Skip to content

Commit b21bd7a

Browse files
authored
Automatically wait for retries in gh eval (#75)
2 parents 79f1655 + ffabf58 commit b21bd7a

File tree

3 files changed

+210
-44
lines changed

3 files changed

+210
-44
lines changed

cmd/eval/eval.go

Lines changed: 53 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"errors"
88
"fmt"
99
"strings"
10+
"time"
1011

1112
"github.com/MakeNowJust/heredoc"
1213
"github.com/github/gh-models/internal/azuremodels"
@@ -80,6 +81,8 @@ func NewEvalCommand(cfg *command.Config) *cobra.Command {
8081
8182
By default, results are displayed in a human-readable format. Use the --json flag
8283
to output structured JSON data for programmatic use or integration with CI/CD pipelines.
84+
This command will automatically retry on rate limiting errors, waiting for the specified
85+
duration before retrying the request.
8386
8487
See https://docs.github.com/github-models/use-github-models/storing-prompts-in-github-repositories#supported-file-format for more information.
8588
`),
@@ -327,36 +330,65 @@ func (h *evalCommandHandler) templateString(templateStr string, data map[string]
327330
return prompt.TemplateString(templateStr, data)
328331
}
329332

330-
func (h *evalCommandHandler) callModel(ctx context.Context, messages []azuremodels.ChatMessage) (string, error) {
331-
req := h.evalFile.BuildChatCompletionOptions(messages)
332-
333-
resp, err := h.client.GetChatCompletionStream(ctx, req, h.org)
334-
if err != nil {
335-
return "", err
336-
}
333+
// callModelWithRetry makes an API call with automatic retry on rate limiting
334+
func (h *evalCommandHandler) callModelWithRetry(ctx context.Context, req azuremodels.ChatCompletionOptions) (string, error) {
335+
const maxRetries = 3
337336

338-
// For non-streaming requests, we should get a single response
339-
var content strings.Builder
340-
for {
341-
completion, err := resp.Reader.Read()
337+
for attempt := 0; attempt <= maxRetries; attempt++ {
338+
resp, err := h.client.GetChatCompletionStream(ctx, req, h.org)
342339
if err != nil {
343-
if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "EOF") {
344-
break
340+
var rateLimitErr *azuremodels.RateLimitError
341+
if errors.As(err, &rateLimitErr) {
342+
if attempt < maxRetries {
343+
if !h.jsonOutput {
344+
h.cfg.WriteToOut(fmt.Sprintf(" Rate limited, waiting %v before retry (attempt %d/%d)...\n",
345+
rateLimitErr.RetryAfter, attempt+1, maxRetries+1))
346+
}
347+
348+
// Wait for the specified duration
349+
select {
350+
case <-ctx.Done():
351+
return "", ctx.Err()
352+
case <-time.After(rateLimitErr.RetryAfter):
353+
continue
354+
}
355+
}
356+
return "", fmt.Errorf("rate limit exceeded after %d attempts: %w", attempt+1, err)
345357
}
358+
// For non-rate-limit errors, return immediately
346359
return "", err
347360
}
348361

349-
for _, choice := range completion.Choices {
350-
if choice.Delta != nil && choice.Delta.Content != nil {
351-
content.WriteString(*choice.Delta.Content)
362+
var content strings.Builder
363+
for {
364+
completion, err := resp.Reader.Read()
365+
if err != nil {
366+
if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "EOF") {
367+
break
368+
}
369+
return "", err
352370
}
353-
if choice.Message != nil && choice.Message.Content != nil {
354-
content.WriteString(*choice.Message.Content)
371+
372+
for _, choice := range completion.Choices {
373+
if choice.Delta != nil && choice.Delta.Content != nil {
374+
content.WriteString(*choice.Delta.Content)
375+
}
376+
if choice.Message != nil && choice.Message.Content != nil {
377+
content.WriteString(*choice.Message.Content)
378+
}
355379
}
356380
}
381+
382+
return strings.TrimSpace(content.String()), nil
357383
}
358384

359-
return strings.TrimSpace(content.String()), nil
385+
// This should never be reached, but just in case
386+
return "", errors.New("unexpected error calling model")
387+
}
388+
389+
func (h *evalCommandHandler) callModel(ctx context.Context, messages []azuremodels.ChatMessage) (string, error) {
390+
req := h.evalFile.BuildChatCompletionOptions(messages)
391+
return h.callModelWithRetry(ctx, req)
360392
}
361393

362394
func (h *evalCommandHandler) runEvaluators(ctx context.Context, testCase map[string]interface{}, response string) ([]EvaluationResult, error) {
@@ -437,7 +469,6 @@ func (h *evalCommandHandler) runStringEvaluator(name string, eval prompt.StringE
437469
}
438470

439471
func (h *evalCommandHandler) runLLMEvaluator(ctx context.Context, name string, eval prompt.LLMEvaluator, testCase map[string]interface{}, response string) (EvaluationResult, error) {
440-
// Template the evaluation prompt
441472
evalData := make(map[string]interface{})
442473
for k, v := range testCase {
443474
evalData[k] = v
@@ -449,7 +480,6 @@ func (h *evalCommandHandler) runLLMEvaluator(ctx context.Context, name string, e
449480
return EvaluationResult{}, fmt.Errorf("failed to template evaluation prompt: %w", err)
450481
}
451482

452-
// Prepare messages for evaluation
453483
var messages []azuremodels.ChatMessage
454484
if eval.SystemPrompt != "" {
455485
messages = append(messages, azuremodels.ChatMessage{
@@ -462,40 +492,19 @@ func (h *evalCommandHandler) runLLMEvaluator(ctx context.Context, name string, e
462492
Content: util.Ptr(promptContent),
463493
})
464494

465-
// Call the evaluation model
466495
req := azuremodels.ChatCompletionOptions{
467496
Messages: messages,
468497
Model: eval.ModelID,
469498
Stream: false,
470499
}
471500

472-
resp, err := h.client.GetChatCompletionStream(ctx, req, h.org)
501+
evalResponseText, err := h.callModelWithRetry(ctx, req)
473502
if err != nil {
474503
return EvaluationResult{}, fmt.Errorf("failed to call evaluation model: %w", err)
475504
}
476505

477-
var evalResponse strings.Builder
478-
for {
479-
completion, err := resp.Reader.Read()
480-
if err != nil {
481-
if errors.Is(err, context.Canceled) || strings.Contains(err.Error(), "EOF") {
482-
break
483-
}
484-
return EvaluationResult{}, err
485-
}
486-
487-
for _, choice := range completion.Choices {
488-
if choice.Delta != nil && choice.Delta.Content != nil {
489-
evalResponse.WriteString(*choice.Delta.Content)
490-
}
491-
if choice.Message != nil && choice.Message.Content != nil {
492-
evalResponse.WriteString(*choice.Message.Content)
493-
}
494-
}
495-
}
496-
497506
// Match response to choices
498-
evalResponseText := strings.TrimSpace(strings.ToLower(evalResponse.String()))
507+
evalResponseText = strings.TrimSpace(strings.ToLower(evalResponseText))
499508
for _, choice := range eval.Choices {
500509
if strings.Contains(evalResponseText, strings.ToLower(choice.Choice)) {
501510
return EvaluationResult{

internal/azuremodels/azure_client.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ import (
1010
"io"
1111
"net/http"
1212
"slices"
13+
"strconv"
1314
"strings"
15+
"time"
1416

1517
"github.com/cli/go-gh/v2/pkg/api"
1618
"github.com/github/gh-models/internal/modelkey"
@@ -259,6 +261,42 @@ func (c *AzureClient) handleHTTPError(resp *http.Response) error {
259261
return err
260262
}
261263

264+
case http.StatusTooManyRequests:
265+
// Handle rate limiting
266+
retryAfter := time.Duration(0)
267+
268+
// Check for x-ratelimit-timeremaining header (in seconds)
269+
if timeRemainingStr := resp.Header.Get("x-ratelimit-timeremaining"); timeRemainingStr != "" {
270+
if seconds, parseErr := strconv.Atoi(timeRemainingStr); parseErr == nil {
271+
retryAfter = time.Duration(seconds) * time.Second
272+
}
273+
}
274+
275+
// Fall back to standard Retry-After header if x-ratelimit-timeremaining is not available
276+
if retryAfter == 0 {
277+
if retryAfterStr := resp.Header.Get("Retry-After"); retryAfterStr != "" {
278+
if seconds, parseErr := strconv.Atoi(retryAfterStr); parseErr == nil {
279+
retryAfter = time.Duration(seconds) * time.Second
280+
}
281+
}
282+
}
283+
284+
// Default to 60 seconds if no retry-after information is provided
285+
if retryAfter == 0 {
286+
retryAfter = 60 * time.Second
287+
}
288+
289+
body, _ := io.ReadAll(resp.Body)
290+
message := "rate limit exceeded"
291+
if len(body) > 0 {
292+
message = string(body)
293+
}
294+
295+
return &RateLimitError{
296+
RetryAfter: retryAfter,
297+
Message: strings.TrimSpace(message),
298+
}
299+
262300
default:
263301
_, err = sb.WriteString("unexpected response from the server: " + resp.Status)
264302
if err != nil {
@@ -286,3 +324,13 @@ func (c *AzureClient) handleHTTPError(resp *http.Response) error {
286324

287325
return errors.New(sb.String())
288326
}
327+
328+
// RateLimitError represents a rate limiting error from the API
329+
type RateLimitError struct {
330+
RetryAfter time.Duration
331+
Message string
332+
}
333+
334+
func (e *RateLimitError) Error() string {
335+
return fmt.Sprintf("rate limited: %s (retry after %v)", e.Message, e.RetryAfter)
336+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package azuremodels
2+
3+
import (
4+
"net/http"
5+
"strings"
6+
"testing"
7+
"time"
8+
)
9+
10+
func TestRateLimitError(t *testing.T) {
11+
err := &RateLimitError{
12+
RetryAfter: 30 * time.Second,
13+
Message: "Too many requests",
14+
}
15+
16+
expected := "rate limited: Too many requests (retry after 30s)"
17+
if err.Error() != expected {
18+
t.Errorf("Expected error message %q, got %q", expected, err.Error())
19+
}
20+
}
21+
22+
func TestHandleHTTPError_RateLimit(t *testing.T) {
23+
client := &AzureClient{}
24+
25+
tests := []struct {
26+
name string
27+
statusCode int
28+
headers map[string]string
29+
expectedRetryAfter time.Duration
30+
}{
31+
{
32+
name: "Rate limit with x-ratelimit-timeremaining header",
33+
statusCode: http.StatusTooManyRequests,
34+
headers: map[string]string{
35+
"x-ratelimit-timeremaining": "45",
36+
},
37+
expectedRetryAfter: 45 * time.Second,
38+
},
39+
{
40+
name: "Rate limit with Retry-After header",
41+
statusCode: http.StatusTooManyRequests,
42+
headers: map[string]string{
43+
"Retry-After": "60",
44+
},
45+
expectedRetryAfter: 60 * time.Second,
46+
},
47+
{
48+
name: "Rate limit with both headers - x-ratelimit-timeremaining takes precedence",
49+
statusCode: http.StatusTooManyRequests,
50+
headers: map[string]string{
51+
"x-ratelimit-timeremaining": "30",
52+
"Retry-After": "90",
53+
},
54+
expectedRetryAfter: 30 * time.Second,
55+
},
56+
{
57+
name: "Rate limit with no headers - default to 60s",
58+
statusCode: http.StatusTooManyRequests,
59+
headers: map[string]string{},
60+
expectedRetryAfter: 60 * time.Second,
61+
},
62+
}
63+
64+
for _, tt := range tests {
65+
t.Run(tt.name, func(t *testing.T) {
66+
resp := &http.Response{
67+
StatusCode: tt.statusCode,
68+
Header: make(http.Header),
69+
Body: &mockReadCloser{reader: strings.NewReader("rate limit exceeded")},
70+
}
71+
72+
for key, value := range tt.headers {
73+
resp.Header.Set(key, value)
74+
}
75+
76+
err := client.handleHTTPError(resp)
77+
78+
var rateLimitErr *RateLimitError
79+
if !isRateLimitError(err, &rateLimitErr) {
80+
t.Fatalf("Expected RateLimitError, got %T: %v", err, err)
81+
}
82+
83+
if rateLimitErr.RetryAfter != tt.expectedRetryAfter {
84+
t.Errorf("Expected RetryAfter %v, got %v", tt.expectedRetryAfter, rateLimitErr.RetryAfter)
85+
}
86+
})
87+
}
88+
}
89+
90+
// Helper function to check if error is a RateLimitError (mimics errors.As)
91+
func isRateLimitError(err error, target **RateLimitError) bool {
92+
if rateLimitErr, ok := err.(*RateLimitError); ok {
93+
*target = rateLimitErr
94+
return true
95+
}
96+
return false
97+
}
98+
99+
type mockReadCloser struct {
100+
reader *strings.Reader
101+
}
102+
103+
func (m *mockReadCloser) Read(p []byte) (n int, err error) {
104+
return m.reader.Read(p)
105+
}
106+
107+
func (m *mockReadCloser) Close() error {
108+
return nil
109+
}

0 commit comments

Comments
 (0)