Skip to content

Commit 082e9f8

Browse files
authored
SNOW-736353: Implement retry reason (#913)
1 parent 08a9669 commit 082e9f8

19 files changed

+306
-104
lines changed

async.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func (sr *snowflakeRestful) getAsync(
105105

106106
}
107107

108-
sc := &snowflakeConn{rest: sr, cfg: cfg, queryContextCache: (&queryContextCache{}).init()}
108+
sc := &snowflakeConn{rest: sr, cfg: cfg, queryContextCache: (&queryContextCache{}).init(), currentTimeProvider: defaultTimeProvider}
109109
if respd.Success {
110110
if resType == execResultType {
111111
res.insertID = -1

authokta.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ func postAuthSAML(
216216
fullURL := sr.getFullURL(authenticatorRequestPath, params)
217217

218218
logger.Infof("fullURL: %v", fullURL)
219-
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true)
219+
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider)
220220
if err != nil {
221221
return nil, err
222222
}
@@ -274,7 +274,7 @@ func postAuthOKTA(
274274
if err != nil {
275275
return nil, err
276276
}
277-
resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, false)
277+
resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, false, defaultTimeProvider)
278278
if err != nil {
279279
return nil, err
280280
}

chunk_downloader.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ func getChunk(
264264
if err != nil {
265265
return nil, err
266266
}
267-
return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout).execute()
267+
return newRetryHTTP(ctx, sc.rest.Client, http.NewRequest, u, headers, timeout, sc.currentTimeProvider).execute()
268268
}
269269

270270
func (scd *snowflakeChunkDownloader) startArrowBatches() error {
@@ -636,7 +636,7 @@ func (f *httpStreamChunkFetcher) fetch(URL string, rows chan<- []*string) error
636636
if err != nil {
637637
return err
638638
}
639-
res, err := newRetryHTTP(context.Background(), f.client, http.NewRequest, fullURL, f.headers, 0).execute()
639+
res, err := newRetryHTTP(context.Background(), f.client, http.NewRequest, fullURL, f.headers, 0, defaultTimeProvider).execute()
640640
if err != nil {
641641
return err
642642
}

client.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
// InternalClient is implemented by HTTPClient
1313
type InternalClient interface {
1414
Get(context.Context, *url.URL, map[string]string, time.Duration) (*http.Response, error)
15-
Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, bool) (*http.Response, error)
15+
Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error)
1616
}
1717

1818
type httpClient struct {
@@ -33,6 +33,7 @@ func (cli *httpClient) Post(
3333
headers map[string]string,
3434
body []byte,
3535
timeout time.Duration,
36-
raise4xx bool) (*http.Response, error) {
37-
return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, raise4xx)
36+
raise4xx bool,
37+
currentTimeProvider currentTimeProvider) (*http.Response, error) {
38+
return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, raise4xx, currentTimeProvider)
3839
}

client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func TestInternalClient(t *testing.T) {
4848
t.Fatalf("Expected exactly one GET request, got %v", transport.getRequests)
4949
}
5050

51-
resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, false)
51+
resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, false, defaultTimeProvider)
5252
if err != nil || resp.StatusCode != 200 {
5353
t.Fail()
5454
}

connection.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,14 @@ const (
6161
const privateLinkSuffix = "privatelink.snowflakecomputing.com"
6262

6363
type snowflakeConn struct {
64-
ctx context.Context
65-
cfg *Config
66-
rest *snowflakeRestful
67-
SequenceCounter uint64
68-
telemetry *snowflakeTelemetry
69-
internal InternalClient
70-
queryContextCache *queryContextCache
64+
ctx context.Context
65+
cfg *Config
66+
rest *snowflakeRestful
67+
SequenceCounter uint64
68+
telemetry *snowflakeTelemetry
69+
internal InternalClient
70+
queryContextCache *queryContextCache
71+
currentTimeProvider currentTimeProvider
7172
}
7273

7374
var (
@@ -727,10 +728,11 @@ func (scd *snowflakeArrowStreamChunkDownloader) GetBatches() (out []ArrowStreamB
727728

728729
func buildSnowflakeConn(ctx context.Context, config Config) (*snowflakeConn, error) {
729730
sc := &snowflakeConn{
730-
SequenceCounter: 0,
731-
ctx: ctx,
732-
cfg: &config,
733-
queryContextCache: (&queryContextCache{}).init(),
731+
SequenceCounter: 0,
732+
ctx: ctx,
733+
cfg: &config,
734+
queryContextCache: (&queryContextCache{}).init(),
735+
currentTimeProvider: defaultTimeProvider,
734736
}
735737
var st http.RoundTripper = SnowflakeTransport
736738
if sc.cfg.Transporter == nil {

connection_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ func TestGetQueryResultUsesTokenFromTokenAccessor(t *testing.T) {
130130
TokenAccessor: ta,
131131
}
132132
sc := &snowflakeConn{
133-
cfg: &Config{Params: map[string]*string{}},
134-
rest: sr,
133+
cfg: &Config{Params: map[string]*string{}},
134+
rest: sr,
135+
currentTimeProvider: defaultTimeProvider,
135136
}
136137
if _, err := sc.getQueryResultResp(context.Background(), ""); err != nil {
137138
t.Fatalf("err: %v", err)

heartbeat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func (hc *heartbeat) heartbeatMain() error {
6262

6363
fullURL := hc.restful.getFullURL(heartBeatPath, params)
6464
timeout := hc.restful.RequestTimeout
65-
resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, false)
65+
resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, false, defaultTimeProvider)
6666
if err != nil {
6767
return err
6868
}

monitoring.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"fmt"
1010
"net/url"
1111
"strconv"
12-
"time"
1312
)
1413

1514
const urlQueriesResultFmt = "/queries/%s/result"
@@ -208,7 +207,7 @@ func (sc *snowflakeConn) getQueryResultResp(
208207
paramsMutex.Unlock()
209208
param := make(url.Values)
210209
param.Add(requestIDKey, getOrGenerateRequestIDFromContext(ctx).String())
211-
param.Add("clientStartTime", strconv.FormatInt(time.Now().Unix(), 10))
210+
param.Add("clientStartTime", strconv.FormatInt(sc.currentTimeProvider.currentTime(), 10))
212211
param.Add(requestGUIDKey, NewUUID().String())
213212
token, _, _ := sc.rest.TokenAccessor.GetTokens()
214213
if token != "" {

ocsp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ func checkOCSPCacheServer(
358358
ocspS *ocspStatus) {
359359
var respd map[string][]interface{}
360360
headers := make(map[string]string)
361-
res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout).execute()
361+
res, err := newRetryHTTP(ctx, client, req, ocspServerHost, headers, totalTimeout, defaultTimeProvider).execute()
362362
if err != nil {
363363
logger.Errorf("failed to get OCSP cache from OCSP Cache Server. %v\n", err)
364364
return nil, &ocspStatus{
@@ -413,7 +413,7 @@ func retryOCSP(
413413
}
414414
res, err := newRetryHTTP(
415415
ctx, client, req, ocspHost, headers,
416-
totalTimeout*time.Duration(multiplier)).doPost().setBody(reqBody).execute()
416+
totalTimeout*time.Duration(multiplier), defaultTimeProvider).doPost().setBody(reqBody).execute()
417417
if err != nil {
418418
return ocspRes, ocspResBytes, &ocspStatus{
419419
code: ocspFailedSubmit,

0 commit comments

Comments
 (0)