Skip to content

Commit 0bf163a

Browse files
committed
feat: introduce nullable types
1 parent 4b7b22a commit 0bf163a

File tree

7 files changed

+419
-1
lines changed

7 files changed

+419
-1
lines changed

Diff for: go.mod

+2
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
module github.com/hashicorp/jsonapi
2+
3+
go 1.18

Diff for: models_test.go

+9
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@ type TimestampModel struct {
3535
RFC3339P *time.Time `jsonapi:"attr,rfc3339p,rfc3339"`
3636
}
3737

38+
type WithNullables struct {
39+
ID int `jsonapi:"primary,with-nullables"`
40+
Name string `jsonapi:"attr,name"`
41+
IntTime Nullable[time.Time] `jsonapi:"attr,int_time,omitempty"`
42+
RFC3339Time Nullable[time.Time] `jsonapi:"attr,rfc3339_time,rfc3339,omitempty"`
43+
ISO8601Time Nullable[time.Time] `jsonapi:"attr,iso8601_time,iso8601,omitempty"`
44+
Bool Nullable[bool] `jsonapi:"attr,bool,omitempty"`
45+
}
46+
3847
type Car struct {
3948
ID *string `jsonapi:"primary,cars"`
4049
Make *string `jsonapi:"attr,make,omitempty"`

Diff for: nullable.go

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
package jsonapi
2+
3+
import (
4+
"errors"
5+
"reflect"
6+
"time"
7+
)
8+
9+
var supportedNullableTypes = map[string]reflect.Value{
10+
"bool": reflect.ValueOf(false),
11+
"time.Time": reflect.ValueOf(time.Time{}),
12+
}
13+
14+
// Nullable is a generic type, which implements a field that can be one of three states:
15+
//
16+
// - field is not set in the request
17+
// - field is explicitly set to `null` in the request
18+
// - field is explicitly set to a valid value in the request
19+
//
20+
// Nullable is intended to be used with JSON marshalling and unmarshalling.
21+
// This is generally useful for PATCH requests, where attributes with zero
22+
// values are intentionally not marshaled into the request payload so that
23+
// existing attribute values are not overwritten.
24+
//
25+
// Internal implementation details:
26+
//
27+
// - map[true]T means a value was provided
28+
// - map[false]T means an explicit null was provided
29+
// - nil or zero map means the field was not provided
30+
//
31+
// If the field is expected to be optional, add the `omitempty` JSON tags. Do NOT use `*Nullable`!
32+
//
33+
// Adapted from https://www.jvt.me/posts/2024/01/09/go-json-nullable/
34+
35+
type Nullable[T any] map[bool]T
36+
37+
// NewNullableWithValue is a convenience helper to allow constructing a
38+
// Nullable with a given value, for instance to construct a field inside a
39+
// struct without introducing an intermediate variable.
40+
func NewNullableWithValue[T any](t T) Nullable[T] {
41+
var n Nullable[T]
42+
n.Set(t)
43+
return n
44+
}
45+
46+
// NewNullNullable is a convenience helper to allow constructing a Nullable with
47+
// an explicit `null`, for instance to construct a field inside a struct
48+
// without introducing an intermediate variable
49+
func NewNullNullable[T any]() Nullable[T] {
50+
var n Nullable[T]
51+
n.SetNull()
52+
return n
53+
}
54+
55+
// Get retrieves the underlying value, if present, and returns an error if the value was not present
56+
func (t Nullable[T]) Get() (T, error) {
57+
var empty T
58+
if t.IsNull() {
59+
return empty, errors.New("value is null")
60+
}
61+
if !t.IsSpecified() {
62+
return empty, errors.New("value is not specified")
63+
}
64+
return t[true], nil
65+
}
66+
67+
// Set sets the underlying value to a given value
68+
func (t *Nullable[T]) Set(value T) {
69+
*t = map[bool]T{true: value}
70+
}
71+
72+
// Set sets the underlying value to a given value
73+
func (t *Nullable[T]) SetInterface(value interface{}) {
74+
t.Set(value.(T))
75+
}
76+
77+
// IsNull indicate whether the field was sent, and had a value of `null`
78+
func (t Nullable[T]) IsNull() bool {
79+
_, foundNull := t[false]
80+
return foundNull
81+
}
82+
83+
// SetNull indicate that the field was sent, and had a value of `null`
84+
func (t *Nullable[T]) SetNull() {
85+
var empty T
86+
*t = map[bool]T{false: empty}
87+
}
88+
89+
// IsSpecified indicates whether the field was sent
90+
func (t Nullable[T]) IsSpecified() bool {
91+
return len(t) != 0
92+
}
93+
94+
// SetUnspecified indicate whether the field was sent
95+
func (t *Nullable[T]) SetUnspecified() {
96+
*t = map[bool]T{}
97+
}
98+
99+
func NullableBool(v bool) Nullable[bool] {
100+
return NewNullableWithValue[bool](v)
101+
}
102+
103+
func NullBool() Nullable[bool] {
104+
return NewNullNullable[bool]()
105+
}
106+
107+
func NullableTime(v time.Time) Nullable[time.Time] {
108+
return NewNullableWithValue[time.Time](v)
109+
}
110+
111+
func NullTime() Nullable[time.Time] {
112+
return NewNullNullable[time.Time]()
113+
}

Diff for: request.go

+34-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"fmt"
88
"io"
99
"reflect"
10+
"regexp"
1011
"strconv"
1112
"strings"
1213
"time"
@@ -589,6 +590,12 @@ func unmarshalAttribute(
589590
value = reflect.ValueOf(attribute)
590591
fieldType := structField.Type
591592

593+
// Handle Nullable[T]
594+
if strings.HasPrefix(fieldValue.Type().Name(), "Nullable[") {
595+
value, err = handleNullable(attribute, args, structField, fieldValue)
596+
return
597+
}
598+
592599
// Handle field of type []string
593600
if fieldValue.Type() == reflect.TypeOf([]string{}) {
594601
value, err = handleStringSlice(attribute)
@@ -656,6 +663,32 @@ func handleStringSlice(attribute interface{}) (reflect.Value, error) {
656663
return reflect.ValueOf(values), nil
657664
}
658665

666+
func handleNullable(
667+
attribute interface{},
668+
args []string,
669+
structField reflect.StructField,
670+
fieldValue reflect.Value) (reflect.Value, error) {
671+
672+
if a, ok := attribute.(string); ok {
673+
if bytes.Equal([]byte(a), []byte("null")) {
674+
return reflect.ValueOf(nil), nil
675+
}
676+
}
677+
678+
var rgx = regexp.MustCompile(`\[(.*)\]`)
679+
rs := rgx.FindStringSubmatch(fieldValue.Type().Name())
680+
681+
attrVal, err := unmarshalAttribute(attribute, args, structField, supportedNullableTypes[rs[1]])
682+
if err != nil {
683+
return reflect.ValueOf(nil), err
684+
}
685+
686+
fieldValue.Set(reflect.MakeMap(fieldValue.Type()))
687+
fieldValue.SetMapIndex(reflect.ValueOf(true), attrVal)
688+
689+
return fieldValue, nil
690+
}
691+
659692
func handleTime(attribute interface{}, args []string, fieldValue reflect.Value) (reflect.Value, error) {
660693
var isISO8601, isRFC3339 bool
661694
v := reflect.ValueOf(attribute)
@@ -714,7 +747,7 @@ func handleTime(attribute interface{}, args []string, fieldValue reflect.Value)
714747
return reflect.ValueOf(time.Now()), ErrInvalidTime
715748
}
716749

717-
t := time.Unix(at, 0)
750+
t := time.Unix(at, 0).UTC()
718751

719752
return reflect.ValueOf(t), nil
720753
}

Diff for: request_test.go

+97
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,88 @@ func TestStringPointerField(t *testing.T) {
300300
}
301301
}
302302

303+
func TestUnmarshalNullableTime(t *testing.T) {
304+
aTime := time.Date(2016, 8, 17, 8, 27, 12, 23849, time.UTC)
305+
306+
out := new(WithNullables)
307+
308+
attrs := map[string]interface{}{
309+
"name": "Name",
310+
"int_time": aTime.Unix(),
311+
"rfc3339_time": aTime.Format(time.RFC3339),
312+
"iso8601_time": aTime.Format(iso8601TimeFormat),
313+
}
314+
315+
if err := UnmarshalPayload(samplePayloadWithNullables(attrs), out); err != nil {
316+
t.Fatal(err)
317+
}
318+
319+
if out.IntTime == nil {
320+
t.Fatal("Was not expecting a nil pointer for out.IntTime")
321+
}
322+
323+
timeVal, err := out.IntTime.Get()
324+
if err != nil {
325+
t.Fatal(err)
326+
}
327+
328+
if expected, actual := aTime, timeVal; expected.Equal(actual) {
329+
t.Fatalf("Was expecting int_time to be `%s`, got `%s`", expected, actual)
330+
}
331+
332+
timeVal, err = out.IntTime.Get()
333+
if err != nil {
334+
t.Fatal(err)
335+
}
336+
337+
if out.RFC3339Time == nil {
338+
t.Fatal("Was not expecting a nil pointer for out.RFC3339Time")
339+
}
340+
if expected, actual := aTime, timeVal; expected.Equal(actual) {
341+
t.Fatalf("Was expecting descript to be `%s`, got `%s`", expected, actual)
342+
}
343+
344+
timeVal, err = out.IntTime.Get()
345+
if err != nil {
346+
t.Fatal(err)
347+
}
348+
349+
if out.ISO8601Time == nil {
350+
t.Fatal("Was not expecting a nil pointer for out.ISO8601Time")
351+
}
352+
if expected, actual := aTime, timeVal; expected.Equal(actual) {
353+
t.Fatalf("Was expecting descript to be `%s`, got `%s`", expected, actual)
354+
}
355+
}
356+
357+
func TestUnmarshalNullableBool(t *testing.T) {
358+
out := new(WithNullables)
359+
360+
aBool := false
361+
362+
attrs := map[string]interface{}{
363+
"name": "Name",
364+
"bool": aBool,
365+
}
366+
367+
if err := UnmarshalPayload(samplePayloadWithNullables(attrs), out); err != nil {
368+
t.Fatal(err)
369+
}
370+
371+
if out.Bool == nil {
372+
t.Fatal("Was not expecting a nil pointer for out.Bool")
373+
}
374+
375+
boolVal, err := out.Bool.Get()
376+
if err != nil {
377+
t.Fatal(err)
378+
}
379+
380+
if expected, actual := aBool, boolVal; expected != actual {
381+
t.Fatalf("Was expecting bool to be `%t`, got `%t`", expected, actual)
382+
}
383+
}
384+
303385
func TestMalformedTag(t *testing.T) {
304386
out := new(BadModel)
305387
err := UnmarshalPayload(samplePayload(), out)
@@ -1426,6 +1508,21 @@ func sampleWithPointerPayload(m map[string]interface{}) io.Reader {
14261508
return out
14271509
}
14281510

1511+
func samplePayloadWithNullables(m map[string]interface{}) io.Reader {
1512+
payload := &OnePayload{
1513+
Data: &Node{
1514+
ID: "5",
1515+
Type: "with-nullables",
1516+
Attributes: m,
1517+
},
1518+
}
1519+
1520+
out := bytes.NewBuffer(nil)
1521+
json.NewEncoder(out).Encode(payload)
1522+
1523+
return out
1524+
}
1525+
14291526
func testModel() *Blog {
14301527
return &Blog{
14311528
ID: 5,

Diff for: response.go

+16
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,22 @@ func visitModelNode(model interface{}, included *map[string]*Node,
331331
node.Attributes = make(map[string]interface{})
332332
}
333333

334+
// Handle Nullable[T]
335+
if strings.HasPrefix(fieldValue.Type().Name(), "Nullable[") {
336+
// handle unspecified
337+
if fieldValue.IsNil() {
338+
continue
339+
}
340+
341+
// handle null
342+
if fieldValue.MapIndex(reflect.ValueOf(false)).IsValid() {
343+
continue
344+
}
345+
346+
// handle value
347+
fieldValue = fieldValue.MapIndex(reflect.ValueOf(true))
348+
}
349+
334350
if fieldValue.Type() == reflect.TypeOf(time.Time{}) {
335351
t := fieldValue.Interface().(time.Time)
336352

0 commit comments

Comments
 (0)