Skip to content

Commit d56c0f2

Browse files
authored
SNOW-981533 Separate retry strategy for auth endpoints and the remaining ones (#982)
1 parent ceea09f commit d56c0f2

File tree

4 files changed

+49
-28
lines changed

4 files changed

+49
-28
lines changed

monitoring.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ func (sc *snowflakeConn) checkQueryStatus(
136136
if tok, _, _ := sc.rest.TokenAccessor.GetTokens(); tok != "" {
137137
headers[headerAuthorizationKey] = fmt.Sprintf(headerSnowflakeToken, tok)
138138
}
139-
resultPath := fmt.Sprintf("/monitoring/queries/%s", qid)
139+
resultPath := fmt.Sprintf("%s/%s", monitoringQueriesPath, qid)
140140
url := sc.rest.getFullURL(resultPath, &param)
141141

142142
res, err := sc.rest.FuncGet(ctx, sc.rest, url, headers, sc.rest.RequestTimeout)

restful.go

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ const (
3737
tokenRequestPath = "/session/token-request"
3838
abortRequestPath = "/queries/v1/abort-request"
3939
authenticatorRequestPath = "/session/authenticator-request"
40+
monitoringQueriesPath = "/monitoring/queries"
4041
sessionRequestPath = "/session"
4142
heartBeatPath = "/session/heartbeat"
4243
)

retry.go

+39-15
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,17 @@ import (
2020
type waitAlgo struct {
2121
mutex *sync.Mutex // required for *rand.Rand usage
2222
random *rand.Rand
23+
base time.Duration // base wait time
24+
cap time.Duration // maximum wait time
2325
}
2426

2527
var random *rand.Rand
2628
var defaultWaitAlgo *waitAlgo
2729

28-
var endpointsEligibleForRetry = []string{
30+
var authEndpoints = []string{
2931
loginRequestPath,
3032
tokenRequestPath,
3133
authenticatorRequestPath,
32-
queryRequestPath,
33-
abortRequestPath,
34-
sessionRequestPath,
3534
}
3635

3736
var clientErrorsStatusCodesEligibleForRetry = []int{
@@ -43,7 +42,7 @@ var clientErrorsStatusCodesEligibleForRetry = []int{
4342

4443
func init() {
4544
random = rand.New(rand.NewSource(time.Now().UnixNano()))
46-
defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random}
45+
defaultWaitAlgo = &waitAlgo{mutex: &sync.Mutex{}, random: random, base: 5 * time.Second, cap: 160 * time.Second}
4746
}
4847

4948
const (
@@ -205,12 +204,30 @@ func isQueryRequest(url *url.URL) bool {
205204
}
206205

207206
// jitter backoff in seconds
208-
func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, currWaitTime float64) float64 {
207+
func (w *waitAlgo) calculateWaitBeforeRetryForAuthRequest(attempt int, currWaitTimeDuration time.Duration) time.Duration {
209208
w.mutex.Lock()
210209
defer w.mutex.Unlock()
211-
jitterAmount := w.getJitter(currWaitTime)
212-
jitteredSleepTime := chooseRandomFromRange(currWaitTime+jitterAmount, math.Pow(2, float64(attempt))+jitterAmount)
213-
return jitteredSleepTime
210+
currWaitTimeInSeconds := currWaitTimeDuration.Seconds()
211+
jitterAmount := w.getJitter(currWaitTimeInSeconds)
212+
jitteredSleepTime := chooseRandomFromRange(currWaitTimeInSeconds+jitterAmount, math.Pow(2, float64(attempt))+jitterAmount)
213+
return time.Duration(jitteredSleepTime * float64(time.Second))
214+
}
215+
216+
func (w *waitAlgo) calculateWaitBeforeRetry(attempt int, sleep time.Duration) time.Duration {
217+
w.mutex.Lock()
218+
defer w.mutex.Unlock()
219+
t := 3*sleep - w.base
220+
switch {
221+
case t > 0:
222+
return durationMin(w.cap, randSecondDuration(t)+w.base)
223+
case t < 0:
224+
return durationMin(w.cap, randSecondDuration(-t)+3*sleep)
225+
}
226+
return w.base
227+
}
228+
229+
func randSecondDuration(n time.Duration) time.Duration {
230+
return time.Duration(random.Int63n(int64(n/time.Second))) * time.Second
214231
}
215232

216233
func (w *waitAlgo) getJitter(currWaitTime float64) float64 {
@@ -284,7 +301,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
284301
totalTimeout := r.timeout
285302
logger.WithContext(r.ctx).Infof("retryHTTP.totalTimeout: %v", totalTimeout)
286303
retryCounter := 0
287-
sleepTime := 1.0 // seconds
304+
sleepTime := time.Duration(time.Second)
288305
clientStartTime := strconv.FormatInt(r.currentTimeProvider.currentTime(), 10)
289306

290307
var requestGUIDReplacer requestGUIDReplacer
@@ -324,12 +341,16 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
324341
}
325342
// uses exponential jitter backoff
326343
retryCounter++
327-
sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(retryCounter, sleepTime)
344+
if isLoginRequest(req) {
345+
sleepTime = defaultWaitAlgo.calculateWaitBeforeRetryForAuthRequest(retryCounter, sleepTime)
346+
} else {
347+
sleepTime = defaultWaitAlgo.calculateWaitBeforeRetry(retryCounter, sleepTime)
348+
}
328349

329350
if totalTimeout > 0 {
330351
logger.WithContext(r.ctx).Infof("to timeout: %v", totalTimeout)
331352
// if any timeout is set
332-
totalTimeout -= time.Duration(sleepTime * float64(time.Second))
353+
totalTimeout -= sleepTime
333354
if totalTimeout <= 0 || retryCounter > r.maxRetryCount {
334355
if err != nil {
335356
return nil, err
@@ -360,7 +381,7 @@ func (r *retryHTTP) execute() (res *http.Response, err error) {
360381
logger.WithContext(r.ctx).Infof("sleeping %v. to timeout: %v. retrying", sleepTime, totalTimeout)
361382
logger.WithContext(r.ctx).Infof("retry count: %v, retry reason: %v", retryCounter, retryReason)
362383

363-
await := time.NewTimer(time.Duration(sleepTime * float64(time.Second)))
384+
await := time.NewTimer(sleepTime)
364385
select {
365386
case <-await.C:
366387
// retry the request
@@ -378,10 +399,13 @@ func isRetryableError(req *http.Request, res *http.Response, err error) (bool, e
378399
if res == nil || req == nil {
379400
return false, err
380401
}
381-
isRetryableURL := contains(endpointsEligibleForRetry, req.URL.Path)
382-
return isRetryableURL && isRetryableStatus(res.StatusCode), err
402+
return isRetryableStatus(res.StatusCode), err
383403
}
384404

385405
func isRetryableStatus(statusCode int) bool {
386406
return (statusCode >= 500 && statusCode < 600) || contains(clientErrorsStatusCodesEligibleForRetry, statusCode)
387407
}
408+
409+
func isLoginRequest(req *http.Request) bool {
410+
return contains(authEndpoints, req.URL.Path)
411+
}

retry_test.go

+8-12
Original file line numberDiff line numberDiff line change
@@ -493,12 +493,6 @@ func TestIsRetryable(t *testing.T) {
493493
err: nil,
494494
expected: false,
495495
},
496-
{
497-
req: &http.Request{URL: &url.URL{Path: heartBeatPath}},
498-
res: &http.Response{StatusCode: http.StatusBadRequest},
499-
err: nil,
500-
expected: false,
501-
},
502496
{
503497
req: &http.Request{URL: &url.URL{Path: loginRequestPath}},
504498
res: &http.Response{StatusCode: http.StatusNotFound},
@@ -525,10 +519,12 @@ func TestIsRetryable(t *testing.T) {
525519
}
526520

527521
for _, tc := range tcs {
528-
result, _ := isRetryableError(tc.req, tc.res, tc.err)
529-
if result != tc.expected {
530-
t.Fatalf("expected %v, got %v; request: %v, response: %v", tc.expected, result, tc.req, tc.res)
531-
}
522+
t.Run(fmt.Sprintf("req %v, resp %v", tc.req, tc.res), func(t *testing.T) {
523+
result, _ := isRetryableError(tc.req, tc.res, tc.err)
524+
if result != tc.expected {
525+
t.Fatalf("expected %v, got %v; request: %v, response: %v", tc.expected, result, tc.req, tc.res)
526+
}
527+
})
532528
}
533529
}
534530

@@ -605,8 +601,8 @@ func TestCalculateRetryWait(t *testing.T) {
605601

606602
for _, tc := range tcs {
607603
t.Run(fmt.Sprintf("attmept: %v", tc.attempt), func(t *testing.T) {
608-
result := defaultWaitAlgo.calculateWaitBeforeRetry(tc.attempt, tc.currWaitTime)
609-
assertBetweenE(t, result, tc.minSleepTime, tc.maxSleepTime)
604+
result := defaultWaitAlgo.calculateWaitBeforeRetryForAuthRequest(tc.attempt, time.Duration(tc.currWaitTime*float64(time.Second)))
605+
assertBetweenE(t, result.Seconds(), tc.minSleepTime, tc.maxSleepTime)
610606
})
611607
}
612608
}

0 commit comments

Comments
 (0)