diff --git a/packages/workers-ai-provider/src/convert-to-workersai-chat-messages.ts b/packages/workers-ai-provider/src/convert-to-workersai-chat-messages.ts index 752a0c4d..86005437 100644 --- a/packages/workers-ai-provider/src/convert-to-workersai-chat-messages.ts +++ b/packages/workers-ai-provider/src/convert-to-workersai-chat-messages.ts @@ -1,7 +1,10 @@ import { type LanguageModelV1Prompt, UnsupportedFunctionalityError } from "@ai-sdk/provider"; import type { WorkersAIChatPrompt } from "./workersai-chat-prompt"; -export function convertToWorkersAIChatMessages(prompt: LanguageModelV1Prompt): WorkersAIChatPrompt { +export function convertToWorkersAIChatMessages( + prompt: LanguageModelV1Prompt, + excludeContentWithToolCalls = false +): WorkersAIChatPrompt { const messages: WorkersAIChatPrompt = []; for (const { role, content } of prompt) { @@ -71,7 +74,7 @@ export function convertToWorkersAIChatMessages(prompt: LanguageModelV1Prompt): W messages.push({ role: "assistant", - content: text, + content: excludeContentWithToolCalls && toolCalls.length > 0 ? undefined : text, // fix for mistral tool_calls: toolCalls.length > 0 ? toolCalls.map(({ function: { name, arguments: args } }) => ({ @@ -88,6 +91,7 @@ export function convertToWorkersAIChatMessages(prompt: LanguageModelV1Prompt): W for (const toolResponse of content) { messages.push({ role: "tool", + tool_call_id: toolResponse.toolCallId, // required by mistral name: toolResponse.toolName, content: JSON.stringify(toolResponse.result), }); diff --git a/packages/workers-ai-provider/src/workersai-chat-language-model.ts b/packages/workers-ai-provider/src/workersai-chat-language-model.ts index 530c1dfc..583821ad 100644 --- a/packages/workers-ai-provider/src/workersai-chat-language-model.ts +++ b/packages/workers-ai-provider/src/workersai-chat-language-model.ts @@ -83,7 +83,7 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 { random_seed: seed, // messages: - messages: convertToWorkersAIChatMessages(prompt), + messages: convertToWorkersAIChatMessages(prompt, this.modelId.includes('mistral')), }; switch (type) { @@ -139,7 +139,7 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 { ): Promise>> { const { args, warnings } = this.getArgs(options); - const { gateway, safePrompt, ...passthroughOptions } = this.settings; + const { gateway, safePrompt, sequentialCalls, ...passthroughOptions } = this.settings; const output = await this.config.binding.run( args.model, @@ -181,11 +181,14 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 { options: Parameters[0], ): Promise>> { const { args, warnings } = this.getArgs(options); + const { gateway, safePrompt, sequentialCalls, ...passthroughOptions } = this.settings; // [1] When the latest message is not a tool response, we use the regular generate function // and simulate it as a streamed response in order to satisfy the AI SDK's interface for // doStream... - if (args.tools?.length && lastMessageWasUser(args.messages)) { + // To allow a model to chain together tool calls, `sequentialCalls` can be set to `true`. + // This will make all responses be simulated streams. + if (args.tools?.length && (!!sequentialCalls || lastMessageWasUser(args.messages))) { const response = await this.doGenerate(options); if (response instanceof ReadableStream) { @@ -223,7 +226,6 @@ export class WorkersAIChatLanguageModel implements LanguageModelV1 { } // [2] ...otherwise, we just proceed as normal and stream the response directly from the remote model. - const { gateway, ...passthroughOptions } = this.settings; const response = await this.config.binding.run( args.model, diff --git a/packages/workers-ai-provider/src/workersai-chat-settings.ts b/packages/workers-ai-provider/src/workersai-chat-settings.ts index 2f129737..c572d6e4 100644 --- a/packages/workers-ai-provider/src/workersai-chat-settings.ts +++ b/packages/workers-ai-provider/src/workersai-chat-settings.ts @@ -7,6 +7,12 @@ export type WorkersAIChatSettings = { */ safePrompt?: boolean; + /** + * Whether to allow the model to execute calls when the last message was not a user message. Turns off streaming. + * Defaults to `false`. + */ + sequentialCalls?: boolean; + /** * Optionally set Cloudflare AI Gateway options. * @deprecated