Skip to content

Commit 7b3319d

Browse files
authored
SNOW-1016278 Fix panic on empty arrow batches (#1039)
1 parent 611fe9d commit 7b3319d

File tree

3 files changed

+38
-2
lines changed

3 files changed

+38
-2
lines changed

assert_test.go

+12
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ func assertBetweenE(t *testing.T, value float64, min float64, max float64, descr
5858
errorOnNonEmpty(t, validateValueBetween(value, min, max, descriptions...))
5959
}
6060

61+
func assertEmptyE[T any](t *testing.T, actual []T, descriptions ...string) {
62+
errorOnNonEmpty(t, validateEmpty(actual, descriptions...))
63+
}
64+
6165
func fatalOnNonEmpty(t *testing.T, errMsg string) {
6266
if errMsg != "" {
6367
t.Fatal(formatErrorMessage(errMsg))
@@ -122,6 +126,14 @@ func validateValueBetween(value float64, min float64, max float64, descriptions
122126
return fmt.Sprintf("expected \"%f\" should be between \"%f\" and \"%f\" but did not. %s", value, min, max, desc)
123127
}
124128

129+
func validateEmpty[T any](value []T, descriptions ...string) string {
130+
if len(value) == 0 {
131+
return ""
132+
}
133+
desc := joinDescriptions(descriptions...)
134+
return fmt.Sprintf("expected \"%v\" to be empty. %s", value, desc)
135+
}
136+
125137
func joinDescriptions(descriptions ...string) string {
126138
return strings.Join(descriptions, " ")
127139
}

chunk_downloader.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ func (scd *snowflakeChunkDownloader) getRowType() []execResponseRowType {
249249
}
250250

251251
func (scd *snowflakeChunkDownloader) getArrowBatches() []*ArrowBatch {
252-
if scd.FirstBatch.rec == nil {
252+
if scd.FirstBatch == nil || scd.FirstBatch.rec == nil {
253253
return scd.ArrowBatches
254254
}
255255
return append([]*ArrowBatch{scd.FirstBatch}, scd.ArrowBatches...)

chunk_downloader_test.go

+25-1
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,38 @@ func TestWithArrowBatchesWhenQueryReturnsNoRowsWhenUsingNativeGoSQLInterface(t *
4141
})
4242
}
4343

44-
func TestWithArrowBatchesWhenQueryReturnsNoRows(t *testing.T) {
44+
func TestWithArrowBatchesWhenQueryReturnsRowsAndReadingRows(t *testing.T) {
4545
runDBTest(t, func(dbt *DBTest) {
4646
rows := dbt.mustQueryContext(WithArrowBatches(context.Background()), "SELECT 1")
4747
defer rows.Close()
4848
assertFalseF(t, rows.Next())
4949
})
5050
}
5151

52+
func TestWithArrowBatchesWhenQueryReturnsNoRowsAndReadingRows(t *testing.T) {
53+
runDBTest(t, func(dbt *DBTest) {
54+
rows := dbt.mustQueryContext(WithArrowBatches(context.Background()), "SELECT 1 WHERE 1 = 0")
55+
defer rows.Close()
56+
assertFalseF(t, rows.Next())
57+
})
58+
}
59+
60+
func TestWithArrowBatchesWhenQueryReturnsNoRowsAndReadingArrowBatches(t *testing.T) {
61+
runDBTest(t, func(dbt *DBTest) {
62+
var rows driver.Rows
63+
var err error
64+
err = dbt.conn.Raw(func(x any) error {
65+
rows, err = x.(driver.QueryerContext).QueryContext(WithArrowBatches(context.Background()), "SELECT 1 WHERE 1 = 0", nil)
66+
return err
67+
})
68+
assertNilF(t, err)
69+
defer rows.Close()
70+
batches, err := rows.(SnowflakeRows).GetArrowBatches()
71+
assertNilF(t, err)
72+
assertEmptyE(t, batches)
73+
})
74+
}
75+
5276
func TestWithArrowBatchesWhenQueryReturnsSomeRowsInGivenFormatUsingNativeGoSQLInterface(t *testing.T) {
5377
for _, tc := range []struct {
5478
useJSON bool

0 commit comments

Comments
 (0)