Skip to content

Commit f1dfe51

Browse files
SNOW-942595: Retry the fetch result request for response code queryInProgressAsyncCode when using WithFetchResultByID (#1021)
Add a check for the response code queryInProgressAsyncCode(333334) in rowsForRunningQuery() in monitoring.go used for fetching results using WithFetchResultByID
1 parent e990de6 commit f1dfe51

File tree

7 files changed

+269
-47
lines changed

7 files changed

+269
-47
lines changed

async.go

+54-39
Original file line numberDiff line numberDiff line change
@@ -64,45 +64,12 @@ func (sr *snowflakeRestful) getAsync(
6464
token, _, _ := sr.TokenAccessor.GetTokens()
6565
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
6666

67-
var err error
68-
var respd execResponse
69-
retry := 0
70-
retryPattern := []int32{1, 1, 2, 3, 4, 8, 10}
71-
72-
for {
73-
resp, err := sr.FuncGet(ctx, sr, URL, headers, timeout)
74-
if err != nil {
75-
logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
76-
sfError.Message = err.Error()
77-
errChannel <- sfError
78-
return err
79-
}
80-
defer resp.Body.Close()
81-
82-
respd = execResponse{} // reset the response
83-
err = json.NewDecoder(resp.Body).Decode(&respd)
84-
if err != nil {
85-
logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
86-
sfError.Message = err.Error()
87-
errChannel <- sfError
88-
return err
89-
}
90-
if respd.Code != queryInProgressAsyncCode {
91-
// If the query takes longer than 45 seconds to complete the results are not returned.
92-
// If the query is still in progress after 45 seconds, retry the request to the /results endpoint.
93-
// For all other scenarios continue processing results response
94-
break
95-
} else {
96-
// Sleep before retrying get result request. Exponential backoff up to 5 seconds.
97-
// Once 5 second backoff is reached it will keep retrying with this sleeptime.
98-
sleepTime := time.Millisecond * time.Duration(500*retryPattern[retry])
99-
logger.WithContext(ctx).Infof("Query execution still in progress. Sleep for %v ms", sleepTime)
100-
time.Sleep(sleepTime)
101-
}
102-
if retry < len(retryPattern)-1 {
103-
retry++
104-
}
105-
67+
respd, err := getQueryResultWithRetriesForAsyncMode(ctx, sr, URL, headers, timeout)
68+
if err != nil {
69+
logger.WithContext(ctx).Errorf("error: %v", err)
70+
sfError.Message = err.Error()
71+
errChannel <- sfError
72+
return err
10673
}
10774

10875
sc := &snowflakeConn{rest: sr, cfg: cfg, queryContextCache: (&queryContextCache{}).init(), currentTimeProvider: defaultTimeProvider}
@@ -166,3 +133,51 @@ func (sr *snowflakeRestful) getAsync(
166133
}
167134
return nil
168135
}
136+
137+
func getQueryResultWithRetriesForAsyncMode(
138+
ctx context.Context,
139+
sr *snowflakeRestful,
140+
URL *url.URL,
141+
headers map[string]string,
142+
timeout time.Duration) (*execResponse, error) {
143+
var respd *execResponse
144+
retry := 0
145+
retryPattern := []int32{1, 1, 2, 3, 4, 8, 10}
146+
retryPatternIndex := 0
147+
148+
for {
149+
logger.WithContext(ctx).Debugf("Retry count for get query result request in async mode: %v", retry)
150+
151+
resp, err := sr.FuncGet(ctx, sr, URL, headers, timeout)
152+
if err != nil {
153+
logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
154+
return respd, err
155+
}
156+
defer resp.Body.Close()
157+
158+
respd = &execResponse{} // reset the response
159+
err = json.NewDecoder(resp.Body).Decode(&respd)
160+
if err != nil {
161+
logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
162+
return respd, err
163+
}
164+
if respd.Code != queryInProgressAsyncCode {
165+
// If the query takes longer than 45 seconds to complete the results are not returned.
166+
// If the query is still in progress after 45 seconds, retry the request to the /results endpoint.
167+
// For all other scenarios continue processing results response
168+
break
169+
} else {
170+
// Sleep before retrying get result request. Exponential backoff up to 5 seconds.
171+
// Once 5 second backoff is reached it will keep retrying with this sleeptime.
172+
sleepTime := time.Millisecond * time.Duration(500*retryPattern[retryPatternIndex])
173+
logger.WithContext(ctx).Infof("Query execution still in progress. Response code: %v, message: %v Sleep for %v ms", respd.Code, respd.Message, sleepTime)
174+
time.Sleep(sleepTime)
175+
retry++
176+
177+
if retryPatternIndex < len(retryPattern)-1 {
178+
retryPatternIndex++
179+
}
180+
}
181+
}
182+
return respd, nil
183+
}

async_test.go

+31
Original file line numberDiff line numberDiff line change
@@ -196,3 +196,34 @@ func TestLongRunningAsyncQuery(t *testing.T) {
196196
}
197197
}
198198
}
199+
200+
func TestLongRunningAsyncQueryFetchResultByID(t *testing.T) {
201+
runDBTest(t, func(dbt *DBTest) {
202+
queryIDChan := make(chan string, 1)
203+
ctx := WithAsyncMode(context.Background())
204+
ctx = WithQueryIDChan(ctx, queryIDChan)
205+
206+
// Run a long running query asynchronously
207+
go dbt.mustExecContext(ctx, "CALL SYSTEM$WAIT(50, 'SECONDS')")
208+
209+
// Get the query ID without waiting for the query to finish
210+
queryID := <-queryIDChan
211+
assertNotNilF(t, queryID, "expected a nonempty query ID")
212+
213+
ctx = WithFetchResultByID(ctx, queryID)
214+
rows := dbt.mustQueryContext(ctx, "")
215+
defer rows.Close()
216+
217+
var v string
218+
assertTrueF(t, rows.Next())
219+
err := rows.Scan(&v)
220+
assertNilF(t, err, fmt.Sprintf("failed to get result. err: %v", err))
221+
assertNotNilF(t, v, "should have returned a result")
222+
223+
expected := "waited 50 seconds"
224+
if v != expected {
225+
t.Fatalf("unexpected result returned. expected: %v, but got: %v", expected, v)
226+
}
227+
assertFalseF(t, rows.NextResultSet())
228+
})
229+
}

