Skip to content

Commit d9372f5

Browse files
dushaoshuaidavid_dupropel-code-bot[bot]
authored
fix(UnixSecondSerializer.Value): Avoid panic when handling unsigned integer values (#7608)
* fix reflection panic * testing * fix: do not pass a nil Context * //nolint: gosec * reduce cyclomatic complexity * add a type assertion safety check for time.Time Co-authored-by: propel-code-bot[bot] <203372662+propel-code-bot[bot]@users.noreply.github.com> * test coverage for integer overflow edge case --------- Co-authored-by: david_du <[email protected]> Co-authored-by: propel-code-bot[bot] <203372662+propel-code-bot[bot]@users.noreply.github.com>
1 parent d8cdb39 commit d9372f5

File tree

2 files changed

+235
-6
lines changed

2 files changed

+235
-6
lines changed

schema/serializer.go

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"encoding/gob"
99
"encoding/json"
1010
"fmt"
11+
"math"
1112
"reflect"
1213
"strings"
1314
"sync"
@@ -127,16 +128,31 @@ func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.
127128
// Value implements serializer interface
128129
func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) {
129130
rv := reflect.ValueOf(fieldValue)
130-
switch v := fieldValue.(type) {
131-
case int64, int, uint, uint64, int32, uint32, int16, uint16:
132-
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
133-
case *int64, *int, *uint, *uint64, *int32, *uint32, *int16, *uint16:
131+
switch fieldValue.(type) {
132+
case int, int8, int16, int32, int64:
133+
result = time.Unix(rv.Int(), 0).UTC()
134+
case uint, uint8, uint16, uint32, uint64:
135+
if uv := rv.Uint(); uv > math.MaxInt64 {
136+
err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv)
137+
} else {
138+
result = time.Unix(int64(uv), 0).UTC() //nolint:gosec
139+
}
140+
case *int, *int8, *int16, *int32, *int64:
134141
if rv.IsZero() {
135142
return nil, nil
136143
}
137-
result = time.Unix(reflect.Indirect(rv).Int(), 0).UTC()
144+
result = time.Unix(rv.Elem().Int(), 0).UTC()
145+
case *uint, *uint8, *uint16, *uint32, *uint64:
146+
if rv.IsZero() {
147+
return nil, nil
148+
}
149+
if uv := rv.Elem().Uint(); uv > math.MaxInt64 {
150+
err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv)
151+
} else {
152+
result = time.Unix(int64(uv), 0).UTC() //nolint:gosec
153+
}
138154
default:
139-
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", v)
155+
err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", fieldValue)
140156
}
141157
return
142158
}

