Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions fixtures/anthropic/single_builtin_tool.txtar
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@ Claude Code has builtin tools to (e.g.) explore the filesystem.
{
"model": "claude-sonnet-4-20250514",
"max_tokens": 1024,
"tools": [
{
"name": "Read",
"description": "Read the contents of a file at the given path.",
"input_schema": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "The absolute path to the file to read"
}
},
"required": ["file_path"]
}
}
],
"messages": [
{
"role": "user",
Expand Down
9 changes: 8 additions & 1 deletion intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,13 @@ func (i *interceptionBase) newErrorResponse(err error) map[string]any {
}

func (i *interceptionBase) injectTools() {
if i.req == nil || i.mcpProxy == nil {
if i.req == nil || i.mcpProxy == nil || !i.hasInjectableTools() {
return
}

// Disable parallel tool calls when injectable tools are present to simplify the inner agentic loop.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was somehow missed previously; we weren't setting parallel tool calls to disabled in blocking mode :thinking_face:

i.req.ParallelToolCalls = openai.Bool(false)

// Inject tools.
for _, tool := range i.mcpProxy.ListTools() {
fn := openai.ChatCompletionToolUnionParam{
Expand Down Expand Up @@ -171,6 +174,10 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *err
}
}

func (i *interceptionBase) hasInjectableTools() bool {
return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0
}

func sumUsage(ref, in openai.CompletionUsage) openai.CompletionUsage {
return openai.CompletionUsage{
CompletionTokens: ref.CompletionTokens + in.CompletionTokens,
Expand Down
6 changes: 0 additions & 6 deletions intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,6 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
_ = events.Shutdown(streamCtx) // Catch-all in case it doesn't get shutdown after stream completes.
}()

// TODO: implement parallel tool calls.
// TODO: don't send if not supported by model (i.e. o4-mini).
if len(i.req.Tools) > 0 { // If no tools are specified but this setting is set, it'll cause a 400 Bad Request.
i.req.ParallelToolCalls = openai.Bool(false)
}

// Force responses to only have one choice.
// It's unnecessary to generate multiple responses, and would complicate our stream processing logic if
// multiple choices were returned.
Expand Down
17 changes: 10 additions & 7 deletions intercept/messages/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,19 +101,15 @@ func (s *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool)
}

func (i *interceptionBase) injectTools() {
if i.req == nil || i.mcpProxy == nil {
if i.req == nil || i.mcpProxy == nil || !i.hasInjectableTools() {
return
}

tools := i.mcpProxy.ListTools()
if len(tools) == 0 {
// No injected tools: no need to influence parallel tool calling.
return
}
i.disableParallelToolCalls()

// Inject tools.
var injectedTools []anthropic.ToolUnionParam
for _, tool := range tools {
for _, tool := range i.mcpProxy.ListTools() {
injectedTools = append(injectedTools, anthropic.ToolUnionParam{
OfTool: &anthropic.ToolParam{
InputSchema: anthropic.ToolInputSchemaParam{
Expand All @@ -137,7 +133,9 @@ func (i *interceptionBase) injectTools() {
if err != nil {
i.logger.Warn(context.Background(), "failed to set inject tools in request payload", slog.Error(err))
}
}

func (i *interceptionBase) disableParallelToolCalls() {
// Note: Parallel tool calls are disabled to avoid tool_use/tool_result block mismatches.
// https://github.com/coder/aibridge/issues/2
toolChoiceType := i.req.ToolChoice.GetType()
Expand All @@ -163,6 +161,7 @@ func (i *interceptionBase) injectTools() {
case string(constant.ValueOf[constant.None]()):
// No-op; if tool_choice=none then tools are not used at all.
}
var err error
i.payload, err = sjson.SetBytes(i.payload, "tool_choice", i.req.ToolChoice)
if err != nil {
i.logger.Warn(context.Background(), "failed to set tool_choice in request payload", slog.Error(err))
Expand Down Expand Up @@ -315,6 +314,10 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *Err
}
}

func (i *interceptionBase) hasInjectableTools() bool {
return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0
}

// accumulateUsage accumulates usage statistics from source into dest.
// It handles both [anthropic.Usage] and [anthropic.MessageDeltaUsage] types through [any].
// The function uses reflection to handle the differences between the types:
Expand Down
4 changes: 4 additions & 0 deletions intercept/responses/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,10 @@ func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, respon
}
}

func (i *responsesInterceptionBase) hasInjectableTools() bool {
return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0
}

// responseCopier helper struct to send original response to the client
type responseCopier struct {
buff deltaBuffer
Expand Down
1 change: 0 additions & 1 deletion intercept/responses/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *
}

i.injectTools()
i.disableParallelToolCalls()

var (
response *responses.Response
Expand Down
19 changes: 7 additions & 12 deletions intercept/responses/injected_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,14 @@ import (
)

func (i *responsesInterceptionBase) injectTools() {
if i.req == nil || i.mcpProxy == nil {
if i.req == nil || i.mcpProxy == nil || !i.hasInjectableTools() {
return
}

tools := i.mcpProxy.ListTools()
if len(tools) == 0 {
return
}
i.disableParallelToolCalls()

// Inject tools.
for _, tool := range tools {
for _, tool := range i.mcpProxy.ListTools() {
var params map[string]any

if tool.Params != nil {
Expand Down Expand Up @@ -67,12 +64,10 @@ func (i *responsesInterceptionBase) injectTools() {
// TODO: implement parallel tool calls.
func (i *responsesInterceptionBase) disableParallelToolCalls() {
// Disable parallel tool calls to simplify inner agentic loop; best-effort.
if len(i.req.Tools) > 0 {
var err error
i.reqPayload, err = sjson.SetBytes(i.reqPayload, "parallel_tool_calls", false)
if err != nil {
i.logger.Warn(context.Background(), "failed to disable parallel_tool_calls", slog.Error(err))
}
var err error
i.reqPayload, err = sjson.SetBytes(i.reqPayload, "parallel_tool_calls", false)
if err != nil {
i.logger.Warn(context.Background(), "failed to disable parallel_tool_calls", slog.Error(err))
}
}

Expand Down
1 change: 0 additions & 1 deletion intercept/responses/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r
}

i.injectTools()
i.disableParallelToolCalls()

events := eventstream.NewEventStream(ctx, i.logger.Named("sse-sender"), nil)
go events.Start(w, r)
Expand Down
Loading
Loading