diff --git a/pkg/networking/middleware/retry_middleware.go b/pkg/networking/middleware/retry_middleware.go index b52570dcf..dad97d901 100644 --- a/pkg/networking/middleware/retry_middleware.go +++ b/pkg/networking/middleware/retry_middleware.go @@ -42,7 +42,7 @@ var statusCodesToRetryLUT = map[int]retryLogic{ } var errRetryNecessary = errors.New("retry with backoff") -var errRetryAfterHeaderError = errors.New("retry-after is too much in the future") +var errRetryDelayMaxExceeded = errors.New("suggested retry delay exceeds maximum allowed wait") type RetryMiddleware struct { nextRoundtripper http.RoundTripper @@ -109,7 +109,7 @@ func (rm RetryMiddleware) RoundTrip(req *http.Request) (*http.Response, error) { response.Header.Set(retryCountHeaderKey, fmt.Sprintf("%d", actualAttempts)) } - // errors from the next round tripper cannot not be retried + // errors from the next round tripper cannot be retried if err != nil { return response, backoff.Permanent(err) } @@ -133,13 +133,21 @@ func (rm RetryMiddleware) RoundTrip(req *http.Request) (*http.Response, error) { backoffMethod.InitialInterval = time.Duration(retryAfterSeconds) * time.Second finalResponse, finalError = backoff.Retry(req.Context(), op, backoff.WithBackOff(backoffMethod)) - // if retries fail to resolve the issue, we need to unset the locally used error type to not return it from the RoundTripper - if errors.Is(finalError, errRetryNecessary) { + finalError = rm.filterRetryError(finalError, actualAttempts) + return finalResponse, finalError +} + +// filterRetryError strips sentinel errors used only inside the retry loop so callers receive the last HTTP response. +func (rm RetryMiddleware) filterRetryError(err error, actualAttempts int) error { + if errors.Is(err, errRetryNecessary) { rm.logger.Warn().Msgf("Retry ultimately failed after %d attempts", actualAttempts) - finalError = nil + return nil } - - return finalResponse, finalError + if errors.Is(err, errRetryDelayMaxExceeded) { + rm.logger.Warn().Msg("Suggested retry delay from Retry-After or X-RateLimit-Reset exceeds maximum allowed wait; returning last HTTP response") + return nil + } + return err } func getMaxRetryAttempts(response *http.Response, maxAttempts int) int { @@ -172,12 +180,23 @@ func shouldRetry(response *http.Response, attempts int, maxAttempts int) error { // try to read retry-after header if available if headerRetryAfterValue := response.Header.Get("Retry-After"); len(headerRetryAfterValue) > 0 { - fixRetryDelay = parseRetryAfterHeader(headerRetryAfterValue) + fixRetryDelay = parseRetryDelay(headerRetryAfterValue) + + // if the fix retry delay is too big, we rather fail permanently than blocking too long + if fixRetryDelay > maxRetryAfter { + return backoff.Permanent(errRetryDelayMaxExceeded) + } } - // if the fix retry delay is too big, we rather fail permanently than blocking too long - if fixRetryDelay > maxRetryAfter { - return backoff.Permanent(errRetryAfterHeaderError) + if fixRetryDelay == 0 { + // try to read X-RateLimit-Reset header if available + // according to envoy docs: number of seconds until reset of the current time-window + if headerXRateLimitResetValue := response.Header.Get("X-RateLimit-Reset"); len(headerXRateLimitResetValue) > 0 { + fixRetryDelay = parseRetryDelay(headerXRateLimitResetValue) + } + if fixRetryDelay > maxRetryAfter { + return backoff.Permanent(errRetryDelayMaxExceeded) + } } // if a retry after is defined, this is the time to wait for @@ -192,7 +211,7 @@ func shouldRetry(response *http.Response, attempts int, maxAttempts int) error { return nil } -func parseRetryAfterHeader(headerRetryAfterValue string) time.Duration { +func parseRetryDelay(headerRetryAfterValue string) time.Duration { // Retry-After: 1230 if tmp, err := strconv.ParseInt(headerRetryAfterValue, 10, 64); err == nil { return time.Duration(tmp) * time.Second diff --git a/pkg/networking/middleware/retry_middleware_test.go b/pkg/networking/middleware/retry_middleware_test.go index b052e3f7b..6a5e83187 100644 --- a/pkg/networking/middleware/retry_middleware_test.go +++ b/pkg/networking/middleware/retry_middleware_test.go @@ -15,6 +15,7 @@ import ( "github.com/rs/zerolog" "github.com/snyk/error-catalog-golang-public/snyk" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/cenkalti/backoff/v5" @@ -156,6 +157,101 @@ func TestNewRetryMiddleware(t *testing.T) { assert.Equal(t, fmt.Sprintf("%d", expectedAttempts), response.Header.Get(retryCountHeaderKey)) }) + t.Run("Happy path, 429 with only X-Ratelimit-Reset then success", func(t *testing.T) { + attemptCount := 0 + + //nolint:unparam // error is always nil but signature must match http.RoundTripper + customRTFn := func(req *http.Request) (*http.Response, error) { + attemptCount++ + headers := http.Header{} + + switch attemptCount { + case 1: + headers.Set("X-Ratelimit-Reset", "0") + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: headers, + Request: req, + }, nil + default: + return &http.Response{ + StatusCode: http.StatusOK, + Header: headers, + Request: req, + }, nil + } + } + + rt := &failRoundtripper{t: t, roundTripFn: &customRTFn} + config := configuration.NewWithOpts() + config.Set(ConfigurationKeyRequestAttempts, 3) + config.Set(configurationKeyRetryAfter, 1) + + sut := NewRetryMiddleware(config, &logger, rt) + resp, err := sut.RoundTrip(httptest.NewRequest(http.MethodGet, "/", bytes.NewReader(expectedBody))) + + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.GreaterOrEqual(t, attemptCount, 2) + }) + + t.Run("429 with Retry-After beyond max wait returns response without leaking internal error", func(t *testing.T) { + const hugeRetryAfter = "126144000" // 4 years in seconds; exceeds maxRetryAfter + + //nolint:unparam // error is always nil but signature must match http.RoundTripper + customRTFn := func(req *http.Request) (*http.Response, error) { + h := http.Header{} + h.Set("Retry-After", hugeRetryAfter) + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: h, + Request: req, + }, nil + } + + rt := &failRoundtripper{t: t, roundTripFn: &customRTFn} + config := configuration.NewWithOpts() + config.Set(ConfigurationKeyRequestAttempts, 3) + config.Set(configurationKeyRetryAfter, 1) + + sut := NewRetryMiddleware(config, &logger, rt) + resp, err := sut.RoundTrip(httptest.NewRequest(http.MethodGet, "/", nil)) + + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + require.Equal(t, hugeRetryAfter, resp.Header.Get("Retry-After")) + }) + + t.Run("429 with X-RateLimit-Reset beyond max wait returns response without leaking internal error", func(t *testing.T) { + const hugeReset = "126144000" + + //nolint:unparam // error is always nil but signature must match http.RoundTripper + customRTFn := func(req *http.Request) (*http.Response, error) { + h := http.Header{} + h.Set("X-RateLimit-Reset", hugeReset) + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: h, + Request: req, + }, nil + } + + rt := &failRoundtripper{t: t, roundTripFn: &customRTFn} + config := configuration.NewWithOpts() + config.Set(ConfigurationKeyRequestAttempts, 3) + config.Set(configurationKeyRetryAfter, 1) + + sut := NewRetryMiddleware(config, &logger, rt) + resp, err := sut.RoundTrip(httptest.NewRequest(http.MethodGet, "/", nil)) + + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + require.Equal(t, hugeReset, resp.Header.Get("X-RateLimit-Reset")) + }) + t.Run("Unhappy path, retries didn't resolve the issue", func(t *testing.T) { var expectedAttempts = 3 failureRoundtripper := &failRoundtripper{ @@ -242,7 +338,7 @@ func Test_shouldRetry(t *testing.T) { { name: "Retryable status code (429) with Retry-After header too far in the future (4years)", response: newResponse(http.StatusTooManyRequests, http.Header{"Retry-After": []string{"126144000"}}), - expectedErrorIs: &backoff.PermanentError{Err: errRetryAfterHeaderError}, + expectedErrorIs: &backoff.PermanentError{Err: errRetryDelayMaxExceeded}, attempts: 0, maxAttempts: 1, }, @@ -346,7 +442,73 @@ func Test_shouldRetry(t *testing.T) { } } -func Test_parseRetryAfterHeader(t *testing.T) { +func Test_shouldRetry_rateLimitResetHeaders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + response *http.Response + expectedErrorIs error + expectedRetryable *backoff.RetryAfterError + attempts int + maxAttempts int + }{ + { + name: "Retryable status code (429) with only X-Ratelimit-Reset header", + response: func() *http.Response { + h := http.Header{} + h.Set("X-Ratelimit-Reset", "5") + return newResponse(http.StatusTooManyRequests, h) + }(), + expectedRetryable: &backoff.RetryAfterError{Duration: 5 * time.Second}, + attempts: 0, + maxAttempts: 1, + }, + { + name: "Retryable status code (429) Retry-After takes precedence over X-RateLimit-Reset", + response: func() *http.Response { + h := http.Header{} + h.Set("Retry-After", "3") + h.Set("X-RateLimit-Reset", "10") + return newResponse(http.StatusTooManyRequests, h) + }(), + expectedRetryable: &backoff.RetryAfterError{Duration: 3 * time.Second}, + attempts: 0, + maxAttempts: 1, + }, + { + name: "Retryable status code (429) with X-RateLimit-Reset header too far in the future (4years)", + response: func() *http.Response { + h := http.Header{} + h.Set("X-RateLimit-Reset", "126144000") + return newResponse(http.StatusTooManyRequests, h) + }(), + expectedErrorIs: errRetryDelayMaxExceeded, + attempts: 0, + maxAttempts: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := shouldRetry(tt.response, tt.attempts, tt.maxAttempts) + + assert.NotNil(t, err) + if tt.expectedErrorIs != nil { + require.True(t, errors.Is(err, tt.expectedErrorIs), `Expected error to be equal to "%v" (%T), got "%v" (%T)`, tt.expectedErrorIs, tt.expectedErrorIs, err, err) + } + if tt.expectedRetryable != nil { + var actualRetryableErr *backoff.RetryAfterError + require.ErrorAs(t, err, &actualRetryableErr) + require.Equal(t, tt.expectedRetryable, actualRetryableErr, "RetryAfter duration mismatch") + } + }) + } +} + +func Test_parseRetryDelay(t *testing.T) { tests := []struct { name string input string @@ -381,7 +543,7 @@ func Test_parseRetryAfterHeader(t *testing.T) { for _, testcase := range tests { t.Run(testcase.name, func(t *testing.T) { - actualOutput := parseRetryAfterHeader(testcase.input) + actualOutput := parseRetryDelay(testcase.input) timeDistance := (testcase.output - actualOutput) / time.Second t.Logf("Time distance: %v", timeDistance) assert.Equal(t, 0.0, math.Abs(float64(timeDistance)))