schema/serializer_test.go

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
package schema
2+
3+
import (
4+
"context"
5+
"math"
6+
"reflect"
7+
"testing"
8+
"time"
9+
)
10+
11+
func TestUnixSecondSerializer_Value(t *testing.T) {
12+
var (
13+
intValue = math.MaxInt64
14+
int8Value = int8(math.MaxInt8)
15+
int16Value = int16(math.MaxInt16)
16+
int32Value = int32(math.MaxInt32)
17+
int64Value = int64(math.MaxInt64)
18+
uintValue = uint(math.MaxInt64)
19+
uint8Value = uint8(math.MaxUint8)
20+
uint16Value = uint16(math.MaxUint16)
21+
uint32Value = uint32(math.MaxUint32)
22+
uint64Value = uint64(math.MaxInt64)
23+
maxInt64Plus1 = uint64(math.MaxInt64 + 1)
24+
25+
intPtrValue = &intValue
26+
int8PtrValue = &int8Value
27+
int16PtrValue = &int16Value
28+
int32PtrValue = &int32Value
29+
int64PtrValue = &int64Value
30+
uintPtrValue = &uintValue
31+
uint8PtrValue = &uint8Value
32+
uint16PtrValue = &uint16Value
33+
uint32PtrValue = &uint32Value
34+
uint64PtrValue = &uint64Value
35+
maxInt64Plus1Ptr = &maxInt64Plus1
36+
)
37+
tests := []struct {
38+
name string
39+
value interface{}
40+
want interface{}
41+
wantErr bool
42+
}{
43+
{
44+
name: "int",
45+
value: intValue,
46+
want: time.Unix(int64(intValue), 0).UTC(),
47+
wantErr: false,
48+
},
49+
{
50+
name: "int8",
51+
value: int8Value,
52+
want: time.Unix(int64(int8Value), 0).UTC(),
53+
wantErr: false,
54+
},
55+
{
56+
name: "int16",
57+
value: int16Value,
58+
want: time.Unix(int64(int16Value), 0).UTC(),
59+
wantErr: false,
60+
},
61+
{
62+
name: "int32",
63+
value: int32Value,
64+
want: time.Unix(int64(int32Value), 0).UTC(),
65+
wantErr: false,
66+
},
67+
{
68+
name: "int64",
69+
value: int64Value,
70+
want: time.Unix(int64Value, 0).UTC(),
71+
wantErr: false,
72+
},
73+
{
74+
name: "uint",
75+
value: uintValue,
76+
want: time.Unix(int64(uintValue), 0).UTC(), //nolint:gosec
77+
wantErr: false,
78+
},
79+
{
80+
name: "uint8",
81+
value: uint8Value,
82+
want: time.Unix(int64(uint8Value), 0).UTC(),
83+
wantErr: false,
84+
},
85+
{
86+
name: "uint16",
87+
value: uint16Value,
88+
want: time.Unix(int64(uint16Value), 0).UTC(),
89+
wantErr: false,
90+
},
91+
{
92+
name: "uint32",
93+
value: uint32Value,
94+
want: time.Unix(int64(uint32Value), 0).UTC(),
95+
wantErr: false,
96+
},
97+
{
98+
name: "uint64",
99+
value: uint64Value,
100+
want: time.Unix(int64(uint64Value), 0).UTC(), //nolint:gosec
101+
wantErr: false,
102+
},
103+
{
104+
name: "maxInt64+1",
105+
value: maxInt64Plus1,
106+
want: nil,
107+
wantErr: true,
108+
},
109+
{
110+
name: "*int",
111+
value: intPtrValue,
112+
want: time.Unix(int64(*intPtrValue), 0).UTC(),
113+
wantErr: false,
114+
},
115+
{
116+
name: "*int8",
117+
value: int8PtrValue,
118+
want: time.Unix(int64(*int8PtrValue), 0).UTC(),
119+
wantErr: false,
120+
},
121+
{
122+
name: "*int16",
123+
value: int16PtrValue,
124+
want: time.Unix(int64(*int16PtrValue), 0).UTC(),
125+
wantErr: false,
126+
},
127+
{
128+
name: "*int32",
129+
value: int32PtrValue,
130+
want: time.Unix(int64(*int32PtrValue), 0).UTC(),
131+
wantErr: false,
132+
},
133+
{
134+
name: "*int64",
135+
value: int64PtrValue,
136+
want: time.Unix(*int64PtrValue, 0).UTC(),
137+
wantErr: false,
138+
},
139+
{
140+
name: "*uint",
141+
value: uintPtrValue,
142+
want: time.Unix(int64(*uintPtrValue), 0).UTC(), //nolint:gosec
143+
wantErr: false,
144+
},
145+
{
146+
name: "*uint8",
147+
value: uint8PtrValue,
148+
want: time.Unix(int64(*uint8PtrValue), 0).UTC(),
149+
wantErr: false,
150+
},
151+
{
152+
name: "*uint16",
153+
value: uint16PtrValue,
154+
want: time.Unix(int64(*uint16PtrValue), 0).UTC(),
155+
wantErr: false,
156+
},
157+
{
158+
name: "*uint32",
159+
value: uint32PtrValue,
160+
want: time.Unix(int64(*uint32PtrValue), 0).UTC(),
161+
wantErr: false,
162+
},
163+
{
164+
name: "*uint64",
165+
value: uint64PtrValue,
166+
want: time.Unix(int64(*uint64PtrValue), 0).UTC(), //nolint:gosec
167+
wantErr: false,
168+
},
169+
{
170+
name: "pointer to maxInt64+1",
171+
value: maxInt64Plus1Ptr,
172+
want: nil,
173+
wantErr: true,
174+
},
175+
{
176+
name: "nil pointer",
177+
value: (*int)(nil),
178+
want: nil,
179+
wantErr: false,
180+
},
181+
{
182+
name: "invalid type",
183+
value: "invalid",
184+
want: nil,
185+
wantErr: true,
186+
},
187+
}
188+
for _, tt := range tests {
189+
t.Run(tt.name, func(t *testing.T) {
190+
got, err := UnixSecondSerializer{}.Value(context.Background(), nil, reflect.Value{}, tt.value)
191+
if (err != nil) != tt.wantErr {
192+
t.Fatalf("UnixSecondSerializer.Value() error = %v, wantErr %v", err, tt.wantErr)
193+
}
194+
if err != nil {
195+
return
196+
}
197+
if tt.want == nil && got == nil {
198+
return
199+
}
200+
if tt.want == nil {
201+
t.Fatalf("UnixSecondSerializer.Value() = %v, want nil", got)
202+
}
203+
if got == nil {
204+
t.Fatalf("UnixSecondSerializer.Value() = nil, want %v", tt.want)
205+
}
206+
if gotTime, ok := got.(time.Time); !ok {
207+
t.Errorf("UnixSecondSerializer.Value() returned %T, expected time.Time", got)
208+
} else if !tt.want.(time.Time).Equal(gotTime) {
209+
t.Errorf("UnixSecondSerializer.Value() = %v, want %v", got, tt.want)
210+
}
211+
})
212+
}
213+
}

0 commit comments

Comments
 (0)