Skip to content

Commit 1043498

Browse files
authored
SNOW-895537: Send query context with request (#904)
1 parent 213a5ac commit 1043498

File tree

5 files changed

+136
-110
lines changed

5 files changed

+136
-110
lines changed

connection.go

+26
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,18 @@ func (sc *snowflakeConn) exec(
8585
var err error
8686
counter := atomic.AddUint64(&sc.SequenceCounter, 1) // query sequence counter
8787

88+
queryContext, err := buildQueryContext(sc.queryContextCache)
89+
if err != nil {
90+
logger.Errorf("error while building query context: %v", err)
91+
}
8892
req := execRequest{
8993
SQLText: query,
9094
AsyncExec: noResult,
9195
Parameters: map[string]interface{}{},
9296
IsInternal: isInternal,
9397
DescribeOnly: describeOnly,
9498
SequenceID: counter,
99+
QueryContext: queryContext,
95100
}
96101
if key := ctx.Value(multiStatementCount); key != nil {
97102
req.Parameters[string(multiStatementCount)] = key
@@ -173,6 +178,27 @@ func extractQueryContext(data *execResponse) (queryContext, error) {
173178
return queryContext, err
174179
}
175180

181+
func buildQueryContext(qcc *queryContextCache) (requestQueryContext, error) {
182+
rqc := requestQueryContext{}
183+
if qcc == nil || len(qcc.entries) == 0 {
184+
logger.Debugf("empty qcc")
185+
return rqc, nil
186+
}
187+
for _, qce := range qcc.entries {
188+
contextData := contextData{}
189+
if qce.Context == "" {
190+
contextData.Base64Data = qce.Context
191+
}
192+
rqc.Entries = append(rqc.Entries, requestQueryContextEntry{
193+
ID: qce.ID,
194+
Priority: qce.Priority,
195+
Timestamp: qce.Timestamp,
196+
Context: contextData,
197+
})
198+
}
199+
return rqc, nil
200+
}
201+
176202
func (sc *snowflakeConn) Begin() (driver.Tx, error) {
177203
return sc.BeginTx(sc.ctx, driver.TxOptions{})
178204
}

driver_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ func runSnowflakeConnTest(t *testing.T, test func(sct *SCTest)) {
387387
}
388388

389389
sct := &SCTest{t, sc}
390+
390391
test(sct)
391392
}
392393

htap.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ type queryContext struct {
1616
}
1717

1818
type queryContextEntry struct {
19-
ID int `json:"id"`
20-
Timestamp int64 `json:"timestamp"`
21-
Priority int `json:"priority"`
22-
Context any `json:"context,omitempty"`
19+
ID int `json:"id"`
20+
Timestamp int64 `json:"timestamp"`
21+
Priority int `json:"priority"`
22+
Context string `json:"context,omitempty"`
2323
}
2424

2525
type queryContextCache struct {

htap_test.go

+89-106
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,13 @@
11
package gosnowflake
22

33
import (
4-
"encoding/json"
4+
"database/sql/driver"
5+
"fmt"
56
"reflect"
6-
"strings"
77
"testing"
8+
"time"
89
)
910

10-
func TestMarshallAndDecodeOpaqueContext(t *testing.T) {
11-
testcases := []struct {
12-
json string
13-
qc queryContextEntry
14-
}{
15-
{
16-
json: `{
17-
"id": 1,
18-
"timestamp": 2,
19-
"priority": 3
20-
}`,
21-
qc: queryContextEntry{1, 2, 3, nil},
22-
},
23-
{
24-
json: `{
25-
"id": 1,
26-
"timestamp": 2,
27-
"priority": 3,
28-
"context": "abc"
29-
}`,
30-
qc: queryContextEntry{1, 2, 3, "abc"},
31-
},
32-
{
33-
json: `{
34-
"id": 1,
35-
"timestamp": 2,
36-
"priority": 3,
37-
"context": {
38-
"val": "abc"
39-
}
40-
}`,
41-
qc: queryContextEntry{1, 2, 3, map[string]interface{}{"val": "abc"}},
42-
},
43-
{
44-
json: `{
45-
"id": 1,
46-
"timestamp": 2,
47-
"priority": 3,
48-
"context": [
49-
"abc"
50-
]
51-
}`,
52-
qc: queryContextEntry{1, 2, 3, []any{"abc"}},
53-
},
54-
{
55-
json: `{
56-
"id": 1,
57-
"timestamp": 2,
58-
"priority": 3,
59-
"context": [
60-
{
61-
"val": "abc"
62-
}
63-
]
64-
}`,
65-
qc: queryContextEntry{1, 2, 3, []any{map[string]interface{}{"val": "abc"}}},
66-
},
67-
}
68-
69-
for _, tc := range testcases {
70-
t.Run(trimWhitespaces(tc.json), func(t *testing.T) {
71-
var qc queryContextEntry
72-
73-
err := json.NewDecoder(strings.NewReader(tc.json)).Decode(&qc)
74-
if err != nil {
75-
t.Fatalf("failed to decode json. %v", err)
76-
}
77-
78-
if !reflect.DeepEqual(tc.qc, qc) {
79-
t.Errorf("failed to decode json. expected: %v, got: %v", tc.qc, qc)
80-
}
81-
82-
bytes, err := json.Marshal(qc)
83-
if err != nil {
84-
t.Fatalf("failed to encode json. %v", err)
85-
}
86-
87-
resultJSON := string(bytes)
88-
if resultJSON != trimWhitespaces(tc.json) {
89-
t.Errorf("failed to encode json. epxected: %v, got: %v", trimWhitespaces(tc.json), resultJSON)
90-
}
91-
})
92-
}
93-
}
94-
95-
func trimWhitespaces(s string) string {
96-
return strings.ReplaceAll(
97-
strings.ReplaceAll(
98-
strings.ReplaceAll(s, "\t", ""),
99-
" ", ""),
100-
"\n", "",
101-
)
102-
}
103-
10411
func TestSortingByPriority(t *testing.T) {
10512
qcc := (&queryContextCache{}).init()
10613
sc := htapTestSnowflakeConn()
@@ -302,9 +209,9 @@ func containsNewEntries(entriesAfter []queryContextEntry, entriesBefore []queryC
302209
}
303210

304211
func TestPruneBySessionValue(t *testing.T) {
305-
qce1 := queryContextEntry{1, 1, 1, nil}
306-
qce2 := queryContextEntry{2, 2, 2, nil}
307-
qce3 := queryContextEntry{3, 3, 3, nil}
212+
qce1 := queryContextEntry{1, 1, 1, ""}
213+
qce2 := queryContextEntry{2, 2, 2, ""}
214+
qce3 := queryContextEntry{3, 3, 3, ""}
308215

309216
testcases := []struct {
310217
size string
@@ -352,12 +259,12 @@ func TestPruneBySessionValue(t *testing.T) {
352259
}
353260

354261
func TestPruneByDefaultValue(t *testing.T) {
355-
qce1 := queryContextEntry{1, 1, 1, nil}
356-
qce2 := queryContextEntry{2, 2, 2, nil}
357-
qce3 := queryContextEntry{3, 3, 3, nil}
358-
qce4 := queryContextEntry{4, 4, 4, nil}
359-
qce5 := queryContextEntry{5, 5, 5, nil}
360-
qce6 := queryContextEntry{6, 6, 6, nil}
262+
qce1 := queryContextEntry{1, 1, 1, ""}
263+
qce2 := queryContextEntry{2, 2, 2, ""}
264+
qce3 := queryContextEntry{3, 3, 3, ""}
265+
qce4 := queryContextEntry{4, 4, 4, ""}
266+
qce5 := queryContextEntry{5, 5, 5, ""}
267+
qce6 := queryContextEntry{6, 6, 6, ""}
361268

362269
sc := &snowflakeConn{
363270
cfg: &Config{
@@ -383,7 +290,7 @@ func TestPruneByDefaultValue(t *testing.T) {
383290
}
384291

385292
func TestNoQcesClearsCache(t *testing.T) {
386-
qce1 := queryContextEntry{1, 1, 1, nil}
293+
qce1 := queryContextEntry{1, 1, 1, ""}
387294

388295
sc := &snowflakeConn{
389296
cfg: &Config{
@@ -426,3 +333,79 @@ func TestQueryContextCacheDisabled(t *testing.T) {
426333
}
427334
})
428335
}
336+
337+
func TestHybridTablesE2E(t *testing.T) {
338+
if runningOnGithubAction() && !runningOnAWS() {
339+
t.Skip("HTAP is enabled only on AWS")
340+
}
341+
runID := time.Now().UnixMilli()
342+
testDb1 := fmt.Sprintf("hybrid_db_test_%v", runID)
343+
testDb2 := fmt.Sprintf("hybrid_db_test_%v_2", runID)
344+
runSnowflakeConnTest(t, func(sct *SCTest) {
345+
dbQuery := sct.mustQuery("SELECT CURRENT_DATABASE()", nil)
346+
defer dbQuery.Close()
347+
currentDb := make([]driver.Value, 1)
348+
dbQuery.Next(currentDb)
349+
defer func() {
350+
sct.mustExec(fmt.Sprintf("USE DATABASE %v", currentDb[0]), nil)
351+
sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb1), nil)
352+
sct.mustExec(fmt.Sprintf("DROP DATABASE IF EXISTS %v", testDb2), nil)
353+
}()
354+
355+
t.Run("Run tests on first database", func(t *testing.T) {
356+
sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb1), nil)
357+
sct.mustExec("CREATE HYBRID TABLE test_hybrid_table (id INT PRIMARY KEY, text VARCHAR)", nil)
358+
359+
sct.mustExec("INSERT INTO test_hybrid_table VALUES (1, 'a')", nil)
360+
rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
361+
defer rows.Close()
362+
row := make([]driver.Value, 2)
363+
rows.Next(row)
364+
if row[0] != "1" || row[1] != "a" {
365+
t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1])
366+
}
367+
368+
sct.mustExec("INSERT INTO test_hybrid_table VALUES (2, 'b')", nil)
369+
rows2 := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
370+
defer rows2.Close()
371+
rows2.Next(row)
372+
if row[0] != "1" || row[1] != "a" {
373+
t.Errorf("expected 1, got %v and expected a, got %v", row[0], row[1])
374+
}
375+
rows2.Next(row)
376+
if row[0] != "2" || row[1] != "b" {
377+
t.Errorf("expected 2, got %v and expected b, got %v", row[0], row[1])
378+
}
379+
if len(sct.sc.queryContextCache.entries) != 2 {
380+
t.Errorf("expected two entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
381+
}
382+
})
383+
t.Run("Run tests on second database", func(t *testing.T) {
384+
sct.mustExec(fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %v", testDb2), nil)
385+
sct.mustExec("CREATE HYBRID TABLE test_hybrid_table_2 (id INT PRIMARY KEY, text VARCHAR)", nil)
386+
sct.mustExec("INSERT INTO test_hybrid_table_2 VALUES (3, 'c')", nil)
387+
388+
rows := sct.mustQuery("SELECT * FROM test_hybrid_table_2", nil)
389+
defer rows.Close()
390+
row := make([]driver.Value, 2)
391+
rows.Next(row)
392+
if row[0] != "3" || row[1] != "c" {
393+
t.Errorf("expected 3, got %v and expected c, got %v", row[0], row[1])
394+
}
395+
if len(sct.sc.queryContextCache.entries) != 3 {
396+
t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
397+
}
398+
})
399+
t.Run("Run tests on first database again", func(t *testing.T) {
400+
sct.mustExec(fmt.Sprintf("USE DATABASE %v", testDb1), nil)
401+
402+
sct.mustExec("INSERT INTO test_hybrid_table VALUES (4, 'd')", nil)
403+
404+
rows := sct.mustQuery("SELECT * FROM test_hybrid_table", nil)
405+
defer rows.Close()
406+
if len(sct.sc.queryContextCache.entries) != 3 {
407+
t.Errorf("expected three entries in query context cache, got: %v", sct.sc.queryContextCache.entries)
408+
}
409+
})
410+
})
411+
}

query.go

+16
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,22 @@ type execRequest struct {
2828
Parameters map[string]interface{} `json:"parameters,omitempty"`
2929
Bindings map[string]execBindParameter `json:"bindings,omitempty"`
3030
BindStage string `json:"bindStage,omitempty"`
31+
QueryContext requestQueryContext `json:"queryContextDTO,omitempty"`
32+
}
33+
34+
type requestQueryContext struct {
35+
Entries []requestQueryContextEntry `json:"entries,omitempty"`
36+
}
37+
38+
type requestQueryContextEntry struct {
39+
Context contextData `json:"context,omitempty"`
40+
ID int `json:"id"`
41+
Priority int `json:"priority"`
42+
Timestamp int64 `json:"timestamp,omitempty"`
43+
}
44+
45+
type contextData struct {
46+
Base64Data string `json:"base64Data,omitempty"`
3147
}
3248

3349
type execResponseRowType struct {

0 commit comments

Comments
 (0)