Skip to content

Commit b540ce5

Browse files
SNOW-878073 Refactor retry policy to support HTTP 503 & 429 (#919)
1 parent ce9ef59 commit b540ce5

17 files changed

+229
-131
lines changed

auth.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ func postAuth(
226226

227227
fullURL := sr.getFullURL(loginRequestPath, params)
228228
logger.Infof("full URL: %v", fullURL)
229-
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout, true)
229+
resp, err := sr.FuncAuthPost(ctx, client, fullURL, headers, bodyCreator, timeout)
230230
if err != nil {
231231
return nil, err
232232
}
@@ -279,6 +279,8 @@ func getHeaders() map[string]string {
279279
headers := make(map[string]string)
280280
headers[httpHeaderContentType] = headerContentTypeApplicationJSON
281281
headers[httpHeaderAccept] = headerAcceptTypeApplicationSnowflake
282+
headers[httpClientAppID] = clientType
283+
headers[httpClientAppVersion] = SnowflakeGoDriverVersion
282284
headers[httpHeaderUserAgent] = userAgent
283285
return headers
284286
}

authokta.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ func postAuthSAML(
216216
fullURL := sr.getFullURL(authenticatorRequestPath, params)
217217

218218
logger.Infof("fullURL: %v", fullURL)
219-
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider, nil)
219+
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider, nil)
220220
if err != nil {
221221
return nil, err
222222
}
@@ -274,7 +274,7 @@ func postAuthOKTA(
274274
if err != nil {
275275
return nil, err
276276
}
277-
resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, false, defaultTimeProvider, nil)
277+
resp, err := sr.FuncPost(ctx, sr, targetURL, headers, body, timeout, defaultTimeProvider, nil)
278278
if err != nil {
279279
return nil, err
280280
}

client.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
// InternalClient is implemented by HTTPClient
1313
type InternalClient interface {
1414
Get(context.Context, *url.URL, map[string]string, time.Duration) (*http.Response, error)
15-
Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider) (*http.Response, error)
15+
Post(context.Context, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider) (*http.Response, error)
1616
}
1717

1818
type httpClient struct {
@@ -33,7 +33,6 @@ func (cli *httpClient) Post(
3333
headers map[string]string,
3434
body []byte,
3535
timeout time.Duration,
36-
raise4xx bool,
3736
currentTimeProvider currentTimeProvider) (*http.Response, error) {
38-
return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, raise4xx, currentTimeProvider, nil)
37+
return cli.sr.FuncPost(ctx, cli.sr, url, headers, body, timeout, currentTimeProvider, nil)
3938
}

client_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func TestInternalClient(t *testing.T) {
4848
t.Fatalf("Expected exactly one GET request, got %v", transport.getRequests)
4949
}
5050

51-
resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, false, defaultTimeProvider)
51+
resp, err = internalClient.Post(context.Background(), &url.URL{}, make(map[string]string), make([]byte, 0), 0, defaultTimeProvider)
5252
if err != nil || resp.StatusCode != 200 {
5353
t.Fail()
5454
}

connection.go

+2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ const (
3434
httpHeaderHost = "Host"
3535
httpHeaderValueOctetStream = "application/octet-stream"
3636
httpHeaderContentEncoding = "Content-Encoding"
37+
httpClientAppID = "CLIENT_APP_ID"
38+
httpClientAppVersion = "CLIENT_APP_VERSION"
3739
)
3840