cmd/fetchresultbyid/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
fetchresultbyid

cmd/fetchresultbyid/Makefile

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
include ../../gosnowflake.mak
2+
CMD_TARGET=fetchresultbyid
3+
4+
## Install
5+
install: cinstall
6+
7+
## Run
8+
run: crun
9+
10+
## Lint
11+
lint: clint
12+
13+
## Format source codes
14+
fmt: cfmt
15+
16+
.PHONY: install run lint fmt
+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"database/sql/driver"
7+
"flag"
8+
"log"
9+
"strings"
10+
11+
sf "github.com/snowflakedb/gosnowflake"
12+
)
13+
14+
func main() {
15+
if !flag.Parsed() {
16+
flag.Parse()
17+
}
18+
19+
cfg, err := sf.GetConfigFromEnv([]*sf.ConfigParam{
20+
{Name: "Account", EnvName: "SNOWFLAKE_TEST_ACCOUNT", FailOnMissing: true},
21+
{Name: "User", EnvName: "SNOWFLAKE_TEST_USER", FailOnMissing: true},
22+
{Name: "Password", EnvName: "SNOWFLAKE_TEST_PASSWORD", FailOnMissing: true},
23+
{Name: "Host", EnvName: "SNOWFLAKE_TEST_HOST", FailOnMissing: false},
24+
{Name: "Port", EnvName: "SNOWFLAKE_TEST_PORT", FailOnMissing: false},
25+
{Name: "Protocol", EnvName: "SNOWFLAKE_TEST_PROTOCOL", FailOnMissing: false},
26+
})
27+
if err != nil {
28+
log.Fatalf("failed to create Config, err: %v", err)
29+
}
30+
31+
dsn, err := sf.DSN(cfg)
32+
if err != nil {
33+
log.Fatalf("failed to create DSN from Config: %v, err: %v", cfg, err)
34+
}
35+
36+
db, err := sql.Open("snowflake", dsn)
37+
if err != nil {
38+
log.Fatalf("failed to connect. %v, err: %v", dsn, err)
39+
}
40+
defer db.Close()
41+
42+
log.Println("Lets simulate running synchronous query and fetching the result by the query ID using the WithFetchResultByID context")
43+
sqlRows := fetchResultByIDSync(db, "SELECT 1")
44+
printSQLRowsResult(sqlRows)
45+
46+
log.Println("Lets simulate running long query asynchronously and fetching result by query ID using a channel provided in the WithQueryIDChan context")
47+
sqlRows = fetchResultByIDAsync(db, "CALL SYSTEM$WAIT(10, 'SECONDS')")
48+
printSQLRowsResult(sqlRows)
49+
}
50+
51+
func fetchResultByIDSync(db *sql.DB, query string) *sql.Rows {
52+
ctx := context.Background()
53+
conn, err := db.Conn(ctx)
54+
if err != nil {
55+
log.Fatalf("failed to get Conn. err: %v", err)
56+
}
57+
defer conn.Close()
58+
59+
var rows1 driver.Rows
60+
var queryID string
61+
62+
// Get the query ID using raw connection
63+
err = conn.Raw(func(x any) error {
64+
log.Printf("Executing query: %v\n", query)
65+
rows1, err = x.(driver.QueryerContext).QueryContext(ctx, query, nil)
66+
if err != nil {
67+
return err
68+
}
69+
70+
queryID = rows1.(sf.SnowflakeRows).GetQueryID()
71+
log.Printf("Query ID retrieved from GetQueryID(): %v\n", queryID)
72+
return nil
73+
})
74+
if err != nil {
75+
log.Fatalf("unable to run the query. err: %v", err)
76+
}
77+
78+
// Update the Context object to specify the query ID
79+
ctx = sf.WithFetchResultByID(ctx, queryID)
80+
81+
// Execute an empty string query
82+
rows2, err := db.QueryContext(ctx, "")
83+
if err != nil {
84+
log.Fatal(err)
85+
}
86+
87+
return rows2
88+
}
89+
90+
func fetchResultByIDAsync(db *sql.DB, query string) *sql.Rows {
91+
// Make a channel to receive the query ID
92+
queryIDChan := make(chan string, 1)
93+
94+
// Enable asynchronous mode
95+
ctx := sf.WithAsyncMode(context.Background())
96+
97+
// Pass the channel to receive the query ID
98+
ctx = sf.WithQueryIDChan(ctx, queryIDChan)
99+
100+
// Run a long running query asynchronously and without retrieving the result
101+
log.Printf("Executing query: %v\n", query)
102+
go db.ExecContext(ctx, query)
103+
104+
// Get the query ID without waiting for the query to finish
105+
queryID := <-queryIDChan
106+
log.Printf("Query ID retrieved from the channel: %v\n", queryID)
107+
108+
// Update the Context object to specify the query ID
109+
ctx = sf.WithFetchResultByID(ctx, queryID)
110+
111+
// Execute an empty string query
112+
rows, err := db.QueryContext(ctx, "")
113+
if err != nil {
114+
log.Fatal(err)
115+
}
116+
117+
return rows
118+
}
119+
120+
func printSQLRowsResult(rows *sql.Rows) {
121+
log.Print("Printing the results: \n")
122+
123+
cols, err := rows.Columns()
124+
if err != nil {
125+
log.Fatalf("failed to get columns. err: %v", err)
126+
}
127+
log.Println(strings.Join(cols, ", "))
128+
129+
var val string
130+
for rows.Next() {
131+
err := rows.Scan(&val)
132+
if err != nil {
133+
log.Fatalf("failed to scan rows. err: %v", err)
134+
}
135+
log.Printf("%v\n", val)
136+
}
137+
}

