Skip to content

Commit b9ab4b3

Browse files
authored
SNOW-1924252: Support internal flag (#1319)
1 parent 8e49966 commit b9ab4b3

5 files changed

+34
-9
lines changed

connection.go

+10-8
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,9 @@ func (sc *snowflakeConn) BeginTx(
258258
return nil, driver.ErrBadConn
259259
}
260260
isDesc := isDescribeOnly(ctx)
261+
isInternal := isInternal(ctx)
261262
if _, err := sc.exec(ctx, "BEGIN", false, /* noResult */
262-
false /* isInternal */, isDesc, nil); err != nil {
263+
isInternal, isDesc, nil); err != nil {
263264
return nil, err
264265
}
265266
return &snowflakeTx{sc, ctx}, nil
@@ -318,9 +319,9 @@ func (sc *snowflakeConn) ExecContext(
318319
}
319320
noResult := isAsyncMode(ctx)
320321
isDesc := isDescribeOnly(ctx)
321-
// TODO handle isInternal
322+
isInternal := isInternal(ctx)
322323
ctx = setResultType(ctx, execResultType)
323-
data, err := sc.exec(ctx, query, noResult, false /* isInternal */, isDesc, args)
324+
data, err := sc.exec(ctx, query, noResult, isInternal, isDesc, args)
324325
if err != nil {
325326
logger.WithContext(ctx).Infof("error: %v", err)
326327
if data != nil {
@@ -407,8 +408,8 @@ func (sc *snowflakeConn) queryContextInternal(
407408
noResult := isAsyncMode(ctx)
408409
isDesc := isDescribeOnly(ctx)
409410
ctx = setResultType(ctx, queryResultType)
410-
// TODO: handle isInternal
411-
data, err := sc.exec(ctx, query, noResult, false /* isInternal */, isDesc, args)
411+
isInternal := isInternal(ctx)
412+
data, err := sc.exec(ctx, query, noResult, isInternal, isDesc, args)
412413
if err != nil {
413414
logger.WithContext(ctx).Errorf("error: %v", err)
414415
if data != nil {
@@ -475,9 +476,9 @@ func (sc *snowflakeConn) Ping(ctx context.Context) error {
475476
}
476477
noResult := isAsyncMode(ctx)
477478
isDesc := isDescribeOnly(ctx)
478-
// TODO: handle isInternal
479+
isInternal := isInternal(ctx)
479480
ctx = setResultType(ctx, execResultType)
480-
_, err := sc.exec(ctx, "SELECT 1", noResult, false, /* isInternal */
481+
_, err := sc.exec(ctx, "SELECT 1", noResult, isInternal,
481482
isDesc, []driver.NamedValue{})
482483
return err
483484
}
@@ -518,7 +519,8 @@ func (sc *snowflakeConn) QueryArrowStream(ctx context.Context, query string, bin
518519
ctx = WithArrowBatches(context.WithValue(ctx, asyncMode, false))
519520
ctx = setResultType(ctx, queryResultType)
520521
isDesc := isDescribeOnly(ctx)
521-
data, err := sc.exec(ctx, query, false, false /* isinternal */, isDesc, bindings)
522+
isInternal := isInternal(ctx)
523+
data, err := sc.exec(ctx, query, false, isInternal, isDesc, bindings)
522524
if err != nil {
523525
logger.WithContext(ctx).Errorf("error: %v", err)
524526
if data != nil {

connection_util.go

+9
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,15 @@ func isDescribeOnly(ctx context.Context) bool {
221221
return ok && d
222222
}
223223

224+
func isInternal(ctx context.Context) bool {
225+
v := ctx.Value(internalQuery)
226+
if v == nil {
227+
return false
228+
}
229+
d, ok := v.(bool)
230+
return ok && d
231+
}
232+
224233
func setResultType(ctx context.Context, resType resultType) context.Context {
225234
return context.WithValue(ctx, snowflakeResultType, resType)
226235
}

transaction.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ func (tx *snowflakeTx) execTxCommand(command txCommand) (err error) {
4646
if tx.sc == nil || tx.sc.rest == nil {
4747
return driver.ErrBadConn
4848
}
49-
_, err = tx.sc.exec(tx.ctx, txStr, false /* noResult */, false /* isInternal */, false /* describeOnly */, nil)
49+
isInternal := isInternal(tx.ctx)
50+
_, err = tx.sc.exec(tx.ctx, txStr, false /* noResult */, isInternal, false /* describeOnly */, nil)
5051
if err != nil {
5152
return
5253
}

util.go

+6
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ const (
4040

4141
const (
4242
describeOnly contextKey = "DESCRIBE_ONLY"
43+
internalQuery contextKey = "INTERNAL_QUERY"
4344
cancelRetry contextKey = "CANCEL_RETRY"
4445
streamChunkDownload contextKey = "STREAM_CHUNK_DOWNLOAD"
4546
)
@@ -173,6 +174,11 @@ func WithArrayValuesNullable(ctx context.Context) context.Context {
173174
return context.WithValue(ctx, arrayValuesNullable, true)
174175
}
175176

177+
// WithInternal sets the internal query flag.
178+
func WithInternal(ctx context.Context) context.Context {
179+
return context.WithValue(ctx, internalQuery, true)
180+
}
181+
176182
// Get the request ID from the context if specified, otherwise generate one
177183
func getOrGenerateRequestIDFromContext(ctx context.Context) UUID {
178184
requestID, ok := ctx.Value(snowflakeRequestIDKey).(UUID)

util_test.go

+7
Original file line numberDiff line numberDiff line change
@@ -433,3 +433,10 @@ func TestFindByPrefix(t *testing.T) {
433433
assertEqualE(t, findByPrefix(nonEmpty, "dd"), -1)
434434
assertEqualE(t, findByPrefix([]string{}, "dd"), -1)
435435
}
436+
437+
func TestInternal(t *testing.T) {
438+
ctx := context.Background()
439+
assertFalseE(t, isInternal(ctx))
440+
ctx = WithInternal(ctx)
441+
assertTrueE(t, isInternal(ctx))
442+
}

0 commit comments

Comments
 (0)