diff --git a/apidump_integration_test.go b/apidump_integration_test.go index 9caaac6..29db3f2 100644 --- a/apidump_integration_test.go +++ b/apidump_integration_test.go @@ -25,7 +25,6 @@ import ( "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" "github.com/stretchr/testify/require" - "golang.org/x/tools/txtar" ) func openaiCfgWithAPIDump(url, key, dumpDir string) config.OpenAI { @@ -92,16 +91,9 @@ func TestAPIDump(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - arc := txtar.Parse(tc.fixture) - files := filesMap(arc) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureNonStreamingResponse) - - reqBody := files[fixtureRequest] - // Setup mock upstream server. - srv := newMockServer(ctx, t, files, nil, nil) - t.Cleanup(srv.Close) + fix := fixtures.Parse(t, tc.fixture) + srv := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) // Create temp dir for API dumps. dumpDir := t.TempDir() @@ -117,7 +109,7 @@ func TestAPIDump(t *testing.T) { } mockSrv.Start() - req := tc.createRequestFunc(t, mockSrv.URL, reqBody) + req := tc.createRequestFunc(t, mockSrv.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -163,7 +155,7 @@ func TestAPIDump(t *testing.T) { require.NoError(t, err) // Compare requests semantically (key order may differ). - require.JSONEq(t, string(dumpBody), string(reqBody), "request body JSON should match semantically") + require.JSONEq(t, string(dumpBody), string(fix.Request()), "request body JSON should match semantically") // Verify response dump contains expected HTTP response format. respDumpData, err := os.ReadFile(respDumpFile) @@ -177,7 +169,7 @@ func TestAPIDump(t *testing.T) { require.NoError(t, err) // Compare responses semantically (key order may differ). - expectedRespBody := files[fixtureNonStreamingResponse] + expectedRespBody := fix.NonStreaming() require.JSONEq(t, string(expectedRespBody), string(dumpRespBody), "response body JSON should match semantically") recorderClient.VerifyAllInterceptionsEnded(t) diff --git a/bridge_integration_test.go b/bridge_integration_test.go index ea4c0ae..068474c 100644 --- a/bridge_integration_test.go +++ b/bridge_integration_test.go @@ -1,7 +1,6 @@ package aibridge_test import ( - "bufio" "bytes" "context" "encoding/json" @@ -13,7 +12,6 @@ import ( "net/http/httptest" "strings" "sync" - "sync/atomic" "testing" "time" @@ -43,19 +41,11 @@ import ( "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/trace" "go.uber.org/goleak" - "golang.org/x/tools/txtar" ) var testTracer = otel.Tracer("forTesting") const ( - fixtureRequest = "request" - fixtureStreamingResponse = "streaming" - fixtureNonStreamingResponse = "non-streaming" - fixtureStreamingToolResponse = "streaming/tool-call" - fixtureNonStreamingToolResponse = "non-streaming/tool-call" - fixtureResponse = "response" - apiKey = "api-key" userID = "ae235cc1-9f8f-417d-a636-a7b170bac62e" ) @@ -91,31 +81,15 @@ func TestAnthropicMessages(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { t.Parallel() - arc := txtar.Parse(fixtures.AntSingleBuiltinTool) - t.Logf("%s: %s", t.Name(), arc.Comment) - - files := filesMap(arc) - require.Len(t, files, 3) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) - - reqBody := files[fixtureRequest] - - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", tc.streaming) - require.NoError(t, err) - reqBody = newBody - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - srv := newMockServer(ctx, t, files, nil, nil) - t.Cleanup(srv.Close) - recorderClient := &testutil.MockRecorder{} + fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + recorderClient := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(srv.URL, apiKey), nil)} + providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)} b, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) @@ -127,6 +101,8 @@ func TestAnthropicMessages(t *testing.T) { mockSrv.Start() // Make API call to aibridge for Anthropic /v1/messages + reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) + require.NoError(t, err) req := createAnthropicMessagesReq(t, mockSrv.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -180,10 +156,6 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Run("invalid config", func(t *testing.T) { t.Parallel() - arc := txtar.Parse(fixtures.AntSingleBuiltinTool) - files := filesMap(arc) - reqBody := files[fixtureRequest] - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) @@ -210,7 +182,7 @@ func TestAWSBedrockIntegration(t *testing.T) { } mockSrv.Start() - req := createAnthropicMessagesReq(t, mockSrv.URL, reqBody) + req := createAnthropicMessagesReq(t, mockSrv.URL, fixtures.Request(t, fixtures.AntSingleBuiltinTool)) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -227,60 +199,11 @@ func TestAWSBedrockIntegration(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), streaming), func(t *testing.T) { t.Parallel() - arc := txtar.Parse(fixtures.AntSingleBuiltinTool) - t.Logf("%s: %s", t.Name(), arc.Comment) - - files := filesMap(arc) - require.Len(t, files, 3) - require.Contains(t, files, fixtureRequest) - - reqBody := files[fixtureRequest] - - newBody, err := setJSON(reqBody, "stream", streaming) - require.NoError(t, err) - reqBody = newBody - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - var receivedModelName string - var requestCount int - - // Create a mock server that intercepts requests to capture model name and return fixtures. - srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount++ - t.Logf("Mock server received request #%d: %s %s (streaming=%v)", requestCount, r.Method, r.URL.Path, streaming) - t.Logf("Request headers: %v", r.Header) - - // AWS Bedrock encodes the model name in the URL path: /model/{model-id}/invoke or /model/{model-id}/invoke-with-response-stream. - // Extract the model name from the path. - pathParts := strings.Split(r.URL.Path, "/") - if len(pathParts) >= 3 && pathParts[1] == "model" { - receivedModelName = pathParts[2] - t.Logf("Extracted model name from path: %s", receivedModelName) - } - - // Return appropriate fixture response. - var respBody []byte - if streaming { - respBody = files[fixtureStreamingResponse] - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - } else { - respBody = files[fixtureNonStreamingResponse] - w.Header().Set("Content-Type", "application/json") - } - - w.WriteHeader(http.StatusOK) - _, _ = w.Write(respBody) - })) - - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - srv.Start() - t.Cleanup(srv.Close) + fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) // We define region here to validate that with Region & BaseURL defined, the latter takes precedence. bedrockCfg := &config.AWSBedrock{ @@ -289,14 +212,13 @@ func TestAWSBedrockIntegration(t *testing.T) { AccessKeySecret: "test-secret-key", Model: "danthropic", // This model should override the request's given one. SmallFastModel: "danthropic-mini", // Unused but needed for validation. - BaseURL: srv.URL, // Use the mock server. + BaseURL: upstream.URL, // Use the mock server. } recorderClient := &testutil.MockRecorder{} - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge( - ctx, []aibridge.Provider{provider.NewAnthropic(anthropicCfg(srv.URL, apiKey), bedrockCfg)}, + ctx, []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg)}, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) @@ -309,6 +231,8 @@ func TestAWSBedrockIntegration(t *testing.T) { // Make API call to aibridge for Anthropic /v1/messages, which will be routed via AWS Bedrock. // We override the AWS Bedrock client to route requests through our mock server. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + require.NoError(t, err) req := createAnthropicMessagesReq(t, mockBridgeSrv.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -324,8 +248,18 @@ func TestAWSBedrockIntegration(t *testing.T) { // Verify that Bedrock-specific model name was used in the request to the mock server // and the interception data. - require.Equal(t, requestCount, 1) - require.Equal(t, bedrockCfg.Model, receivedModelName) + received := upstream.ReceivedRequests() + require.Len(t, received, 1) + + // The Anthropic SDK's Bedrock middleware extracts "model" and "stream" + // from the JSON body and encodes them in the URL path. + // See: https://github.com/anthropics/anthropic-sdk-go/blob/4d669338f2041f3c60640b6dd317c4895dc71cd4/bedrock/bedrock.go#L247-L248 + pathParts := strings.Split(received[0].Path, "/") + require.True(t, len(pathParts) >= 3 && pathParts[1] == "model", "unexpected path: %s", received[0].Path) + require.Equal(t, bedrockCfg.Model, pathParts[2]) + require.False(t, gjson.GetBytes(received[0].Body, "model").Exists(), "model should be stripped from body") + require.False(t, gjson.GetBytes(received[0].Body, "stream").Exists(), "stream should be stripped from body") + interceptions := recorderClient.RecordedInterceptions() require.Len(t, interceptions, 1) require.Equal(t, interceptions[0].Model, bedrockCfg.Model) @@ -361,31 +295,15 @@ func TestOpenAIChatCompletions(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { t.Parallel() - arc := txtar.Parse(fixtures.OaiChatSingleBuiltinTool) - t.Logf("%s: %s", t.Name(), arc.Comment) - - files := filesMap(arc) - require.Len(t, files, 3) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) - - reqBody := files[fixtureRequest] - - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", tc.streaming) - require.NoError(t, err) - reqBody = newBody - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - srv := newMockServer(ctx, t, files, nil, nil) - t.Cleanup(srv.Close) - recorderClient := &testutil.MockRecorder{} + fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + recorderClient := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(srv.URL, apiKey))} + providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(upstream.URL, apiKey))} b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) @@ -395,7 +313,10 @@ func TestOpenAIChatCompletions(t *testing.T) { return aibcontext.AsActor(ctx, userID, nil) } mockSrv.Start() + // Make API call to aibridge for OpenAI /v1/chat/completions + reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) + require.NoError(t, err) req := createOpenAIChatCompletionsReq(t, mockSrv.URL, reqBody) client := &http.Client{} @@ -463,35 +384,13 @@ func TestOpenAIChatCompletions(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - arc := txtar.Parse(tc.fixture) - t.Logf("%s: %s", t.Name(), arc.Comment) - - files := filesMap(arc) - require.Len(t, files, 3) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureStreamingToolResponse) - - reqBody := files[fixtureRequest] - - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", true) - require.NoError(t, err) - reqBody = newBody - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - // Setup mock server with response mutator for multi-turn interaction. - srv := newMockServer(ctx, t, files, nil, func(reqCount uint32, resp []byte) []byte { - if reqCount == 1 { - // First request gets the tool call response - return resp - } - // Second request gets final response - return files[fixtureStreamingToolResponse] - }) - t.Cleanup(srv.Close) + // Setup mock server for multi-turn interaction. + // First request → tool call response, second → tool response. + fix := fixtures.Parse(t, tc.fixture) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix)) recorderClient := &testutil.MockRecorder{} @@ -501,7 +400,7 @@ func TestOpenAIChatCompletions(t *testing.T) { require.NoError(t, mcpMgr.Init(ctx)) logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(srv.URL, apiKey))} + providers := []aibridge.Provider{provider.NewOpenAI(openaiCfg(upstream.URL, apiKey))} b, err := aibridge.NewRequestBridge(t.Context(), providers, recorderClient, mcpMgr, logger, nil, testTracer) require.NoError(t, err) @@ -512,6 +411,9 @@ func TestOpenAIChatCompletions(t *testing.T) { } mockSrv.Start() + // Add the stream param to the request. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", true) + require.NoError(t, err) req := createOpenAIChatCompletionsReq(t, mockSrv.URL, reqBody) client := &http.Client{} @@ -699,42 +601,27 @@ func TestSimple(t *testing.T) { t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { t.Parallel() - arc := txtar.Parse(tc.fixture) - t.Logf("%s: %s", t.Name(), arc.Comment) - - files := filesMap(arc) - require.Len(t, files, 3) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) - - reqBody := files[fixtureRequest] - - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", streaming) - require.NoError(t, err) - reqBody = newBody - - // Given: a mock API server and a Bridge through which the requests will flow. ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - srv := newMockServer(ctx, t, files, func(r *http.Request) { - require.Equal(t, tc.expectedPath, r.URL.Path) - }, nil) - t.Cleanup(srv.Close) + + fix := fixtures.Parse(t, tc.fixture) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) recorderClient := &testutil.MockRecorder{} - b, err := tc.configureFunc(t, srv.URL+tc.basePath, recorderClient) + b, err := tc.configureFunc(t, upstream.URL+tc.basePath, recorderClient) require.NoError(t, err) - mockSrv := httptest.NewUnstartedServer(b) t.Cleanup(mockSrv.Close) + mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { return aibcontext.AsActor(ctx, userID, nil) } mockSrv.Start() + // When: calling the "API server" with the fixture's request body. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + require.NoError(t, err) req := tc.createRequest(t, mockSrv.URL, reqBody) req.Header.Set("User-Agent", tc.userAgent) client := &http.Client{} @@ -743,6 +630,11 @@ func TestSimple(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) defer resp.Body.Close() + // Then: I expect the upstream request to have the correct path. + received := upstream.ReceivedRequests() + require.Len(t, received, 1) + require.Equal(t, tc.expectedPath, received[0].Path) + // Then: I expect a non-empty response. bodyBytes, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -780,11 +672,6 @@ func TestSimple(t *testing.T) { } } -func setJSON(in []byte, key string, val bool) ([]byte, error) { - out, err := sjson.Set(string(in), key, val) - return []byte(out), err -} - func TestFallthrough(t *testing.T) { t.Parallel() @@ -863,33 +750,10 @@ func TestFallthrough(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - arc := txtar.Parse(tc.fixture) - t.Logf("%s: %s", t.Name(), arc.Comment) - - files := filesMap(arc) - require.Contains(t, files, fixtureResponse) - expectedPath := tc.expectedUpstreamPath - - var receivedHeaders *http.Header - respBody := files[fixtureResponse] - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != expectedPath { - t.Errorf("unexpected request path: %q", r.URL.Path) - t.FailNow() - } - - receivedHeaders = &r.Header - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(respBody) - })) - t.Cleanup(upstream.Close) - + fix := fixtures.Parse(t, tc.fixture) + upstream := testutil.NewMockUpstream(t, t.Context(), testutil.NewFixtureResponse(fix)) recorderClient := &testutil.MockRecorder{} - - upstreamURL := upstream.URL + tc.basePath - provider, bridge := tc.configureFunc(upstreamURL, recorderClient) + provider, bridge := tc.configureFunc(upstream.URL+tc.basePath, recorderClient) bridgeSrv := httptest.NewUnstartedServer(bridge) bridgeSrv.Config.BaseContext = func(_ net.Listener) context.Context { @@ -907,9 +771,12 @@ func TestFallthrough(t *testing.T) { require.Equal(t, http.StatusOK, resp.StatusCode) - // Ensure that the API key was sent. - require.NotNil(t, receivedHeaders) - require.Contains(t, receivedHeaders.Get(provider.AuthHeader()), apiKey) + // Verify upstream received the request at the expected path + // with the API key header. + received := upstream.ReceivedRequests() + require.Len(t, received, 1) + require.Equal(t, tc.expectedUpstreamPath, received[0].Path) + require.Contains(t, received[0].Header.Get(provider.AuthHeader()), apiKey) gotBytes, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -918,7 +785,7 @@ func TestFallthrough(t *testing.T) { var got any var exp any require.NoError(t, json.Unmarshal(gotBytes, &got)) - require.NoError(t, json.Unmarshal(respBody, &exp)) + require.NoError(t, json.Unmarshal(fix.NonStreaming(), &exp)) require.EqualValues(t, exp, got) }) } @@ -1151,19 +1018,10 @@ func TestOpenAIInjectedTools(t *testing.T) { // upstream request contains the assistant's tool_use and user's tool_result messages // appended by the inner agentic loop. If the raw payload is not kept in sync with // the structured messages, the second request will be identical to the first. -func anthropicToolResultValidator(t *testing.T) func(*http.Request) { +func anthropicToolResultValidator(t *testing.T) func(*http.Request, []byte) { t.Helper() - var reqNum atomic.Uint32 - return func(r *http.Request) { - raw, err := io.ReadAll(r.Body) - require.NoError(t, err) - r.Body = io.NopCloser(bytes.NewReader(raw)) - - if reqNum.Add(1) != 2 { - return - } - + return func(_ *http.Request, raw []byte) { messages := gjson.GetBytes(raw, "messages").Array() // After the agentic loop the messages must contain at minimum: @@ -1202,19 +1060,10 @@ func anthropicToolResultValidator(t *testing.T) func(*http.Request) { // openaiChatToolResultValidator returns a request validator that asserts the second // upstream request contains the assistant's tool_calls and a role=tool result message // appended by the inner agentic loop. -func openaiChatToolResultValidator(t *testing.T) func(*http.Request) { +func openaiChatToolResultValidator(t *testing.T) func(*http.Request, []byte) { t.Helper() - var reqNum atomic.Uint32 - return func(r *http.Request) { - raw, err := io.ReadAll(r.Body) - require.NoError(t, err) - r.Body = io.NopCloser(bytes.NewReader(raw)) - - if reqNum.Add(1) != 2 { - return - } - + return func(_ *http.Request, raw []byte) { messages := gjson.GetBytes(raw, "messages").Array() // After the agentic loop the messages must contain at minimum: @@ -1239,48 +1088,20 @@ func openaiChatToolResultValidator(t *testing.T) func(*http.Request) { } // setupInjectedToolTest abstracts the common aspects required for the Test*InjectedTools tests. -func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request, requestValidatorFn func(*http.Request)) (*testutil.MockRecorder, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) { +func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configureFn configureFunc, createRequestFn func(*testing.T, string, []byte) *http.Request, toolRequestValidatorFn func(*http.Request, []byte)) (*testutil.MockRecorder, *callAccumulator, map[string]mcp.ServerProxier, *http.Response) { t.Helper() - arc := txtar.Parse(fixture) - t.Logf("%s: %s", t.Name(), arc.Comment) - - files := filesMap(arc) - require.Len(t, files, 5) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) - require.Contains(t, files, fixtureStreamingToolResponse) - require.Contains(t, files, fixtureNonStreamingToolResponse) - - reqBody := files[fixtureRequest] - - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", streaming) - require.NoError(t, err) - reqBody = newBody - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - // Setup mock server with response mutator for multi-turn interaction. - mockSrv := newMockServer(ctx, t, files, requestValidatorFn, func(reqCount uint32, resp []byte) []byte { - if reqCount == 1 { - return resp // First request gets the normal response (with tool call). - } - - if reqCount > 2 { - // This should not happen in single injected tool tests. - return resp - } + fix := fixtures.Parse(t, fixture) - // Second request gets the tool response. - if streaming { - return files[fixtureStreamingToolResponse] - } - return files[fixtureNonStreamingToolResponse] - }) - t.Cleanup(mockSrv.Close) + // Setup mock server for multi-turn interaction. + // First request → tool call response, second → tool response. + firstResp := testutil.NewFixtureResponse(fix) + toolResp := testutil.NewFixtureToolResponse(fix) + toolResp.OnRequest = toolRequestValidatorFn + upstream := testutil.NewMockUpstream(t, ctx, firstResp, toolResp) recorderClient := &testutil.MockRecorder{} @@ -1290,7 +1111,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu // Configure the bridge with injected tools. mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) require.NoError(t, mcpMgr.Init(ctx)) - b, err := configureFn(mockSrv.URL, recorderClient, mcpMgr) + b, err := configureFn(upstream.URL, recorderClient, mcpMgr) require.NoError(t, err) // Invoke request to mocked API via aibridge. @@ -1301,6 +1122,10 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu bridgeSrv.Start() t.Cleanup(bridgeSrv.Close) + // Add the stream param to the request. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + require.NoError(t, err) + req := createRequestFn(t, bridgeSrv.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -1312,7 +1137,7 @@ func setupInjectedToolTest(t *testing.T, fixture []byte, streaming bool, configu // We must ALWAYS have 2 calls to the bridge for injected tool tests. require.Eventually(t, func() bool { - return mockSrv.callCount.Load() == 2 + return upstream.Calls.Load() == 2 }, time.Second*10, time.Millisecond*50) return recorderClient, acc, mcpProxiers, resp @@ -1379,28 +1204,10 @@ func TestErrorHandling(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - arc := txtar.Parse(tc.fixture) - t.Logf("%s: %s", t.Name(), arc.Comment) - - files := filesMap(arc) - require.Len(t, files, 3) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) - - reqBody := files[fixtureRequest] - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", streaming) - require.NoError(t, err) - reqBody = newBody - - // Setup mock server. - mockResp := files[fixtureStreamingResponse] - if !streaming { - mockResp = files[fixtureNonStreamingResponse] - } - mockSrv := newMockHTTPReflector(ctx, t, mockResp) - t.Cleanup(mockSrv.Close) + // Setup mock server. Error fixtures contain raw HTTP + // responses that may cause the bridge to retry. + fix := fixtures.Parse(t, tc.fixture) + mockSrv := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) recorderClient := &testutil.MockRecorder{} @@ -1415,6 +1222,10 @@ func TestErrorHandling(t *testing.T) { bridgeSrv.Start() t.Cleanup(bridgeSrv.Close) + // Add the stream param to the request. + reqBody, err := sjson.SetBytes(fix.Request(), "stream", streaming) + require.NoError(t, err) + req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody) resp, err := http.DefaultClient.Do(req) t.Cleanup(func() { _ = resp.Body.Close() }) @@ -1489,24 +1300,14 @@ func TestErrorHandling(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - arc := txtar.Parse(tc.fixture) - t.Logf("%s: %s", t.Name(), arc.Comment) - - files := filesMap(arc) - require.Len(t, files, 2) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - - reqBody := files[fixtureRequest] - // Setup mock server. - mockSrv := newMockServer(ctx, t, files, nil, nil) - mockSrv.statusCode = http.StatusInternalServerError - t.Cleanup(mockSrv.Close) + fix := fixtures.Parse(t, tc.fixture) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream.StatusCode = http.StatusInternalServerError recorderClient := &testutil.MockRecorder{} - b, err := tc.configureFunc(mockSrv.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) + b, err := tc.configureFunc(upstream.URL, recorderClient, mcp.NewServerProxyManager(nil, testTracer)) require.NoError(t, err) // Invoke request to mocked API via aibridge. @@ -1517,7 +1318,7 @@ func TestErrorHandling(t *testing.T) { bridgeSrv.Start() t.Cleanup(bridgeSrv.Close) - req := tc.createRequestFunc(t, bridgeSrv.URL, reqBody) + req := tc.createRequestFunc(t, bridgeSrv.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) t.Cleanup(func() { _ = resp.Body.Close() }) require.NoError(t, err) @@ -1579,49 +1380,18 @@ func TestStableRequestEncoding(t *testing.T) { mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) require.NoError(t, mcpMgr.Init(ctx)) - arc := txtar.Parse(tc.fixture) - t.Logf("%s: %s", t.Name(), arc.Comment) - - files := filesMap(arc) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureNonStreamingResponse) - - var ( - reference []byte - reqCount atomic.Int32 - ) - - // Create a mock server that captures and compares request bodies. - mockSrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - reqCount.Add(1) - - // Capture the raw request body. - raw, err := io.ReadAll(r.Body) - defer r.Body.Close() - require.NoError(t, err) - require.NotEmpty(t, raw) - - // Store the first instance as the reference value. - if reference == nil { - reference = raw - } else { - // Compare all subsequent requests to the reference. - assert.JSONEq(t, string(reference), string(raw)) - } + fix := fixtures.Parse(t, tc.fixture) - // Return a valid API response. - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(files[fixtureNonStreamingResponse]) - })) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx + // Create a mock upstream that serves the same blocking response for each request. + count := 10 + responses := make([]testutil.UpstreamResponse, count) + for i := range count { + responses[i] = testutil.NewFixtureResponse(fix) } - mockSrv.Start() - t.Cleanup(mockSrv.Close) + upstream := testutil.NewMockUpstream(t, ctx, responses...) recorder := &testutil.MockRecorder{} - bridge, err := tc.configureFunc(mockSrv.URL, recorder, mcpMgr) + bridge, err := tc.configureFunc(upstream.URL, recorder, mcpMgr) require.NoError(t, err) // Invoke request to mocked API via aibridge. @@ -1633,9 +1403,8 @@ func TestStableRequestEncoding(t *testing.T) { t.Cleanup(bridgeSrv.Close) // Make multiple requests and verify they all have identical payloads. - count := 10 for range count { - req := tc.createRequestFunc(t, bridgeSrv.URL, files[fixtureRequest]) + req := tc.createRequestFunc(t, bridgeSrv.URL, fix.Request()) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -1643,7 +1412,13 @@ func TestStableRequestEncoding(t *testing.T) { _ = resp.Body.Close() } - require.EqualValues(t, count, reqCount.Load()) + // All upstream request bodies should be identical. + received := upstream.ReceivedRequests() + require.Len(t, received, count) + reference := string(received[0].Body) + for _, r := range received[1:] { + assert.JSONEq(t, reference, string(r.Body)) + } }) } } @@ -1738,45 +1513,12 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { } require.NoError(t, mcpMgr.Init(ctx)) - arc := txtar.Parse(fixtures.AntSimple) - files := filesMap(arc) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureNonStreamingResponse) - - // Prepare request body with tool_choice set. - var reqJSON map[string]any - require.NoError(t, json.Unmarshal(files[fixtureRequest], &reqJSON)) - if tc.toolChoice != nil { - reqJSON["tool_choice"] = tc.toolChoice - } - reqBody, err := json.Marshal(reqJSON) - require.NoError(t, err) - - var receivedRequest map[string]any - - // Create a mock server that captures the request body sent upstream. - mockSrv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Capture the raw request body. - raw, err := io.ReadAll(r.Body) - defer r.Body.Close() - require.NoError(t, err) - - require.NoError(t, json.Unmarshal(raw, &receivedRequest)) - - // Return a valid API response. - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(files[fixtureNonStreamingResponse]) - })) - mockSrv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - mockSrv.Start() - t.Cleanup(mockSrv.Close) + fix := fixtures.Parse(t, fixtures.AntSimple) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) recorder := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(mockSrv.URL, apiKey), nil)} + providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)} bridge, err := aibridge.NewRequestBridge(ctx, providers, recorder, mcpMgr, logger, nil, testTracer) require.NoError(t, err) @@ -1788,6 +1530,10 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { bridgeSrv.Start() t.Cleanup(bridgeSrv.Close) + // Prepare request body with tool_choice set. + reqBody, err := sjson.SetBytes(fix.Request(), "tool_choice", tc.toolChoice) + require.NoError(t, err) + req := createAnthropicMessagesReq(t, bridgeSrv.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -1796,7 +1542,10 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { _ = resp.Body.Close() // Verify tool_choice in the upstream request. - require.NotNil(t, receivedRequest) + received := upstream.ReceivedRequests() + require.Len(t, received, 1) + var receivedRequest map[string]any + require.NoError(t, json.Unmarshal(received[0].Body, &receivedRequest)) toolChoice, ok := receivedRequest["tool_choice"].(map[string]any) require.True(t, ok, "expected tool_choice in upstream request") @@ -1825,39 +1574,21 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) { func TestThinkingAdaptiveIsPreserved(t *testing.T) { t.Parallel() - arc := txtar.Parse(fixtures.AntSimple) - files := filesMap(arc) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) + fix := fixtures.Parse(t, fixtures.AntSimple) for _, streaming := range []bool{true, false} { t.Run(fmt.Sprintf("streaming=%v", streaming), func(t *testing.T) { t.Parallel() - // Inject adaptive thinking into the fixture request. - reqBody, err := sjson.SetBytes(files[fixtureRequest], "thinking", map[string]string{"type": "adaptive"}) - require.NoError(t, err) - reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) - require.NoError(t, err) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - var receivedRequest []byte - // Create a mock server that captures the request body sent upstream. - srv := newMockServer(ctx, t, files, func(r *http.Request) { - raw, err := io.ReadAll(r.Body) - require.NoError(t, err) - r.Body = io.NopCloser(bytes.NewReader(raw)) - receivedRequest = raw - }, nil) - t.Cleanup(srv.Close) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) recorderClient := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug) - providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(srv.URL, apiKey), nil)} + providers := []aibridge.Provider{provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil)} bridge, err := aibridge.NewRequestBridge(ctx, providers, recorderClient, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err) @@ -1868,6 +1599,12 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { bridgeSrv.Start() t.Cleanup(bridgeSrv.Close) + // Inject adaptive thinking into the fixture request. + reqBody, err := sjson.SetBytes(fix.Request(), "thinking", map[string]string{"type": "adaptive"}) + require.NoError(t, err) + reqBody, err = sjson.SetBytes(reqBody, "stream", streaming) + require.NoError(t, err) + req := createAnthropicMessagesReq(t, bridgeSrv.URL, reqBody) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) @@ -1876,8 +1613,9 @@ func TestThinkingAdaptiveIsPreserved(t *testing.T) { _ = resp.Body.Close() // Verify the thinking field was preserved in the upstream request. - require.NotEmpty(t, receivedRequest) - assert.Equal(t, "adaptive", gjson.GetBytes(receivedRequest, "thinking.type").Str) + received := upstream.ReceivedRequests() + require.Len(t, received, 1) + assert.Equal(t, "adaptive", gjson.GetBytes(received[0].Body, "thinking.type").Str) }) } } @@ -1929,26 +1667,11 @@ func TestEnvironmentDoNotLeak(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // NOTE: Cannot use t.Parallel() here because t.Setenv requires sequential execution. - arc := txtar.Parse(tc.fixture) - files := filesMap(arc) - reqBody := files[fixtureRequest] - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - // Track headers received by the upstream server. - var receivedHeaders http.Header - srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - receivedHeaders = r.Header.Clone() - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(files[fixtureNonStreamingResponse]) - })) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - srv.Start() - t.Cleanup(srv.Close) + fix := fixtures.Parse(t, tc.fixture) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) // Set environment variables that the SDK would automatically read. // These should NOT leak into upstream requests. @@ -1957,7 +1680,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { } recorderClient := &testutil.MockRecorder{} - b, err := tc.configureFunc(srv.URL, recorderClient) + b, err := tc.configureFunc(upstream.URL, recorderClient) require.NoError(t, err) mockSrv := httptest.NewUnstartedServer(b) @@ -1967,7 +1690,7 @@ func TestEnvironmentDoNotLeak(t *testing.T) { } mockSrv.Start() - req := tc.createRequest(t, mockSrv.URL, reqBody) + req := tc.createRequest(t, mockSrv.URL, fix.Request()) client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) @@ -1975,8 +1698,9 @@ func TestEnvironmentDoNotLeak(t *testing.T) { defer resp.Body.Close() // Verify that environment values did not leak. - require.NotNil(t, receivedHeaders) - require.Empty(t, receivedHeaders.Get(tc.headerName)) + received := upstream.ReceivedRequests() + require.Len(t, received, 1) + require.Empty(t, received[0].Header.Get(tc.headerName)) }) } } @@ -2066,17 +1790,6 @@ func TestActorHeaders(t *testing.T) { t.Run(fmt.Sprintf("%s/streaming=%v/send-headers=%v", tc.name, tc.streaming, send), func(t *testing.T) { t.Parallel() - logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) - - arc := txtar.Parse(tc.fixture) - files := filesMap(arc) - reqBody := files[fixtureRequest] - - // Add the stream param to the request. - newBody, err := setJSON(reqBody, "stream", tc.streaming) - require.NoError(t, err) - reqBody = newBody - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) @@ -2094,6 +1807,7 @@ func TestActorHeaders(t *testing.T) { rec := &testutil.MockRecorder{} provider := tc.createProviderFn(srv.URL, apiKey, send) + logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) b, err := aibridge.NewRequestBridge(t.Context(), []aibridge.Provider{provider}, rec, mcp.NewServerProxyManager(nil, testTracer), logger, nil, testTracer) require.NoError(t, err, "failed to create handler") @@ -2110,6 +1824,10 @@ func TestActorHeaders(t *testing.T) { } mockSrv.Start() + // Add the stream param to the request. + reqBody, err := sjson.SetBytes(fixtures.Request(t, tc.fixture), "stream", tc.streaming) + require.NoError(t, err) + req := tc.createRequest(t, mockSrv.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -2153,20 +1871,6 @@ func calculateTotalOutputTokens(in []*recorder.TokenUsageRecord) int64 { return total } -type archiveFileMap map[string][]byte - -func filesMap(archive *txtar.Archive) archiveFileMap { - if len(archive.Files) == 0 { - return nil - } - - out := make(archiveFileMap, len(archive.Files)) - for _, f := range archive.Files { - out[f.Name] = f.Data - } - return out -} - func createAnthropicMessagesReq(t *testing.T, baseURL string, input []byte) *http.Request { t.Helper() @@ -2187,248 +1891,6 @@ func createOpenAIChatCompletionsReq(t *testing.T, baseURL string, input []byte) return req } -type mockHTTPReflector struct { - *httptest.Server -} - -func newMockHTTPReflector(ctx context.Context, t *testing.T, resp []byte) *mockHTTPReflector { - ref := &mockHTTPReflector{} - - srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - mock, err := http.ReadResponse(bufio.NewReader(bytes.NewBuffer(resp)), r) - require.NoError(t, err) - defer mock.Body.Close() - - // Copy headers from the mocked response. - for key, values := range mock.Header { - for _, value := range values { - w.Header().Add(key, value) - } - } - - // Write the status code. - w.WriteHeader(mock.StatusCode) - - // Copy the body. - _, err = io.Copy(w, mock.Body) - require.NoError(t, err) - })) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - - srv.Start() - t.Cleanup(srv.Close) - - ref.Server = srv - return ref -} - -// TODO: replace this with mockHTTPReflector. -type mockServer struct { - *httptest.Server - - callCount atomic.Uint32 - - statusCode int -} - -func newMockServer(ctx context.Context, t *testing.T, files archiveFileMap, requestValidatorFn func(*http.Request), responseMutatorFn func(reqCount uint32, resp []byte) []byte) *mockServer { - t.Helper() - - ms := &mockServer{} - srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if requestValidatorFn != nil { - requestValidatorFn(r) - } - - statusCode := http.StatusOK - if ms.statusCode != 0 { - statusCode = ms.statusCode - } - - ms.callCount.Add(1) - - body, err := io.ReadAll(r.Body) - defer r.Body.Close() - require.NoError(t, err) - - // Validate request body based on endpoint. - var validationErr error - if strings.Contains(r.URL.Path, "/chat/completions") { - validationErr = validateOpenAIChatCompletionRequest(body) - } else if strings.Contains(r.URL.Path, "/responses") { - validationErr = validateOpenAIResponsesRequest(body) - } else if strings.Contains(r.URL.Path, "/messages") { - validationErr = validateAnthropicMessagesRequest(body) - } - - // If validation failed, return error response - if validationErr != nil { - // Return HTTP error response - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusBadRequest) - errResp := map[string]any{ - "error": map[string]any{ - "message": fmt.Sprintf("Request #%d validation failed: %v", ms.callCount.Load(), validationErr), - "type": "invalid_request_error", - }, - } - json.NewEncoder(w).Encode(errResp) - - // Mark test as failed with detailed message - t.Errorf("Request #%d validation failed: %v\n\nRequest body:\n%s", - ms.callCount.Load(), validationErr, string(body)) - return - } - - type msg struct { - Stream bool `json:"stream"` - } - var reqMsg msg - require.NoError(t, json.Unmarshal(body, &reqMsg)) - - if !reqMsg.Stream && !strings.HasSuffix(r.URL.Path, "invoke-with-response-stream") { - resp := files[fixtureNonStreamingResponse] - if responseMutatorFn != nil { - resp = responseMutatorFn(ms.callCount.Load(), resp) - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(statusCode) - w.Write(resp) - return - } - - w.Header().Set("Content-Type", "text/event-stream") - w.Header().Set("Cache-Control", "no-cache") - w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") - - resp := files[fixtureStreamingResponse] - if responseMutatorFn != nil { - resp = responseMutatorFn(ms.callCount.Load(), resp) - } - - scanner := bufio.NewScanner(bytes.NewReader(resp)) - flusher, ok := w.(http.Flusher) - if !ok { - http.Error(w, "Streaming unsupported", http.StatusInternalServerError) - return - } - - for scanner.Scan() { - line := scanner.Text() - - fmt.Fprintf(w, "%s\n", line) - flusher.Flush() - } - - if err := scanner.Err(); err != nil { - http.Error(w, fmt.Sprintf("Error reading fixture: %v", err), http.StatusInternalServerError) - return - } - })) - - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - - srv.Start() - t.Cleanup(srv.Close) - - ms.Server = srv - return ms -} - -// validateOpenAIChatCompletionRequest validates that an OpenAI chat completion request -// has all required fields. -// According to OpenAI documentation https://platform.openai.com/docs/api-reference/chat/create, -// the "model" and "messages" fields are required. -// Returns an error if validation fails. -func validateOpenAIChatCompletionRequest(body []byte) error { - var req openai.ChatCompletionNewParams - if err := json.Unmarshal(body, &req); err != nil { - return fmt.Errorf("request should unmarshal into ChatCompletionNewParams: %w", err) - } - - // Collect all validation errors - var errs []string - if req.Model == "" { - errs = append(errs, "model field is required but empty") - } - if len(req.Messages) == 0 { - errs = append(errs, "messages field is required but empty") - } - - if len(errs) > 0 { - return fmt.Errorf("validation failed: %s", strings.Join(errs, "; ")) - } - return nil -} - -// validateOpenAIResponsesRequest validates that an OpenAI responses request -// has all required fields. -// According to OpenAI documentation https://platform.openai.com/docs/api-reference/responses/create, -// no fields are strictly required. However, we check "model" and "input" fields -// as these should usually be set in request bodies. -// Returns an error if validation fails. -func validateOpenAIResponsesRequest(body []byte) error { - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err != nil { - return fmt.Errorf("request should unmarshal into valid JSON: %w", err) - } - - // Collect all validation errors - var errs []string - - // Validate model field exists - model, hasModel := reqBody["model"].(string) - if !hasModel || model == "" { - errs = append(errs, "model field is required but empty or missing") - } - - // Validate input field exists - if _, hasInput := reqBody["input"]; !hasInput { - errs = append(errs, "input field is required but missing") - } - - if len(errs) > 0 { - return fmt.Errorf("validation failed: %s", strings.Join(errs, "; ")) - } - return nil -} - -// validateAnthropicMessagesRequest validates that an Anthropic messages request -// has all required fields. -// According to the Anthropic Go SDK https://github.com/anthropics/anthropic-sdk-go, -// the "model", "messages", and "max_tokens" fields are required, as indicated by -// the `required` struct tags in MessageNewParams. -// Returns an error if validation fails. -func validateAnthropicMessagesRequest(body []byte) error { - var req anthropic.MessageNewParams - if err := json.Unmarshal(body, &req); err != nil { - return fmt.Errorf("request should unmarshal into MessageNewParams: %w", err) - } - - // Collect all validation errors - var errs []string - if req.Model == "" { - errs = append(errs, "model field is required but empty") - } - if len(req.Messages) == 0 { - errs = append(errs, "messages field is required but empty") - } - if req.MaxTokens == 0 { - errs = append(errs, "max_tokens field is required but zero") - } - - if len(errs) > 0 { - return fmt.Errorf("validation failed: %s", strings.Join(errs, "; ")) - } - return nil -} - const mockToolName = "coder_list_workspaces" // callAccumulator tracks all tool invocations by name and each instance's arguments. diff --git a/fixtures/anthropic/fallthrough.txtar b/fixtures/anthropic/fallthrough.txtar index 6d9801d..94e71c4 100644 --- a/fixtures/anthropic/fallthrough.txtar +++ b/fixtures/anthropic/fallthrough.txtar @@ -1,6 +1,6 @@ API endpoints not explicitly handled will fallthrough to upstream via reverse-proxy. --- response -- +-- non-streaming -- { "data": [ { diff --git a/fixtures/fixtures.go b/fixtures/fixtures.go index 2370b4e..3c15047 100644 --- a/fixtures/fixtures.go +++ b/fixtures/fixtures.go @@ -4,6 +4,7 @@ import ( _ "embed" "testing" + "github.com/stretchr/testify/require" "golang.org/x/tools/txtar" ) @@ -126,15 +127,85 @@ var ( OaiResponsesStreamingWrongResponseFormat []byte ) -func Request(t *testing.T, fixture []byte) []byte { +// Section name constants matching the file names used in txtar fixtures. +const ( + fileRequest = "request" + fileStreamingResponse = "streaming" + fileNonStreamingResponse = "non-streaming" + fileStreamingToolCall = "streaming/tool-call" + fileNonStreamingToolCall = "non-streaming/tool-call" + + // Exported aliases so callers can check [Fixture.Has] before calling a + // getter that would otherwise fail the test. + SectionStreaming = fileStreamingResponse + SectionNonStreaming = fileNonStreamingResponse + SectionStreamingToolCall = fileStreamingToolCall + SectionNonStreamToolCall = fileNonStreamingToolCall +) + +// Fixture holds the named sections of a parsed txtar test fixture. +type Fixture struct { + sections map[string][]byte + t *testing.T +} + +// Has reports whether the fixture contains the named section. +func (f Fixture) Has(name string) bool { + _, ok := f.sections[name] + return ok +} + +func (f Fixture) Request() []byte { + f.t.Helper() + v, ok := f.sections[fileRequest] + require.True(f.t, ok, "fixture archive missing %q section", fileRequest) + return v +} + +func (f Fixture) Streaming() []byte { + f.t.Helper() + v, ok := f.sections[fileStreamingResponse] + require.True(f.t, ok, "fixture archive missing %q section", fileStreamingResponse) + return v +} + +func (f Fixture) NonStreaming() []byte { + f.t.Helper() + v, ok := f.sections[fileNonStreamingResponse] + require.True(f.t, ok, "fixture archive missing %q section", fileNonStreamingResponse) + return v +} + +func (f Fixture) StreamingToolCall() []byte { + f.t.Helper() + v, ok := f.sections[fileStreamingToolCall] + require.True(f.t, ok, "fixture archive missing %q section", fileStreamingToolCall) + return v +} + +func (f Fixture) NonStreamingToolCall() []byte { + f.t.Helper() + v, ok := f.sections[fileNonStreamingToolCall] + require.True(f.t, ok, "fixture archive missing %q section", fileNonStreamingToolCall) + return v +} + +// Parse parses raw txtar data into a [Fixture]. +func Parse(t *testing.T, data []byte) Fixture { t.Helper() - archive := txtar.Parse(fixture) + archive := txtar.Parse(data) + require.NotEmpty(t, archive.Files, "fixture archive has no files") + + sections := make(map[string][]byte, len(archive.Files)) for _, f := range archive.Files { - if f.Name == "request" { - return f.Data - } + sections[f.Name] = f.Data } - t.Fatal("request not found in fixture") - return []byte{} + return Fixture{sections: sections, t: t} +} + +// Request extracts the "request" fixture from raw txtar data. +func Request(t *testing.T, fixture []byte) []byte { + t.Helper() + return Parse(t, fixture).Request() } diff --git a/fixtures/openai/chatcompletions/fallthrough.txtar b/fixtures/openai/chatcompletions/fallthrough.txtar index 09812cb..41bcf34 100644 --- a/fixtures/openai/chatcompletions/fallthrough.txtar +++ b/fixtures/openai/chatcompletions/fallthrough.txtar @@ -1,6 +1,6 @@ API endpoints not explicitly handled will fallthrough to upstream via reverse-proxy. --- response -- +-- non-streaming -- { "object": "list", "data": [ diff --git a/internal/testutil/upstream.go b/internal/testutil/upstream.go new file mode 100644 index 0000000..bb935a8 --- /dev/null +++ b/internal/testutil/upstream.go @@ -0,0 +1,309 @@ +package testutil + +import ( + "bufio" + "bytes" + "cmp" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + + "github.com/anthropics/anthropic-sdk-go" + "github.com/coder/aibridge/fixtures" + "github.com/openai/openai-go/v3" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// UpstreamResponse defines a single response that MockUpstream will replay +// for one incoming request. Use [NewFixtureResponse] or [NewFixtureToolResponse] to +// construct one from a parsed txtar archive. +type UpstreamResponse struct { + Streaming []byte // returned when the request has "stream": true. + Blocking []byte // returned for non-streaming requests. + + // OnRequest, if non-nil, is called with the incoming request and body + // before the response is sent. Use it for per-request assertions. + OnRequest func(r *http.Request, body []byte) +} + +// NewFixtureResponse creates an UpstreamResponse from a parsed fixture archive. +// It reads whichever of 'streaming' and 'non-streaming' sections exist; +// not every fixture has both (e.g. error fixtures may only define one). +func NewFixtureResponse(fix fixtures.Fixture) UpstreamResponse { + var resp UpstreamResponse + if fix.Has(fixtures.SectionStreaming) { + resp.Streaming = fix.Streaming() + } + if fix.Has(fixtures.SectionNonStreaming) { + resp.Blocking = fix.NonStreaming() + } + return resp +} + +// NewFixtureToolResponse creates an UpstreamResponse from the tool-call fixture files. +// It reads whichever of 'streaming/tool-call' and 'non-streaming/tool-call' +// sections exist. +func NewFixtureToolResponse(fix fixtures.Fixture) UpstreamResponse { + var resp UpstreamResponse + if fix.Has(fixtures.SectionStreamingToolCall) { + resp.Streaming = fix.StreamingToolCall() + } + if fix.Has(fixtures.SectionNonStreamToolCall) { + resp.Blocking = fix.NonStreamingToolCall() + } + return resp +} + +// ReceivedRequest captures the details of a single request handled by MockUpstream. +type ReceivedRequest struct { + Method string + Path string + Header http.Header + Body []byte +} + +// MockUpstream replays txtar fixture responses, validates incoming request +// bodies, and counts calls. It stands in for a real AI provider API +// (Anthropic, OpenAI) during integration tests. +type MockUpstream struct { + *httptest.Server + + // Calls is incremented atomically on every request. + Calls atomic.Uint32 + + // StatusCode overrides the HTTP status for non-streaming responses. + // Zero means 200. + StatusCode int + + // AllowOverflow disables the strict call-count check. When true, + // requests beyond the last response repeat that response, and the + // cleanup assertion only verifies that at least len(responses) + // requests were made. This is useful for error-response tests where + // the bridge may retry. + AllowOverflow bool + + mu sync.Mutex + requests []ReceivedRequest + + t *testing.T + responses []UpstreamResponse +} + +// ReceivedRequests returns a copy of all requests received so far. +func (ms *MockUpstream) ReceivedRequests() []ReceivedRequest { + ms.mu.Lock() + defer ms.mu.Unlock() + return append([]ReceivedRequest(nil), ms.requests...) +} + +// NewMockUpstream creates a started httptest.Server that replays fixture +// responses. Responses are returned in order: first call → first response. +// The test fails if the number of requests doesn't match the number of +// responses (when AllowOverflow is not set, default). +// +// srv := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) // simple +// srv := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix)) // multi-turn +func NewMockUpstream(t *testing.T, ctx context.Context, responses ...UpstreamResponse) *MockUpstream { + t.Helper() + require.NotEmpty(t, responses, "at least one UpstreamResponse required") + + ms := &MockUpstream{ + t: t, + responses: responses, + } + + srv := httptest.NewUnstartedServer(http.HandlerFunc(ms.handle)) + srv.Config.BaseContext = func(_ net.Listener) context.Context { return ctx } + srv.Start() + + t.Cleanup(func() { + srv.Close() + + // Verify the number of requests matches expectations. + calls := int(ms.Calls.Load()) + if ms.AllowOverflow { + require.LessOrEqual(t, len(ms.responses), calls, "too few requests, got: %v, want at least: %v", calls, len(ms.responses)) + } else { + require.Equal(t, len(ms.responses), calls, "unexpected number of requests, got: %v, want: %v", calls, len(ms.responses)) + } + }) + + ms.Server = srv + return ms +} + +func (ms *MockUpstream) handle(w http.ResponseWriter, r *http.Request) { + call := int(ms.Calls.Add(1) - 1) + + body, err := io.ReadAll(r.Body) + defer r.Body.Close() + require.NoError(ms.t, err) + + ms.mu.Lock() + ms.requests = append(ms.requests, ReceivedRequest{ + Method: r.Method, + Path: r.URL.Path, + Header: r.Header.Clone(), + Body: append([]byte(nil), body...), + }) + ms.mu.Unlock() + + validateRequest(ms.t, call, r.URL.Path, body) + + resp := ms.responseForCall(call) + if resp.OnRequest != nil { + resp.OnRequest(r, body) + } + + if isStreaming(body, r.URL.Path) { + require.NotEmpty(ms.t, resp.Streaming, "response #%d: Streaming body is empty (fixture missing streaming response?)", call+1) + if isRawHTTPResponse(resp.Streaming) { + ms.writeRawHTTPResponse(w, r, resp.Streaming) + return + } + ms.writeSSE(w, resp.Streaming) + return + } + + require.NotEmpty(ms.t, resp.Blocking, "response #%d: Blocking body is empty (fixture missing non-streaming response?)", call+1) + if isRawHTTPResponse(resp.Blocking) { + ms.writeRawHTTPResponse(w, r, resp.Blocking) + return + } + + status := cmp.Or(ms.StatusCode, http.StatusOK) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _, _ = w.Write(resp.Blocking) +} + +func (ms *MockUpstream) responseForCall(call int) UpstreamResponse { + if call >= len(ms.responses) { + if ms.AllowOverflow { + return ms.responses[len(ms.responses)-1] + } + ms.t.Fatalf("unexpected number of calls: %v, got only %v responses", call, len(ms.responses)) + } + return ms.responses[call] +} + +func isStreaming(body []byte, urlPath string) bool { + // The Anthropic SDK's Bedrock middleware extracts "stream" + // from the JSON body and encodes them in the URL path instead. + // See: https://github.com/anthropics/anthropic-sdk-go/blob/4d669338f2041f3c60640b6dd317c4895dc71cd4/bedrock/bedrock.go#L247-L248 + return gjson.GetBytes(body, "stream").Bool() || strings.HasSuffix(urlPath, "invoke-with-response-stream") +} + +func (ms *MockUpstream) writeSSE(w http.ResponseWriter, data []byte) { + ms.t.Helper() + + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "streaming unsupported", http.StatusInternalServerError) + return + } + + // Write line-by-line to simulate SSE events arriving incrementally + scanner := bufio.NewScanner(bytes.NewReader(data)) + scanner.Buffer(make([]byte, 0, 64*1024), 1024*1024) + for scanner.Scan() { + _, err := fmt.Fprintf(w, "%s\n", scanner.Text()) + require.NoError(ms.t, err) + flusher.Flush() + } + require.NoError(ms.t, scanner.Err()) +} + +// isRawHTTPResponse returns true if data starts with "HTTP/", indicating +// it contains a complete HTTP response (status line + headers + body) rather +// than just a response body. +func isRawHTTPResponse(data []byte) bool { + return bytes.HasPrefix(data, []byte("HTTP/")) +} + +// writeRawHTTPResponse parses data as a complete HTTP response and replays it, +// copying the status code, headers, and body to w. This supports error fixtures +// that contain full HTTP responses (e.g. "HTTP/2.0 400 Bad Request\r\n..."). +func (ms *MockUpstream) writeRawHTTPResponse(w http.ResponseWriter, r *http.Request, data []byte) { + ms.t.Helper() + + resp, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(data)), r) + require.NoError(ms.t, err) + defer resp.Body.Close() + + for key, values := range resp.Header { + for _, value := range values { + w.Header().Add(key, value) + } + } + w.WriteHeader(resp.StatusCode) + + _, err = io.Copy(w, resp.Body) + require.NoError(ms.t, err) +} + +// validateRequest dispatches to provider-specific validators based on URL path +// and fails the test immediately if the request body is invalid. +func validateRequest(t *testing.T, call int, path string, body []byte) { + t.Helper() + + msgAndArgs := []any{fmt.Sprintf("request #%d validation failed\n\nBody:\n%s", call+1, body)} + switch { + case strings.Contains(path, "/chat/completions"): + validateOpenAIChatCompletion(t, body, msgAndArgs...) + case strings.Contains(path, "/responses"): + validateOpenAIResponses(t, body, msgAndArgs...) + case strings.Contains(path, "/messages"): + validateAnthropicMessages(t, body, msgAndArgs...) + } +} + +// validateOpenAIChatCompletion validates that an OpenAI chat completion request +// has all required fields. +// See https://platform.openai.com/docs/api-reference/chat/create. +func validateOpenAIChatCompletion(t *testing.T, body []byte, msgAndArgs ...any) { + t.Helper() + + var req openai.ChatCompletionNewParams + require.NoError(t, json.Unmarshal(body, &req), msgAndArgs...) + require.NotEmpty(t, req.Model, "model is required", msgAndArgs) + require.NotEmpty(t, req.Messages, "messages is required", msgAndArgs) +} + +// validateOpenAIResponses validates that an OpenAI responses request +// has all required fields. +// See https://platform.openai.com/docs/api-reference/responses/create. +func validateOpenAIResponses(t *testing.T, body []byte, msgAndArgs ...any) { + t.Helper() + + var m map[string]any + require.NoError(t, json.Unmarshal(body, &m), msgAndArgs...) + require.NotEmpty(t, m["model"], "model is required", msgAndArgs) + require.Contains(t, m, "input", msgAndArgs...) +} + +// validateAnthropicMessages validates that an Anthropic messages request +// has all required fields. +// See https://github.com/anthropics/anthropic-sdk-go. +func validateAnthropicMessages(t *testing.T, body []byte, msgAndArgs ...any) { + t.Helper() + + var req anthropic.MessageNewParams + require.NoError(t, json.Unmarshal(body, &req), msgAndArgs...) + require.NotEmpty(t, req.Model, "model is required", msgAndArgs) + require.NotEmpty(t, req.Messages, "messages is required", msgAndArgs) + require.NotZero(t, req.MaxTokens, "max_tokens is required", msgAndArgs) +} diff --git a/metrics_integration_test.go b/metrics_integration_test.go index 290e927..d4c0586 100644 --- a/metrics_integration_test.go +++ b/metrics_integration_test.go @@ -23,7 +23,6 @@ import ( promtest "github.com/prometheus/client_golang/prometheus/testutil" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" - "golang.org/x/tools/txtar" ) func TestMetrics_Interception(t *testing.T) { @@ -37,6 +36,7 @@ func TestMetrics_Interception(t *testing.T) { expectModel string expectRoute string expectProvider string + allowOverflow bool // error fixtures may cause retries }{ { name: "ant_simple", @@ -55,6 +55,7 @@ func TestMetrics_Interception(t *testing.T) { expectModel: "claude-sonnet-4-0", expectRoute: "/v1/messages", expectProvider: config.ProviderAnthropic, + allowOverflow: true, }, { name: "oai_chat_simple", @@ -73,6 +74,7 @@ func TestMetrics_Interception(t *testing.T) { expectModel: "gpt-4.1", expectRoute: "/v1/chat/completions", expectProvider: config.ProviderOpenAI, + allowOverflow: true, }, { name: "oai_responses_blocking_simple", @@ -91,6 +93,7 @@ func TestMetrics_Interception(t *testing.T) { expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", expectProvider: config.ProviderOpenAI, + allowOverflow: true, }, { name: "oai_responses_streaming_simple", @@ -109,6 +112,7 @@ func TestMetrics_Interception(t *testing.T) { expectModel: "gpt-4o-mini", expectRoute: "/v1/responses", expectProvider: config.ProviderOpenAI, + allowOverflow: true, }, } @@ -116,25 +120,23 @@ func TestMetrics_Interception(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - arc := txtar.Parse(tc.fixture) - files := filesMap(arc) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - mockAPI := newMockServer(ctx, t, files, nil, nil) - t.Cleanup(mockAPI.Close) + fix := fixtures.Parse(t, tc.fixture) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + upstream.AllowOverflow = tc.allowOverflow metrics := aibridge.NewMetrics(prometheus.NewRegistry()) var prov aibridge.Provider if tc.expectProvider == config.ProviderAnthropic { - prov = provider.NewAnthropic(anthropicCfg(mockAPI.URL, apiKey), nil) + prov = provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil) } else { - prov = provider.NewOpenAI(openaiCfg(mockAPI.URL, apiKey)) + prov = provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) } srv, _ := newTestSrv(t, ctx, prov, metrics, testTracer) - req := tc.reqFunc(t, srv.URL, files[fixtureRequest]) + req := tc.reqFunc(t, srv.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -152,8 +154,7 @@ func TestMetrics_Interception(t *testing.T) { func TestMetrics_InterceptionsInflight(t *testing.T) { t.Parallel() - arc := txtar.Parse(fixtures.AntSimple) - files := filesMap(arc) + fix := fixtures.Parse(t, fixtures.AntSimple) ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) @@ -161,16 +162,9 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { blockCh := make(chan struct{}) // Setup a mock HTTP server which blocks until the request is marked as inflight then proceeds. - srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { <-blockCh - mock := newMockServer(ctx, t, files, nil, nil) - defer mock.Close() - mock.Server.Config.Handler.ServeHTTP(w, r) })) - srv.Config.BaseContext = func(_ net.Listener) context.Context { - return ctx - } - srv.Start() t.Cleanup(srv.Close) metrics := aibridge.NewMetrics(prometheus.NewRegistry()) @@ -181,7 +175,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { doneCh := make(chan struct{}) go func() { defer close(doneCh) - req := createAnthropicMessagesReq(t, bridgeSrv.URL, files[fixtureRequest]) + req := createAnthropicMessagesReq(t, bridgeSrv.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) if err == nil { defer resp.Body.Close() @@ -215,14 +209,7 @@ func TestMetrics_InterceptionsInflight(t *testing.T) { func TestMetrics_PassthroughCount(t *testing.T) { t.Parallel() - arc := txtar.Parse(fixtures.OaiChatFallthrough) - files := filesMap(arc) - - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(files[fixtureResponse]) - })) + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) t.Cleanup(upstream.Close) metrics := aibridge.NewMetrics(prometheus.NewRegistry()) @@ -245,20 +232,17 @@ func TestMetrics_PassthroughCount(t *testing.T) { func TestMetrics_PromptCount(t *testing.T) { t.Parallel() - arc := txtar.Parse(fixtures.OaiChatSimple) - files := filesMap(arc) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - mockAPI := newMockServer(ctx, t, files, nil, nil) - t.Cleanup(mockAPI.Close) + fix := fixtures.Parse(t, fixtures.OaiChatSimple) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := provider.NewOpenAI(openaiCfg(mockAPI.URL, apiKey)) + provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) - req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) + req := createOpenAIChatCompletionsReq(t, srv.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -273,20 +257,17 @@ func TestMetrics_PromptCount(t *testing.T) { func TestMetrics_NonInjectedToolUseCount(t *testing.T) { t.Parallel() - arc := txtar.Parse(fixtures.OaiChatSingleBuiltinTool) - files := filesMap(arc) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - mockAPI := newMockServer(ctx, t, files, nil, nil) - t.Cleanup(mockAPI.Close) + fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := provider.NewOpenAI(openaiCfg(mockAPI.URL, apiKey)) + provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) srv, _ := newTestSrv(t, ctx, provider, metrics, testTracer) - req := createOpenAIChatCompletionsReq(t, srv.URL, files[fixtureRequest]) + req := createOpenAIChatCompletionsReq(t, srv.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -301,25 +282,17 @@ func TestMetrics_NonInjectedToolUseCount(t *testing.T) { func TestMetrics_InjectedToolUseCount(t *testing.T) { t.Parallel() - arc := txtar.Parse(fixtures.AntSingleInjectedTool) - files := filesMap(arc) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) // First request returns the tool invocation, the second returns the mocked response to the tool result. - mockAPI := newMockServer(ctx, t, files, nil, func(reqCount uint32, resp []byte) []byte { - if reqCount == 1 { - return resp - } - return files[fixtureNonStreamingToolResponse] - }) - t.Cleanup(mockAPI.Close) + fix := fixtures.Parse(t, fixtures.AntSingleInjectedTool) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix)) recorder := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) metrics := aibridge.NewMetrics(prometheus.NewRegistry()) - provider := provider.NewAnthropic(anthropicCfg(mockAPI.URL, apiKey), nil) + provider := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), nil) // Setup mocked MCP server & tools. mcpProxiers, _ := setupMCPServerProxiesForTest(t, testTracer) @@ -336,7 +309,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { srv.Start() t.Cleanup(srv.Close) - req := createAnthropicMessagesReq(t, srv.URL, files[fixtureRequest]) + req := createAnthropicMessagesReq(t, srv.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) @@ -345,7 +318,7 @@ func TestMetrics_InjectedToolUseCount(t *testing.T) { // Wait until full roundtrip has completed. require.Eventually(t, func() bool { - return mockAPI.callCount.Load() == 2 + return upstream.Calls.Load() == 2 }, time.Second*10, time.Millisecond*50) require.Len(t, recorder.ToolUsages(), 1) diff --git a/responses_integration_test.go b/responses_integration_test.go index f4e7299..2521253 100644 --- a/responses_integration_test.go +++ b/responses_integration_test.go @@ -26,7 +26,6 @@ import ( "github.com/coder/aibridge/recorder" "github.com/openai/openai-go/v3/responses" "github.com/stretchr/testify/require" - "golang.org/x/tools/txtar" ) type keyVal struct { @@ -331,26 +330,18 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - files := filesMap(txtar.Parse(tc.fixture)) - require.Contains(t, files, fixtureRequest) - fixtResp := fixtureNonStreamingResponse - if tc.streaming { - fixtResp = fixtureStreamingResponse - } - require.Contains(t, files, fixtResp) - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) ctx = aibcontext.AsActor(ctx, userID, nil) - mockAPI := newMockServer(ctx, t, files, nil, nil) - t.Cleanup(mockAPI.Close) + fix := fixtures.Parse(t, tc.fixture) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) - provider := provider.NewOpenAI(openaiCfg(mockAPI.URL, apiKey)) + provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) srv, mockRecorder := newTestSrv(t, ctx, provider, nil, testTracer) defer srv.Close() - req := createOpenAIResponsesReq(t, srv.URL, files[fixtureRequest]) + req := createOpenAIResponsesReq(t, srv.URL, fix.Request()) req.Header.Set("User-Agent", tc.userAgent) client := &http.Client{} @@ -359,10 +350,13 @@ func TestResponsesOutputMatchesUpstream(t *testing.T) { defer resp.Body.Close() require.Equal(t, http.StatusOK, resp.StatusCode) got, err := io.ReadAll(resp.Body) - srv.Close() require.NoError(t, err) - require.Equal(t, string(files[fixtResp]), string(got)) + if tc.streaming { + require.Equal(t, string(fix.Streaming()), string(got)) + } else { + require.Equal(t, string(fix.NonStreaming()), string(got)) + } interceptions := mockRecorder.RecordedInterceptions() require.Len(t, interceptions, 1) @@ -865,31 +859,13 @@ func TestResponsesInjectedTool(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - files := filesMap(txtar.Parse(tc.fixture)) - require.Contains(t, files, fixtureRequest) - if tc.streaming { - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureStreamingToolResponse) - } else { - require.Contains(t, files, fixtureNonStreamingResponse) - require.Contains(t, files, fixtureNonStreamingToolResponse) - } - ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - // Setup mock server with response mutator for multi-turn interaction. - mockAPI := newMockServer(ctx, t, files, nil, func(reqCount uint32, resp []byte) []byte { - if reqCount == 1 { - return resp // First request gets the normal response (with tool call). - } - // Second request gets the tool response. - if tc.streaming { - return files[fixtureStreamingToolResponse] - } - return files[fixtureNonStreamingToolResponse] - }) - t.Cleanup(mockAPI.Close) + // Setup mock server for multi-turn interaction. + // First request → tool call response, second → tool response. + fix := fixtures.Parse(t, tc.fixture) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix), testutil.NewFixtureToolResponse(fix)) // Setup MCP server proxies (with mock tools). mcpProxiers, mcpCalls := setupMCPServerProxiesForTest(t, testTracer) @@ -899,7 +875,7 @@ func TestResponsesInjectedTool(t *testing.T) { mcpMgr := mcp.NewServerProxyManager(mcpProxiers, testTracer) require.NoError(t, mcpMgr.Init(ctx)) - prov := provider.NewOpenAI(openaiCfg(mockAPI.URL, apiKey)) + prov := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) mockRecorder := &testutil.MockRecorder{} logger := slogtest.Make(t, &slogtest.Options{}).Leveled(slog.LevelDebug) @@ -913,7 +889,7 @@ func TestResponsesInjectedTool(t *testing.T) { srv.Start() t.Cleanup(srv.Close) - req := createOpenAIResponsesReq(t, srv.URL, files[fixtureRequest]) + req := createOpenAIResponsesReq(t, srv.URL, fix.Request()) resp, err := http.DefaultClient.Do(req) require.NoError(t, err) defer resp.Body.Close() @@ -924,7 +900,7 @@ func TestResponsesInjectedTool(t *testing.T) { // Wait for both requests to be made (inner agentic loop). require.Eventually(t, func() bool { - return mockAPI.callCount.Load() == 2 + return upstream.Calls.Load() == 2 }, time.Second*10, time.Millisecond*50) // Verify the injected tool was invoked via MCP. @@ -956,9 +932,9 @@ func TestResponsesInjectedTool(t *testing.T) { // Verify the response is the final tool response (after agentic loop). if tc.streaming { - require.Equal(t, string(files[fixtureStreamingToolResponse]), string(body)) + require.Equal(t, string(fix.StreamingToolCall()), string(body)) } else { - require.Equal(t, string(files[fixtureNonStreamingToolResponse]), string(body)) + require.Equal(t, string(fix.NonStreamingToolCall()), string(body)) } }) } diff --git a/trace_integration_test.go b/trace_integration_test.go index 3fd0fbe..608a7ac 100644 --- a/trace_integration_test.go +++ b/trace_integration_test.go @@ -15,17 +15,18 @@ import ( "github.com/coder/aibridge" "github.com/coder/aibridge/config" "github.com/coder/aibridge/fixtures" + "github.com/coder/aibridge/internal/testutil" "github.com/coder/aibridge/mcp" "github.com/coder/aibridge/provider" "github.com/coder/aibridge/tracing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" + "github.com/tidwall/sjson" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" sdktrace "go.opentelemetry.io/otel/sdk/trace" "go.opentelemetry.io/otel/sdk/trace/tracetest" - "golang.org/x/tools/txtar" ) // expect 'count' amount of traces named 'name' with status 'status' @@ -88,14 +89,9 @@ func TestTraceAnthropic(t *testing.T) { }, } - arc := txtar.Parse(fixtures.AntSingleBuiltinTool) + fix := fixtures.Parse(t, fixtures.AntSingleBuiltinTool) - files := filesMap(arc) - require.Contains(t, files, fixtureRequest) - require.Contains(t, files, fixtureStreamingResponse) - require.Contains(t, files, fixtureNonStreamingResponse) - - fixtureReqBody := files[fixtureRequest] + fixtureReqBody := fix.Request() for _, tc := range cases { t.Run(fmt.Sprintf("%s/streaming=%v", t.Name(), tc.streaming), func(t *testing.T) { @@ -107,19 +103,17 @@ func TestTraceAnthropic(t *testing.T) { tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) - require.NoError(t, err) - - mockAPI := newMockServer(ctx, t, files, nil, nil) - t.Cleanup(mockAPI.Close) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) var bedrockCfg *config.AWSBedrock if tc.bedrock { - bedrockCfg = testBedrockCfg(mockAPI.URL) + bedrockCfg = testBedrockCfg(upstream.URL) } - provider := provider.NewAnthropic(anthropicCfg(mockAPI.URL, apiKey), bedrockCfg) + provider := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg) srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + reqBody, err := sjson.SetBytes(fixtureReqBody, "stream", tc.streaming) + require.NoError(t, err) req := createAnthropicMessagesReq(t, srv.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -179,27 +173,35 @@ func TestTraceAnthropicErr(t *testing.T) { } cases := []struct { - name string - streaming bool - bedrock bool - expect []expectTrace + name string + fixture []byte + streaming bool + bedrock bool + expectCode int // expected status code for non-streaming responses + expect []expectTrace }{ { - name: "anthr_non_streaming_err", - expect: expectNonStream, + name: "anthr_non_streaming_err", + fixture: fixtures.AntNonStreamError, + expectCode: http.StatusBadRequest, + expect: expectNonStream, }, { name: "anthr_streaming_err", + fixture: fixtures.AntMidStreamError, streaming: true, expect: expectStreaming, }, { - name: "bedrock_non_streaming_err", - bedrock: true, - expect: expectNonStream, + name: "bedrock_non_streaming_err", + fixture: fixtures.AntNonStreamError, + bedrock: true, + expectCode: http.StatusBadRequest, + expect: expectNonStream, }, { name: "bedrock_streaming_err", + fixture: fixtures.AntMidStreamError, streaming: true, bedrock: true, expect: expectStreaming, @@ -211,41 +213,23 @@ func TestTraceAnthropicErr(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - var arc *txtar.Archive - if tc.streaming { - arc = txtar.Parse(fixtures.AntMidStreamError) - } else { - arc = txtar.Parse(fixtures.AntNonStreamError) - } - - files := filesMap(arc) - require.Contains(t, files, fixtureRequest) - if tc.streaming { - require.Contains(t, files, fixtureStreamingResponse) - } else { - require.Contains(t, files, fixtureNonStreamingResponse) - } - - fixtureReqBody := files[fixtureRequest] - sr := tracetest.NewSpanRecorder() tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) - require.NoError(t, err) - - mockAPI := newMockServer(ctx, t, files, nil, nil) - t.Cleanup(mockAPI.Close) + fix := fixtures.Parse(t, tc.fixture) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) var bedrockCfg *config.AWSBedrock if tc.bedrock { - bedrockCfg = testBedrockCfg(mockAPI.URL) + bedrockCfg = testBedrockCfg(upstream.URL) } - provider := provider.NewAnthropic(anthropicCfg(mockAPI.URL, apiKey), bedrockCfg) + provider := provider.NewAnthropic(anthropicCfg(upstream.URL, apiKey), bedrockCfg) srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) + require.NoError(t, err) req := createAnthropicMessagesReq(t, srv.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -253,7 +237,7 @@ func TestTraceAnthropicErr(t *testing.T) { if tc.streaming { require.Equal(t, http.StatusOK, resp.StatusCode) } else { - require.Equal(t, http.StatusInternalServerError, resp.StatusCode) + require.Equal(t, tc.expectCode, resp.StatusCode) } defer resp.Body.Close() srv.Close() @@ -465,31 +449,18 @@ func TestTraceOpenAI(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - arc := txtar.Parse(tc.fixture) - - files := filesMap(arc) - require.Contains(t, files, fixtureRequest) - if tc.streaming { - require.Contains(t, files, fixtureStreamingResponse) - } else { - require.Contains(t, files, fixtureNonStreamingResponse) - } - - fixtureReqBody := files[fixtureRequest] - sr := tracetest.NewSpanRecorder() tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) - require.NoError(t, err) - - mockAPI := newMockServer(ctx, t, files, nil, nil) - t.Cleanup(mockAPI.Close) - provider := provider.NewOpenAI(openaiCfg(mockAPI.URL, apiKey)) + fix := fixtures.Parse(t, tc.fixture) + upstream := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + provider := provider.NewOpenAI(openaiCfg(upstream.URL, apiKey)) srv, recorder := newTestSrv(t, ctx, provider, nil, tracer) + reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) + require.NoError(t, err) req := tc.reqFunc(t, srv.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -522,14 +493,14 @@ func TestTraceOpenAI(t *testing.T) { func TestTraceOpenAIErr(t *testing.T) { cases := []struct { - name string - fixture []byte - streaming bool - useMockReflector bool - expectPath string - reqFunc func(t *testing.T, baseURL string, input []byte) *http.Request - expect []expectTrace - expectCode int + name string + fixture []byte + streaming bool + allowOverflow bool + expectPath string + reqFunc func(t *testing.T, baseURL string, input []byte) *http.Request + expect []expectTrace + expectCode int }{ { name: "trace_openai_chat_streaming_error", @@ -554,7 +525,7 @@ func TestTraceOpenAIErr(t *testing.T) { streaming: false, expectPath: "/openai/v1/chat/completions", reqFunc: createOpenAIChatCompletionsReq, - expectCode: http.StatusInternalServerError, + expectCode: http.StatusBadRequest, expect: []expectTrace{ {"Intercept", 1, codes.Error}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -602,13 +573,14 @@ func TestTraceOpenAIErr(t *testing.T) { }, }, { - name: "trace_openai_responses_streaming_http_error", - fixture: fixtures.OaiResponsesStreamingHttpErr, - streaming: true, - useMockReflector: true, - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, - expectCode: http.StatusTooManyRequests, + name: "trace_openai_responses_streaming_http_error", + fixture: fixtures.OaiResponsesStreamingHttpErr, + streaming: true, + allowOverflow: true, // 429 error causes retries + + expectPath: "/openai/v1/responses", + reqFunc: createOpenAIResponsesReq, + expectCode: http.StatusTooManyRequests, expect: []expectTrace{ {"Intercept", 1, codes.Error}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -619,13 +591,13 @@ func TestTraceOpenAIErr(t *testing.T) { }, }, { - name: "trace_openai_responses_blocking_http_error", - fixture: fixtures.OaiResponsesBlockingHttpErr, - streaming: false, - useMockReflector: true, - expectPath: "/openai/v1/responses", - reqFunc: createOpenAIResponsesReq, - expectCode: http.StatusUnauthorized, + name: "trace_openai_responses_blocking_http_error", + fixture: fixtures.OaiResponsesBlockingHttpErr, + streaming: false, + + expectPath: "/openai/v1/responses", + reqFunc: createOpenAIResponsesReq, + expectCode: http.StatusUnauthorized, expect: []expectTrace{ {"Intercept", 1, codes.Error}, {"Intercept.CreateInterceptor", 1, codes.Unset}, @@ -642,42 +614,20 @@ func TestTraceOpenAIErr(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Second*30) t.Cleanup(cancel) - files := filesMap(txtar.Parse(tc.fixture)) - require.Contains(t, files, fixtureRequest) - if tc.streaming { - require.Contains(t, files, fixtureStreamingResponse) - } else { - require.Contains(t, files, fixtureNonStreamingResponse) - } - - fixtureReqBody := files[fixtureRequest] - sr := tracetest.NewSpanRecorder() tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr)) tracer := tp.Tracer(t.Name()) defer func() { _ = tp.Shutdown(t.Context()) }() - reqBody, err := setJSON(fixtureReqBody, "stream", tc.streaming) - require.NoError(t, err) + fix := fixtures.Parse(t, tc.fixture) - var respBody []byte - if tc.streaming { - respBody = files[fixtureStreamingResponse] - } else { - respBody = files[fixtureNonStreamingResponse] - } - var prov *provider.OpenAI - if tc.useMockReflector { - mockAPI := newMockHTTPReflector(ctx, t, respBody) - t.Cleanup(mockAPI.Close) - prov = provider.NewOpenAI(openaiCfg(mockAPI.URL, apiKey)) - } else { - mockAPI := newMockServer(ctx, t, files, nil, nil) - t.Cleanup(mockAPI.Close) - prov = provider.NewOpenAI(openaiCfg(mockAPI.URL, apiKey)) - } + mockAPI := testutil.NewMockUpstream(t, ctx, testutil.NewFixtureResponse(fix)) + mockAPI.AllowOverflow = tc.allowOverflow + prov := provider.NewOpenAI(openaiCfg(mockAPI.URL, apiKey)) srv, recorder := newTestSrv(t, ctx, prov, nil, tracer) + reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) + require.NoError(t, err) req := tc.reqFunc(t, srv.URL, reqBody) client := &http.Client{} resp, err := client.Do(req) @@ -766,15 +716,9 @@ func TestOpenAIInjectedToolsTrace(t *testing.T) { func TestTracePassthrough(t *testing.T) { t.Parallel() - arc := txtar.Parse(fixtures.OaiChatFallthrough) - files := filesMap(arc) + fix := fixtures.Parse(t, fixtures.OaiChatFallthrough) - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusOK) - _, _ = w.Write(files[fixtureResponse]) - })) - t.Cleanup(upstream.Close) + upstream := testutil.NewMockUpstream(t, t.Context(), testutil.NewFixtureResponse(fix)) sr := tracetest.NewSpanRecorder() tp := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(sr))