Skip to content

Commit 082e9f8

Browse files
authored
SNOW-736353: Implement retry reason (#913)
1 parent 08a9669 commit 082e9f8

19 files changed

+306
-104
lines changed

async.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func (sr *snowflakeRestful) getAsync(
105105

106106
}
107107

108-
sc := &snowflakeConn{rest: sr, cfg: cfg, queryContextCache: (&queryContextCache{}).init()}
108+
sc := &snowflakeConn{rest: sr, cfg: cfg, queryContextCache: (&queryContextCache{}).init(), currentTimeProvider: defaultTimeProvider}
109109
if respd.Success {
110110
if resType == execResultType {
111111
res.insertID = -1

authokta.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ func postAuthSAML(
216216
fullURL := sr.getFullURL(authenticatorRequestPath, params)
217217

218218
logger.Infof("fullURL: %v", fullURL)
219-
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true)
219+
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider)
220220
if err != nil {
221221
return nil, err
222222
}
@@ -274,7 +274,7 @@ func postAuthOKTA(
274274
if err != nil {
275275
return nil, err
276276
}
277-
resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, false)
277+
resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, false, defaultTimeProvider)
278278
if err != nil {
279279
return nil, err
280280
}

chunk_downloader.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ func getChunk(
264264
if err != nil {
265265
return nil, err
266266
}
267-
return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout).execute()
267+
return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout, sc.currentTimeProvider).execute()
268268
}
269269

270270
func (scd *snowflakeChunkDownloader) startArrowBatches() error {
@@ -636,7 +636,7 @@ func (f *httpStreamChunkFetcher) fetch(URL string, rows chan<- []*string) error
636636
if err != nil {
637637
return err
638638
}
639-
res, err := newRetryHTTP(context.Background(), f.client, http.NewRequest, fullURL, f.headers, 0).execute()
639+
res, err := newRetryHTTP(context.Background(), f.client, http.NewRequest, fullURL, f.headers, 0, defaultTimeProvider).execute()
640640
if err != nil {
641641
return err
642642
}

client.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
// InternalClient is implemented by HTTPClient
1313
type InternalClient interface {
1414
Get(context.Context, *url.URL, map[string]string, time.Duration) (*http.Response, error)
15-
Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, bool) (*http.Response, error)
15+
Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error)
1616
}
1717

1818
type httpClient struct {
@@ -33,6 +33,7 @@ func (cli *httpClient) Post(
3333
headers map[string]string,
3434
body []byte,
3535
timeout time.Duration,
36-
raise4xx bool) (*http.Response, error) {
37-
return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, raise4xx)
36+
raise4xx bool,
37+
currentTimeProvider currentTimeProvider) (*http.Response, error) {
38+
return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, raise4xx, currentTimeProvider)
3839
}

client_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func TestInternalClient(t *testing.T) {
4848
t.Fatalf("Expected exactly one GET request, got %v", transport.getRequests)
4949
}
5050

51-
resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, false)
51+
resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, false, defaultTimeProvider)
5252
if err != nil || resp.StatusCode != 200 {
5353
t.Fail()
5454
}

connection.go