3941
const (

driver_ocsp_test.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -634,8 +634,8 @@ func TestOCSPFailClosedResponder404(t *testing.T) {
634634
if !ok {
635635
t.Fatalf("failed to extract error URL Error: %v", err)
636636
}
637-
if !strings.Contains(urlErr.Err.Error(), "HTTP Status: 404") {
638-
t.Fatalf("the root cause is not timeout: %v", urlErr.Err)
637+
if !strings.Contains(urlErr.Err.Error(), "404 Not Found") {
638+
t.Fatalf("the root cause is not timeout: %v", urlErr.Err)
639639
}
640640
}
641641

dsn.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919
)
2020

2121
const (
22-
defaultClientTimeout = 900 * time.Second // Timeout for network round trip + read out http response
22+
defaultClientTimeout = 300 * time.Second // Timeout for network round trip + read out http response
2323
defaultJWTClientTimeout = 10 * time.Second // Timeout for network round trip + read out http response but used for JWT auth
2424
defaultLoginTimeout = 60 * time.Second // Timeout for retry for login EXCLUDING clientTimeout
2525
defaultRequestTimeout = 0 * time.Second // Timeout for retry for request EXCLUDING clientTimeout

dsn_test.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -1119,20 +1119,20 @@ func TestDSN(t *testing.T) {
11191119
User: "u",
11201120
Password: "p",
11211121
Account: "a.b.c",
1122-
ClientTimeout: 300 * time.Second,
1122+
ClientTimeout: 400 * time.Second,
11231123
JWTClientTimeout: 60 * time.Second,
11241124
},
1125-
dsn: "u:[email protected]:443?clientTimeout=300&jwtClientTimeout=60&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
1125+
dsn: "u:[email protected]:443?clientTimeout=400&jwtClientTimeout=60&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
11261126
},
11271127
{
11281128
cfg: &Config{
11291129
User: "u",
11301130
Password: "p",
11311131
Account: "a.b.c",
1132-
ClientTimeout: 300 * time.Second,
1132+
ClientTimeout: 400 * time.Second,
11331133
JWTExpireTimeout: 30 * time.Second,
11341134
},
1135-
dsn: "u:[email protected]:443?clientTimeout=300&jwtTimeout=30&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
1135+
dsn: "u:[email protected]:443?clientTimeout=400&jwtTimeout=30&ocspFailOpen=true&region=b.c&validateDefaultParameters=true",
11361136
},
11371137
{
11381138
cfg: &Config{

heartbeat.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ func (hc *heartbeat) heartbeatMain() error {
6262

6363
fullURL := hc.restful.getFullURL(heartBeatPath, params)
6464
timeout := hc.restful.RequestTimeout
65-
resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, false, defaultTimeProvider, nil)
65+
resp, err := hc.restful.FuncPost(context.Background(), hc.restful, fullURL, headers, nil, timeout, defaultTimeProvider, nil)
6666
if err != nil {
6767
return err
6868
}

restful.go

+7-11
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ const (
4343

4444
type (
4545
funcGetType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, time.Duration) (*http.Response, error)
46-
funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, bool, currentTimeProvider, *Config) (*http.Response, error)
47-
funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration, bool) (*http.Response, error)
46+
funcPostType func(context.Context, *snowflakeRestful, *url.URL, map[string]string, []byte, time.Duration, currentTimeProvider, *Config) (*http.Response, error)
47+
funcAuthPostType func(context.Context, *http.Client, *url.URL, map[string]string, bodyCreatorType, time.Duration) (*http.Response, error)
4848
bodyCreatorType func() ([]byte, error)
4949
)
5050

@@ -162,14 +162,12 @@ func postRestful(
162162
headers map[string]string,
163163
body []byte,
164164
timeout time.Duration,
165-
raise4XX bool,
166165
currentTimeProvider currentTimeProvider,
167166
cfg *Config) (
168167
*http.Response, error) {
169168
return newRetryHTTP(ctx, sr.Client, http.NewRequest, fullURL, headers, timeout, currentTimeProvider, cfg).
170169
doPost().
171170
setBody(body).
172-
doRaise4XX(raise4XX).
173171
execute()
174172
}
175173

@@ -189,13 +187,11 @@ func postAuthRestful(
189187
fullURL *url.URL,
190188
headers map[string]string,
191189
bodyCreator bodyCreatorType,
192-
timeout time.Duration,
193-
raise4XX bool) (
190+
timeout time.Duration) (
194191
*http.Response, error) {
195192
return newRetryHTTP(ctx, client, http.NewRequest, fullURL, headers, timeout, defaultTimeProvider, nil).
196193
doPost().
197194
setBodyCreator(bodyCreator).
198-
doRaise4XX(raise4XX).
199195
execute()
200196
}
201197

@@ -243,7 +239,7 @@ func postRestfulQueryHelper(
243239

244240
var resp *http.Response
245241
fullURL := sr.getFullURL(queryRequestPath, params)
246-
resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, true, defaultTimeProvider, cfg)
242+
resp, err = sr.FuncPost(ctx, sr, fullURL, headers, body, timeout, defaultTimeProvider, cfg)
247243
if err != nil {
248244
return nil, err
249245
}
@@ -335,7 +331,7 @@ func closeSession(ctx context.Context, sr *snowflakeRestful, timeout time.Durati
335331
token, _, _ := sr.TokenAccessor.GetTokens()
336332
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, token)
337333

338-
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, false, defaultTimeProvider, nil)
334+
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, nil, 5*time.Second, defaultTimeProvider, nil)
339335
if err != nil {
340336
return err
341337
}
@@ -394,7 +390,7 @@ func renewRestfulSession(ctx context.Context, sr *snowflakeRestful, timeout time
394390
return err
395391
}
396392

397-
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, false, defaultTimeProvider, nil)
393+
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqBody, timeout, defaultTimeProvider, nil)
398394
if err != nil {
399395
return err
400396
}
@@ -466,7 +462,7 @@ func cancelQuery(ctx context.Context, sr *snowflakeRestful, requestID UUID, time
466462
return err
467463
}
468464

