Skip to content

Commit 732f3d2

Browse files
committed
fix: only disable parallel tool calls when injectable tools are present
Signed-off-by: Danny Kopping <danny@coder.com>
1 parent 250e790 commit 732f3d2

File tree

11 files changed

+134
-9
lines changed

11 files changed

+134
-9
lines changed

fixtures/openai/chatcompletions/single_builtin_tool.txtar

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ LLM (https://llm.datasette.io/) configured with a simple "read_file" tool.
99
}
1010
],
1111
"model": "gpt-4.1",
12+
"parallel_tool_calls": true,
1213
"tools": [
1314
{
1415
"type": "function",

intercept/chatcompletions/base.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ func (i *interceptionBase) injectTools() {
107107
return
108108
}
109109

110+
// Disable parallel tool calls when injectable tools are present to simplify the inner agentic loop.
111+
// Only set when there are tools in the request, otherwise it causes a 400 Bad Request.
112+
if i.HasInjectableTools() && len(i.req.Tools) > 0 {
113+
i.req.ParallelToolCalls = openai.Bool(false)
114+
}
115+
110116
// Inject tools.
111117
for _, tool := range i.mcpProxy.ListTools() {
112118
fn := openai.ChatCompletionToolUnionParam{
@@ -171,6 +177,10 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *err
171177
}
172178
}
173179

180+
func (i *interceptionBase) HasInjectableTools() bool {
181+
return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0
182+
}
183+
174184
func sumUsage(ref, in openai.CompletionUsage) openai.CompletionUsage {
175185
return openai.CompletionUsage{
176186
CompletionTokens: ref.CompletionTokens + in.CompletionTokens,

intercept/chatcompletions/streaming.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,6 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
9797
_ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes.
9898
}()
9999

100-
// TODO: implement parallel tool calls.
101-
// TODO: don't send if not supported by model (i.e. o4-mini).
102-
if len(i.req.Tools) > 0 { // If no tools are specified but this setting is set, it'll cause a 400 Bad Request.
103-
i.req.ParallelToolCalls = openai.Bool(false)
104-
}
105-
106100
// Force responses to only have one choice.
107101
// It's unnecessary to generate multiple responses, and would complicate our stream processing logic if
108102
// multiple choices were returned.

intercept/interceptor.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,7 @@ type Interceptor interface {
3434
// by the model, so any single tool call ID is sufficient to identify the
3535
// parent interception.
3636
CorrelatingToolCallID() *string
37+
// HasInjectableTools returns true if an [mcp.ServerProxier] has been provided
38+
// and contains tools which must be injected into requests.
39+
HasInjectableTools() bool
3740
}

intercept/messages/base.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,10 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *Err
315315
}
316316
}
317317

318+
func (i *interceptionBase) HasInjectableTools() bool {
319+
return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0
320+
}
321+
318322
// accumulateUsage accumulates usage statistics from source into dest.
319323
// It handles both [anthropic.Usage] and [anthropic.MessageDeltaUsage] types through [any].
320324
// The function uses reflection to handle the differences between the types:

intercept/responses/base.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,10 @@ func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, respon
326326
}
327327
}
328328

329+
func (i *responsesInterceptionBase) HasInjectableTools() bool {
330+
return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0
331+
}
332+
329333
// responseCopier helper struct to send original response to the client
330334
type responseCopier struct {
331335
buff deltaBuffer

intercept/responses/blocking.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *
5959
}
6060

6161
i.injectTools()
62-
i.disableParallelToolCalls()
6362

6463
var (
6564
response *responses.Response

intercept/responses/injected_tools.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ func (i *responsesInterceptionBase) injectTools() {
2525
return
2626
}
2727

28+
// If there are injectable tools, disable parallel tool calls.
29+
i.disableParallelToolCalls()
30+
2831
// Inject tools.
2932
for _, tool := range tools {
3033
var params map[string]any

intercept/responses/streaming.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r
7070
}
7171

7272
i.injectTools()
73-
i.disableParallelToolCalls()
7473

7574
events := eventstream.NewEventStream(ctx, i.logger.Named("sse-sender"), nil)
7675
go events.Start(w, r)

internal/integrationtest/bridge_test.go

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,83 @@ func TestAnthropicToolChoiceParallelDisabled(t *testing.T) {
13081308
}
13091309
}
13101310

1311+
// TestChatCompletionsParallelToolCallsDisabled verifies that parallel_tool_calls
1312+
// is set to false only when injectable MCP tools are present and the request
1313+
// includes tools.
1314+
func TestChatCompletionsParallelToolCallsDisabled(t *testing.T) {
1315+
t.Parallel()
1316+
1317+
cases := []struct {
1318+
name string
1319+
streaming bool
1320+
withInjectedTools bool
1321+
expectParallelToolCalls bool
1322+
}{
1323+
// Streaming with injected tools: parallel_tool_calls should be forced false.
1324+
{
1325+
name: "streaming/with_injected_tools",
1326+
streaming: true,
1327+
withInjectedTools: true,
1328+
expectParallelToolCalls: false,
1329+
},
1330+
// Streaming without injected tools: parallel_tool_calls preserved.
1331+
{
1332+
name: "streaming/no_injected_tools",
1333+
streaming: true,
1334+
withInjectedTools: false,
1335+
expectParallelToolCalls: true,
1336+
},
1337+
// Blocking with injected tools: parallel_tool_calls should be forced false.
1338+
{
1339+
name: "blocking/with_injected_tools",
1340+
streaming: false,
1341+
withInjectedTools: true,
1342+
expectParallelToolCalls: false,
1343+
},
1344+
{
1345+
name: "blocking/no_injected_tools",
1346+
streaming: false,
1347+
withInjectedTools: false,
1348+
expectParallelToolCalls: true,
1349+
},
1350+
}
1351+
1352+
for _, tc := range cases {
1353+
t.Run(tc.name, func(t *testing.T) {
1354+
t.Parallel()
1355+
1356+
ctx, cancel := context.WithTimeout(t.Context(), time.Second*30)
1357+
t.Cleanup(cancel)
1358+
1359+
fix := fixtures.Parse(t, fixtures.OaiChatSingleBuiltinTool)
1360+
upstream := newMockUpstream(t, ctx, newFixtureResponse(fix))
1361+
1362+
var opts []bridgeOption
1363+
if tc.withInjectedTools {
1364+
opts = append(opts, withMCP(setupMCPForTest(t, defaultTracer)))
1365+
}
1366+
bridgeServer := newBridgeTestServer(t, ctx, upstream.URL, opts...)
1367+
1368+
reqBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
1369+
require.NoError(t, err)
1370+
1371+
resp := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIChatCompletions, reqBody)
1372+
_, err = io.ReadAll(resp.Body)
1373+
require.NoError(t, err)
1374+
1375+
received := upstream.receivedRequests()
1376+
require.Len(t, received, 1)
1377+
1378+
var upstreamReq map[string]any
1379+
require.NoError(t, json.Unmarshal(received[0].Body, &upstreamReq))
1380+
1381+
ptc, ok := upstreamReq["parallel_tool_calls"].(bool)
1382+
require.True(t, ok, "parallel_tool_calls should be present in upstream request")
1383+
assert.Equal(t, tc.expectParallelToolCalls, ptc)
1384+
})
1385+
}
1386+
}
1387+
13111388
func TestThinkingAdaptiveIsPreserved(t *testing.T) {
13121389
t.Parallel()
13131390

0 commit comments

Comments
 (0)