+13-11
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,14 @@ const (
6161
const privateLinkSuffix = "privatelink.snowflakecomputing.com"
6262

6363
type snowflakeConn struct {
64-
ctx context.Context
65-
cfg *Config
66-
rest *snowflakeRestful
67-
SequenceCounter uint64
68-
telemetry *snowflakeTelemetry
69-
internal InternalClient
70-
queryContextCache *queryContextCache
64+
ctx context.Context
65+
cfg *Config
66+
rest *snowflakeRestful
67+
SequenceCounter uint64
68+
telemetry *snowflakeTelemetry
69+
internal InternalClient
70+
queryContextCache *queryContextCache
71+
currentTimeProvider currentTimeProvider
7172
}
7273

7374
var (
@@ -727,10 +728,11 @@ func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamB
727728

728729
func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, error) {
729730
sc := &snowflakeConn{
730-
SequenceCounter: 0,
731-
ctx: ctx,
732-
cfg: &config,
733-
queryContextCache: (&queryContextCache{}).init(),
731+
SequenceCounter: 0,
732+
ctx: ctx,
733+
cfg: &config,
734+
queryContextCache: (&queryContextCache{}).init(),
735+
currentTimeProvider: defaultTimeProvider,
734736
}
735737
var st http.RoundTripper = SnowflakeTransport
736738
if sc.cfg.Transporter == nil {

connection_test.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ func TestGetQueryResultUsesTokenFromTokenAccessor(t *testing.T) {
130130
TokenAccessor: ta,
131131
}
132132
sc := &snowflakeConn{
133-
cfg: &Config{Params: map[string]*string{}},
134-
rest: sr,
133+
cfg: &Config{Params: map[string]*string{}},
134+
rest: sr,
135+
currentTimeProvider: defaultTimeProvider,
135136
}
136137
if _, err := sc.getQueryResultResp(context.Background(), ""); err != nil {
137138
t.Fatalf("err: %v", err)

heartbeat.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func (hc *heartbeat) heartbeatMain() error {
6262

6363
fullURL := hc.restful.getFullURL(heartBeatPath, params)
6464
timeout := hc.restful.RequestTimeout
65-
resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, false)
65+
resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, false, defaultTimeProvider)
6666
if err != nil {
6767
return err
6868
}

monitoring.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"fmt"
1010
"net/url"
1111
"strconv"
12-
"time"
1312
)
1413

1514
const urlQueriesResultFmt = "/queries/%s/result"
@@ -208,7 +207,7 @@ func (sc *snowflakeConn) getQueryResultResp(
208207
paramsMutex.Unlock()
209208
param := make(url.Values)
210209
param.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
211-
param.Add("clientStartTime", strconv.FormatInt(time.Now().Unix(), 10))
210+
param.Add("clientStartTime", strconv.FormatInt(sc.currentTimeProvider.currentTime(), 10))
212211
param.Add(requestGUIDKey, NewUUID().String())
213212
token, _, _ := sc.rest.TokenAccessor.GetTokens()
214213
if token != "" {

ocsp.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ func checkOCSPCacheServer(
358358
ocspS *ocspStatus) {
359359
var respd map[string][]interface{}
360360
headers := make(map[string]string)
361-
res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout).execute()
361+
res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout, defaultTimeProvider).execute()
362362
if err != nil {
363363
logger.Errorf("failed to get OCSP cache from OCSP Cache Server. %v\n", err)
364364
return nil, &ocspStatus{
@@ -413,7 +413,7 @@ func retryOCSP(
413413
}
414414
res, err := newRetryHTTP(
415415
ctx, client, req, ocspHost, headers,
416-
totalTimeout*time.Duration(multiplier)).doPost().setBody(reqBody).execute()
416+
totalTimeout*time.Duration(multiplier), defaultTimeProvider).doPost().setBody(reqBody).execute()
417417
if err != nil {
418418
return ocspRes, ocspResBytes, &ocspStatus{
419419
code: ocspFailedSubmit,

ocsp_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ func TestOCSPRetry(t *testing.T) {
319319
}
320320
res, b, st := retryOCSP(
321321
context.TODO(),
322-
client, fakeRequestFunc,
322+
client, emptyRequest,
323323
dummyOCSPHost,
324324
make(map[string]string), []byte{0}, certs[len(certs)-1], 10*time.Second)
325325
if st.err == nil {

restful.go

+10-10
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ const (
4343

4444
type (
4545
funcGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error)
46-
funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, bool) (*http.Response, error)
46+
funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error)
4747
funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, bool) (*http.Response, error)
4848
bodyCreatorType func() ([]byte, error)
4949
)
@@ -162,9 +162,10 @@ func postRestful(
162162
headers map[string]string,
163163
body []byte,
164164
timeout time.Duration,
165-
raise4XX bool) (
165+
raise4XX bool,
166+
currentTimeProvider currentTimeProvider) (
166167
*http.Response, error) {
167-
return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout).
168+
return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, currentTimeProvider).
168169
doPost().
169170
setBody(body).
170171
doRaise4XX(raise4XX).
@@ -178,7 +179,7 @@ func getRestful(
178179
headers map[string]string,
179180
timeout time.Duration) (
180181
*http.Response, error) {
181-
return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout).execute()
182+
return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider).execute()
182183
}
183184

184185
func postAuthRestful(
@@ -190,7 +191,7 @@ func postAuthRestful(
190191
timeout time.Duration,
191192
raise4XX bool) (
192193
*http.Response, error) {
193-
return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout).
194+
return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider).
194195
doPost().
195196
setBodyCreator(bodyCreator).
196197
doRaise4XX(raise4XX).
@@ -233,7 +234,6 @@ func postRestfulQueryHelper(
233234
data *execResponse, err error) {
234235
logger.Infof("params: %v", params)
235236
params.Add(requestIDKey, requestID.String())
236-
params.Add("clientStartTime", strconv.FormatInt(time.Now().Unix(), 10))
237237
params.Add(requestGUIDKey, NewUUID().String())
238238
token, _, _ := sr.TokenAccessor.GetTokens()
239239
if token != "" {
@@ -242,7 +242,7 @@ func postRestfulQueryHelper(
242242

243243
var resp *http.Response
244244
fullURL := sr.getFullURL(queryRequestPath, params)
245-
resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true)
245+
resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider)
246246
if err != nil {
247247
return nil, err
248248
}
@@ -334,7 +334,7 @@ func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Durati
334334
token, _, _ := sr.TokenAccessor.GetTokens()
335335
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
336336

337-
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, false)
337+
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, false, defaultTimeProvider)
338338
if err != nil {
339339
return err
340340
}
@@ -393,7 +393,7 @@ func renewRestfulSession(ctx context.Context, sr *snowflakeRestful, timeout time
393393
return err
394394
}
395395

