Skip to content

Commit 17db71f

Browse files
SNOW-1313648 GO - Verify value bindings for all field types while exceeding CLIENT_STAGE_ARRAY_BINDING_THRESHOLD (#1297)
1 parent 580e7e8 commit 17db71f

File tree

4 files changed

+226
-48
lines changed

4 files changed

+226
-48
lines changed

bind_uploader.go

+24-9
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,10 @@ func (bu *bindUploader) buildRowsAsBytes(columns []driver.NamedValue) ([][]byte,
136136
}).exceptionTelemetry(bu.sc)
137137
}
138138

139-
_, column := snowflakeArrayToString(&columns[0], true)
139+
_, column, err := snowflakeArrayToString(&columns[0], true)
140+
if err != nil {
141+
return nil, err
142+
}
140143
numRows := len(column)
141144
csvRows := make([][]byte, 0)
142145
rows := make([][]interface{}, 0)
@@ -152,7 +155,10 @@ func (bu *bindUploader) buildRowsAsBytes(columns []driver.NamedValue) ([][]byte,
152155
}
153156
}
154157
for colIdx := 1; colIdx < numColumns; colIdx++ {
155-
_, column = snowflakeArrayToString(&columns[colIdx], true)
158+
_, column, err = snowflakeArrayToString(&columns[colIdx], true)
159+
if err != nil {
160+
return nil, err
161+
}
156162
iNumRows := len(column)
157163
if iNumRows != numRows {
158164
return nil, (&SnowflakeError{
@@ -201,7 +207,10 @@ func (sc *snowflakeConn) processBindings(
201207
requestID UUID,
202208
req *execRequest) error {
203209
arrayBindThreshold := sc.getArrayBindStageThreshold()
204-
numBinds := arrayBindValueCount(bindings)
210+
numBinds, err := arrayBindValueCount(bindings)
211+
if err != nil {
212+
return err
213+
}
205214
if 0 < arrayBindThreshold && arrayBindThreshold <= numBinds && !describeOnly && isArrayBind(bindings) {
206215
uploader := bindUploader{
207216
sc: sc,
@@ -215,7 +224,6 @@ func (sc *snowflakeConn) processBindings(
215224
req.Bindings = nil
216225
req.BindStage = uploader.stagePath
217226
} else {
218-
var err error
219227
req.Bindings, err = getBindValues(bindings, sc.cfg.Params)
220228
if err != nil {
221229
return err
@@ -246,7 +254,10 @@ func getBindValues(bindings []driver.NamedValue, params map[string]*string) (map
246254
var bv bindingValue
247255
if t == sliceType {
248256
// retrieve array binding data
249-
t, val = snowflakeArrayToString(&binding, false)
257+
t, val, err = snowflakeArrayToString(&binding, false)
258+
if err != nil {
259+
return nil, err
260+
}
250261
} else {
251262
bv, err = valueToString(binding.Value, tsmode, params)
252263
val = bv.value
@@ -280,12 +291,16 @@ func bindingName(nv driver.NamedValue, idx int) string {
280291
return strconv.Itoa(idx)
281292
}
282293

283-
func arrayBindValueCount(bindValues []driver.NamedValue) int {
294+
func arrayBindValueCount(bindValues []driver.NamedValue) (int, error) {
284295
if !isArrayBind(bindValues) {
285-
return 0
296+
return 0, nil
286297
}
287-
_, arr := snowflakeArrayToString(&bindValues[0], false)
288-
return len(bindValues) * len(arr)
298+
_, arr, err := snowflakeArrayToString(&bindValues[0], false)
299+
if err != nil {
300+
return 0, err
301+
}
302+
303+
return len(bindValues) * len(arr), nil
289304
}
290305

291306
func isArrayBind(bindings []driver.NamedValue) bool {

bindings_test.go

+132
Original file line numberDiff line numberDiff line change
@@ -874,6 +874,130 @@ func TestBulkArrayBinding(t *testing.T) {
874874
})
875875
}
876876

877+
func TestBindingsWithSameValue(t *testing.T) {
878+
arrayInsertTable := "test_array_binding_insert"
879+
stageBindingTable := "test_stage_binding_insert"
880+
interfaceArrayTable := "test_interface_binding_insert"
881+
882+
runDBTest(t, func(dbt *DBTest) {
883+
dbt.mustExec(fmt.Sprintf("create or replace table %v (c1 integer, c2 string, c3 timestamp_ltz, c4 timestamp_tz, c5 timestamp_ntz, c6 date, c7 time, c9 boolean, c10 double)", arrayInsertTable))
884+
dbt.mustExec(fmt.Sprintf("create or replace table %v (c1 integer, c2 string, c3 timestamp_ltz, c4 timestamp_tz, c5 timestamp_ntz, c6 date, c7 time, c9 boolean, c10 double)", stageBindingTable))
885+
dbt.mustExec(fmt.Sprintf("create or replace table %v (c1 integer, c2 string, c3 timestamp_ltz, c4 timestamp_tz, c5 timestamp_ntz, c6 date, c7 time, c9 boolean, c10 double)", interfaceArrayTable))
886+
887+
defer func() {
888+
dbt.mustExec(fmt.Sprintf("drop table if exists %v", arrayInsertTable))
889+
dbt.mustExec(fmt.Sprintf("drop table if exists %v", stageBindingTable))
890+
dbt.mustExec(fmt.Sprintf("drop table if exists %v", interfaceArrayTable))
891+
}()
892+
893+
numRows := 5
894+
895+
intArr := make([]int, numRows)
896+
strArr := make([]string, numRows)
897+
timeArr := make([]time.Time, numRows)
898+
boolArr := make([]bool, numRows)
899+
doubleArr := make([]float64, numRows)
900+
901+
intAnyArr := make([]any, numRows)
902+
strAnyArr := make([]any, numRows)
903+
timeAnyArr := make([]any, numRows)
904+
boolAnyArr := make([]bool, numRows)
905+
doubleAnyArr := make([]float64, numRows)
906+
907+
for i := 0; i < numRows; i++ {
908+
intArr[i] = i
909+
intAnyArr[i] = i
910+
911+
double := rand.Float64()
912+
doubleArr[i] = double
913+
doubleAnyArr[i] = double
914+
915+
strArr[i] = "test" + strconv.Itoa(i)
916+
strAnyArr[i] = "test" + strconv.Itoa(i)
917+
918+
b := getRandomBool()
919+
boolArr[i] = b
920+
boolAnyArr[i] = b
921+
922+
date := getRandomDate()
923+
timeArr[i] = date
924+
timeAnyArr[i] = date
925+
}
926+
927+
dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?, ?)", interfaceArrayTable), Array(&intAnyArr), Array(&strAnyArr), Array(&timeAnyArr, TimestampLTZType), Array(&timeAnyArr, TimestampTZType), Array(&timeAnyArr, TimestampNTZType), Array(&timeAnyArr, DateType), Array(&timeAnyArr, TimeType), Array(&boolArr), Array(&doubleArr))
928+
dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?, ?)", arrayInsertTable), Array(&intArr), Array(&strArr), Array(&timeArr, TimestampLTZType), Array(&timeArr, TimestampTZType), Array(&timeArr, TimestampNTZType), Array(&timeArr, DateType), Array(&timeArr, TimeType), Array(&boolArr), Array(&doubleArr))
929+
dbt.mustExec("ALTER SESSION SET CLIENT_STAGE_ARRAY_BINDING_THRESHOLD = 1")
930+
dbt.mustExec(fmt.Sprintf("insert into %v values (?, ?, ?, ?, ?, ?, ?, ?, ?)", stageBindingTable), Array(&intArr), Array(&strArr), Array(&timeArr, TimestampLTZType), Array(&timeArr, TimestampTZType), Array(&timeArr, TimestampNTZType), Array(&timeArr, DateType), Array(&timeArr, TimeType), Array(&boolArr), Array(&doubleArr))
931+
932+
insertRows := dbt.mustQuery("select * from " + arrayInsertTable + " order by c1")
933+
bindingRows := dbt.mustQuery("select * from " + stageBindingTable + " order by c1")
934+
interfaceRows := dbt.mustQuery("select * from " + interfaceArrayTable + " order by c1")
935+
936+
defer func() {
937+
assertNilF(t, insertRows.Close())
938+
assertNilF(t, bindingRows.Close())
939+
assertNilF(t, interfaceRows.Close())
940+
}()
941+
var i, bi, ii int
942+
var s, bs, is string
943+
var ltz, bltz, iltz, itz, btz, tz, intz, ntz, bntz, iDate, date, bDate, itt, tt, btt time.Time
944+
var b, bb, ib bool
945+
var d, bd, id float64
946+
947+
timeFormat := "15:04:05"
948+
for k := 0; k < numRows; k++ {
949+
assertTrueF(t, insertRows.Next())
950+
assertNilF(t, insertRows.Scan(&i, &s, &ltz, &tz, &ntz, &date, &tt, &b, &d))
951+
952+
assertTrueF(t, bindingRows.Next())
953+
assertNilF(t, bindingRows.Scan(&bi, &bs, &bltz, &btz, &bntz, &bDate, &btt, &bb, &bd))
954+
955+
assertTrueF(t, interfaceRows.Next())
956+
assertNilF(t, interfaceRows.Scan(&ii, &is, &iltz, &itz, &intz, &iDate, &itt, &ib, &id))
957+
958+
assertEqualE(t, k, i)
959+
assertEqualE(t, k, bi)
960+
assertEqualE(t, k, ii)
961+
962+
assertEqualE(t, "test"+strconv.Itoa(k), s)
963+
assertEqualE(t, "test"+strconv.Itoa(k), bs)
964+
assertEqualE(t, "test"+strconv.Itoa(k), is)
965+
966+
utcTime := timeArr[k].UTC()
967+
assertEqualE(t, ltz.UTC(), utcTime)
968+
assertEqualE(t, bltz.UTC(), utcTime)
969+
assertEqualE(t, iltz.UTC(), utcTime)
970+
971+
assertEqualE(t, tz.UTC(), utcTime)
972+
assertEqualE(t, btz.UTC(), utcTime)
973+
assertEqualE(t, itz.UTC(), utcTime)
974+
975+
assertEqualE(t, ntz.UTC(), utcTime)
976+
assertEqualE(t, bntz.UTC(), utcTime)
977+
assertEqualE(t, intz.UTC(), utcTime)
978+
979+
testingDate := timeArr[k].Truncate(24 * time.Hour)
980+
assertEqualE(t, date, testingDate)
981+
assertEqualE(t, bDate, testingDate)
982+
assertEqualE(t, iDate, testingDate)
983+
984+
testingTime := timeArr[k].Format(timeFormat)
985+
assertEqualE(t, tt.Format(timeFormat), testingTime)
986+
assertEqualE(t, btt.Format(timeFormat), testingTime)
987+
assertEqualE(t, itt.Format(timeFormat), testingTime)
988+
989+
assertEqualE(t, b, boolArr[k])
990+
assertEqualE(t, bb, boolArr[k])
991+
assertEqualE(t, ib, boolArr[k])
992+
993+
assertEqualE(t, d, doubleArr[k])
994+
assertEqualE(t, bd, doubleArr[k])
995+
assertEqualE(t, id, doubleArr[k])
996+
997+
}
998+
})
999+
}
1000+
8771001
func TestBulkArrayBindingTimeWithPrecision(t *testing.T) {
8781002
runDBTest(t, func(dbt *DBTest) {
8791003
dbt.mustExec(fmt.Sprintf("create or replace table %v (s time(0), ms time(3), us time(6), ns time(9))", dbname))
@@ -1423,3 +1547,11 @@ func testInsertLOBData(t *testing.T, useArrowFormat bool, isLiteral bool) {
14231547
dbt.mustExec(unsetFeatureMaxLOBSize)
14241548
})
14251549
}
1550+
1551+
func getRandomDate() time.Time {
1552+
return time.Date(rand.Intn(1582)+1, time.January, rand.Intn(40), rand.Intn(40), rand.Intn(40), rand.Intn(40), rand.Intn(40), time.UTC)
1553+
}
1554+
1555+
func getRandomBool() bool {
1556+
return rand.Int63n(time.Now().Unix())%2 == 0
1557+
}

0 commit comments

Comments
 (0)