Skip to content

Commit 4f89e5b

Browse files
authored
SNOW-1854661 Detect JSON response in Arrow batches mode and return error (#1277)
1 parent f8baf23 commit 4f89e5b

File tree

5 files changed

+50
-0
lines changed

5 files changed

+50
-0
lines changed

async.go

+1
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ func (sr *snowflakeRestful) getAsync(
119119
rows.errChannel <- err
120120
return err
121121
}
122+
rows.format = respd.Data.QueryResultFormat
122123
rows.errChannel <- nil // mark query status complete
123124
}
124125
} else {

chunk_test.go

+39
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"context"
88
"database/sql/driver"
99
"encoding/json"
10+
"errors"
1011
"fmt"
1112
"io"
1213
"math/rand"
@@ -533,6 +534,44 @@ func TestWithArrowBatchesAsync(t *testing.T) {
533534
})
534535
}
535536

537+
func TestWithArrowBatchesButReturningJSON(t *testing.T) {
538+
testWithArrowBatchesButReturningJSON(t, false)
539+
}
540+
541+
func TestWithArrowBatchesButReturningJSONAsync(t *testing.T) {
542+
testWithArrowBatchesButReturningJSON(t, true)
543+
}
544+
545+
func testWithArrowBatchesButReturningJSON(t *testing.T, async bool) {
546+
runSnowflakeConnTest(t, func(sct *SCTest) {
547+
requestID := NewUUID()
548+
pool := memory.NewCheckedAllocator(memory.DefaultAllocator)
549+
defer pool.AssertSize(t, 0)
550+
ctx := WithArrowAllocator(context.Background(), pool)
551+
ctx = WithArrowBatches(ctx)
552+
ctx = WithRequestID(ctx, requestID)
553+
if async {
554+
ctx = WithAsyncMode(ctx)
555+
}
556+
557+
sct.mustExec(forceJSON, nil)
558+
rows := sct.mustQueryContext(ctx, "SELECT 'hello'", nil)
559+
defer rows.Close()
560+
_, err := rows.(SnowflakeRows).GetArrowBatches()
561+
assertNotNilF(t, err)
562+
var se *SnowflakeError
563+
errors.As(err, &se)
564+
assertEqualE(t, se.Message, errJSONResponseInArrowBatchesMode)
565+
566+
ctx = WithRequestID(context.Background(), requestID)
567+
rows2 := sct.mustQueryContext(ctx, "SELECT 'hello'", nil)
568+
defer rows2.Close()
569+
scanValues := make([]driver.Value, 1)
570+
assertNilF(t, rows2.Next(scanValues))
571+
assertEqualE(t, scanValues[0], "hello")
572+
})
573+
}
574+
536575
func TestQueryArrowStream(t *testing.T) {
537576
runSnowflakeConnTest(t, func(sct *SCTest) {
538577
numrows := 50000 // approximately 10 ArrowBatch objects

connection.go

+1
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,7 @@ func (sc *snowflakeConn) queryContextInternal(
435435
rows.sc = sc
436436
rows.queryID = data.Data.QueryID
437437
rows.ctx = ctx
438+
rows.format = data.Data.QueryResultFormat
438439

439440
if isMultiStmt(&data.Data) {
440441
// handleMultiQuery is responsible to fill rows with childResults

errors.go

+1
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ const (
308308
errMsgFailedToParseTomlFile = "failed to parse toml file. the params %v occurred error with value %v"
309309
errMsgFailedToFindDSNInTomlFile = "failed to find DSN in toml file."
310310
errMsgInvalidPermissionToTomlFile = "file permissions different than read/write for user. Your Permission: %v"
311+
errJSONResponseInArrowBatchesMode = "arrow batches enabled, but the response is not Arrow based"
311312
)
312313

313314
// Returned if a DNS doesn't include account parameter.

rows.go

+8
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ type snowflakeRows struct {
4646
errChannel chan error
4747
location *time.Location
4848
ctx context.Context
49+
format string
4950
}
5051

5152
func (rows *snowflakeRows) getLocation() *time.Location {
@@ -168,6 +169,13 @@ func (rows *snowflakeRows) GetArrowBatches() ([]*ArrowBatch, error) {
168169
return nil, err
169170
}
170171

172+
if rows.format != "arrow" {
173+
return nil, (&SnowflakeError{
174+
QueryID: rows.queryID,
175+
Message: errJSONResponseInArrowBatchesMode,
176+
}).exceptionTelemetry(rows.sc)
177+
}
178+
171179
return rows.ChunkDownloader.getArrowBatches(), nil
172180
}
173181

0 commit comments

Comments
 (0)