doc.go

+26
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,32 @@ For execs:
216216
217217
```
218218
219+
# Fetch Results by Query ID
220+
221+
The result of your query can be retrieved by setting the query ID in the WithFetchResultByID context.
222+
```
223+
224+
// Get the query ID using raw connection as mentioned above:
225+
err := conn.Raw(func(x any) error {
226+
rows1, err = x.(driver.QueryerContext).QueryContext(ctx, "SELECT 1", nil)
227+
queryID = rows1.(sf.SnowflakeRows).GetQueryID()
228+
return nil
229+
}
230+
231+
// Update the Context object to specify the query ID
232+
fetchResultByIDCtx = sf.WithFetchResultByID(ctx, queryID)
233+
234+
// Execute an empty string query
235+
rows2, err := db.QueryContext(fetchResultByIDCtx, "")
236+
237+
// Retrieve the results as usual
238+
for rows2.Next() {
239+
err = rows2.Scan(...)
240+
...
241+
}
242+
243+
```
244+
219245
# Canceling Query by CtrlC
220246
221247
From 0.5.0, a signal handling responsibility has moved to the applications. If you want to cancel a

monitoring.go

+4-8
Original file line numberDiff line numberDiff line change
@@ -214,15 +214,10 @@ func (sc *snowflakeConn) getQueryResultResp(
214214
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
215215
}
216216
url := sc.rest.getFullURL(resultPath, &param)
217-
res, err := sc.rest.FuncGet(ctx, sc.rest, url, headers, sc.rest.RequestTimeout)
217+
218+
respd, err := getQueryResultWithRetriesForAsyncMode(ctx, sc.rest, url, headers, sc.rest.RequestTimeout)
218219
if err != nil {
219-
logger.WithContext(ctx).Errorf("failed to get response. err: %v", err)
220-
return nil, err
221-
}
222-
defer res.Body.Close()
223-
var respd *execResponse
224-
if err = json.NewDecoder(res.Body).Decode(&respd); err != nil {
225-
logger.WithContext(ctx).Errorf("failed to decode JSON. err: %v", err)
220+
logger.WithContext(ctx).Errorf("error: %v", err)
226221
return nil, err
227222
}
228223
return respd, nil
@@ -238,6 +233,7 @@ func (sc *snowflakeConn) rowsForRunningQuery(
238233
logger.WithContext(ctx).Errorf("error: %v", err)
239234
return err
240235
}
236+
241237
if !resp.Success {
242238
code, err := strconv.Atoi(resp.Code)
243239
if err != nil {

0 commit comments

Comments
 (0)