396-
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, false)
396+
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, false, defaultTimeProvider)
397397
if err != nil {
398398
return err
399399
}
@@ -465,7 +465,7 @@ func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID UUID, time
465465
return err
466466
}
467467

468-
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, false)
468+
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, false, defaultTimeProvider)
469469
if err != nil {
470470
return err
471471
}

restful_test.go

+9-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
"time"
1616
)
1717

18-
func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) {
18+
func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
1919
return &http.Response{
2020
StatusCode: http.StatusOK,
2121
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
@@ -29,14 +29,14 @@ func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[stri
2929
}, errors.New("failed to run post method")
3030
}
3131

32-
func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) {
32+
func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
3333
return &http.Response{
3434
StatusCode: http.StatusOK,
3535
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
3636
}, nil
3737
}
3838

39-
func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) {
39+
func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
4040
return &http.Response{
4141
StatusCode: http.StatusBadGateway,
4242
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
@@ -50,7 +50,7 @@ func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.UR
5050
}, nil
5151
}
5252

53-
func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) {
53+
func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
5454
return &http.Response{
5555
StatusCode: http.StatusForbidden,
5656
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
@@ -71,7 +71,7 @@ func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.UR
7171
}, nil
7272
}
7373

74-
func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) {
74+
func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
7575
dd := &execResponseData{}
7676
er := &execResponse{
7777
Data: *dd,
@@ -90,7 +90,7 @@ func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.UR
9090
}, nil
9191
}
9292

93-
func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) {
93+
func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
9494
dd := &execResponseData{}
9595
er := &execResponse{
9696
Data: *dd,
@@ -130,7 +130,7 @@ func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map
130130
}, nil
131131
}
132132

133-
func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) {
133+
func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
134134
dd := &execResponseData{}
135135
er := &execResponse{
136136
Data: *dd,
@@ -157,7 +157,7 @@ func cancelTestRetry(ctx context.Context, sr *snowflakeRestful, requestID UUID,
157157
if err != nil {
158158
return err
159159
}
160-
resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, false)
160+
resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, false, defaultTimeProvider)
161161
if err != nil {
162162
return err
163163
}
@@ -462,7 +462,7 @@ func TestUnitRenewRestfulSession(t *testing.T) {
462462
accessor := getSimpleTokenAccessor()
463463
oldToken, oldMasterToken, oldSessionID := "oldtoken", "oldmaster", int64(100)
464464
newToken, newMasterToken, newSessionID := "newtoken", "newmaster", int64(200)
465-
postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ bool) (*http.Response, error) {
465+
postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider) (*http.Response, error) {
466466
if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, oldMasterToken) {
467467
t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, oldMasterToken))
468468
}

0 commit comments

Comments
 (0)