469-
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, false, defaultTimeProvider, nil)
465+
resp, err := sr.FuncPost(ctx, sr, fullURL, headers, reqByte, timeout, defaultTimeProvider, nil)
470466
if err != nil {
471467
return err
472468
}

restful_test.go

+14-14
Original file line numberDiff line numberDiff line change
@@ -15,63 +15,63 @@ import (
1515
"time"
1616
)
1717

18-
func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
18+
func postTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
1919
return &http.Response{
2020
StatusCode: http.StatusOK,
2121
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
2222
}, errors.New("failed to run post method")
2323
}
2424

25-
func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
25+
func postAuthTestError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
2626
return &http.Response{
2727
StatusCode: http.StatusOK,
2828
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
2929
}, errors.New("failed to run post method")
3030
}
3131

32-
func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
32+
func postTestSuccessButInvalidJSON(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
3333
return &http.Response{
3434
StatusCode: http.StatusOK,
3535
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
3636
}, nil
3737
}
3838

39-
func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
39+
func postTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
4040
return &http.Response{
4141
StatusCode: http.StatusBadGateway,
4242
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
4343
}, nil
4444
}
4545

46-
func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
46+
func postAuthTestAppBadGatewayError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
4747
return &http.Response{
4848
StatusCode: http.StatusBadGateway,
4949
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
5050
}, nil
5151
}
5252

53-
func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
53+
func postTestAppForbiddenError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
5454
return &http.Response{
5555
StatusCode: http.StatusForbidden,
5656
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
5757
}, nil
5858
}
5959

60-
func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
60+
func postAuthTestAppForbiddenError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
6161
return &http.Response{
6262
StatusCode: http.StatusForbidden,
6363
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
6464
}, nil
6565
}
6666

67-
func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
67+
func postAuthTestAppUnexpectedError(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
6868
return &http.Response{
6969
StatusCode: http.StatusInsufficientStorage,
7070
Body: &fakeResponseBody{body: []byte{0x12, 0x34}},
7171
}, nil
7272
}
7373

74-
func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
74+
func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
7575
dd := &execResponseData{}
7676
er := &execResponse{
7777
Data: *dd,
@@ -90,7 +90,7 @@ func postTestQueryNotExecuting(_ context.Context, _ *snowflakeRestful, _ *url.UR
9090
}, nil
9191
}
9292

93-
func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
93+
func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
9494
dd := &execResponseData{}
9595
er := &execResponse{
9696
Data: *dd,
@@ -110,7 +110,7 @@ func postTestRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[str
110110
}, nil
111111
}
112112

113-
func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration, _ bool) (*http.Response, error) {
113+
func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map[string]string, _ bodyCreatorType, _ time.Duration) (*http.Response, error) {
114114
dd := &execResponseData{}
115115
er := &execResponse{
116116
Data: *dd,
@@ -130,7 +130,7 @@ func postAuthTestAfterRenew(_ context.Context, _ *http.Client, _ *url.URL, _ map
130130
}, nil
131131
}
132132

133-
func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
133+
func postTestAfterRenew(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
134134
dd := &execResponseData{}
135135
er := &execResponse{
136136
Data: *dd,
@@ -157,7 +157,7 @@ func cancelTestRetry(ctx context.Context, sr *snowflakeRestful, requestID UUID,
157157
if err != nil {
158158
return err
159159
}
160-
resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, false, defaultTimeProvider, nil)
160+
resp, err := sr.FuncPost(ctx, sr, &u, getHeaders(), reqByte, timeout, defaultTimeProvider, nil)
161161
if err != nil {
162162
return err
163163
}
@@ -462,7 +462,7 @@ func TestUnitRenewRestfulSession(t *testing.T) {
462462
accessor := getSimpleTokenAccessor()
463463
oldToken, oldMasterToken, oldSessionID := "oldtoken", "oldmaster", int64(100)
464464
newToken, newMasterToken, newSessionID := "newtoken", "newmaster", int64(200)
465-
postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ bool, _ currentTimeProvider, _ *Config) (*http.Response, error) {
465+
postTestSuccessWithNewTokens := func(_ context.Context, _ *snowflakeRestful, _ *url.URL, headers map[string]string, _ []byte, _ time.Duration, _ currentTimeProvider, _ *Config) (*http.Response, error) {
466466
if headers[headerAuthorizationKey] != fmt.Sprintf(headerSnowflakeToken, oldMasterToken) {
467467
t.Fatalf("authorization key doesn't match, %v vs %v", headers[headerAuthorizationKey], fmt.Sprintf(headerSnowflakeToken, oldMasterToken))
468468
}

0 commit comments

Comments
 (0)