diff --git a/pydantic_ai_slim/pydantic_ai/models/openrouter.py b/pydantic_ai_slim/pydantic_ai/models/openrouter.py index 3f910d76b6..07550b9a03 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openrouter.py +++ b/pydantic_ai_slim/pydantic_ai/models/openrouter.py @@ -2,9 +2,9 @@ from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Literal, cast +from typing import Annotated, Any, Literal, TypeAlias, cast -from pydantic import BaseModel +from pydantic import BaseModel, Discriminator from typing_extensions import TypedDict, assert_never, override from ..exceptions import ModelHTTPError @@ -22,9 +22,13 @@ try: from openai import APIError, AsyncOpenAI from openai.types import chat, completion_usage - from openai.types.chat import chat_completion, chat_completion_chunk + from openai.types.chat import chat_completion, chat_completion_chunk, chat_completion_message_function_tool_call - from .openai import OpenAIChatModel, OpenAIChatModelSettings, OpenAIStreamedResponse + from .openai import ( + OpenAIChatModel, + OpenAIChatModelSettings, + OpenAIStreamedResponse, + ) except ImportError as _import_error: raise ImportError( 'Please install `openai` to use the OpenRouter model, ' @@ -341,6 +345,27 @@ def _into_reasoning_detail(thinking_part: ThinkingPart) -> _OpenRouterReasoningD assert_never(data.type) +class _OpenRouterFunction(chat_completion_message_function_tool_call.Function): + arguments: str | None # type: ignore[reportIncompatibleVariableOverride] + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + + +class _OpenRouterChatCompletionMessageFunctionToolCall(chat.ChatCompletionMessageFunctionToolCall): + function: _OpenRouterFunction # type: ignore[reportIncompatibleVariableOverride] + """The function that the model called.""" + + +_OpenRouterChatCompletionMessageToolCallUnion: TypeAlias = Annotated[ + _OpenRouterChatCompletionMessageFunctionToolCall | chat.ChatCompletionMessageCustomToolCall, + Discriminator(discriminator='type'), +] + + class _OpenRouterCompletionMessage(chat.ChatCompletionMessage): """Wrapped chat completion message with OpenRouter specific attributes.""" @@ -350,6 +375,9 @@ class _OpenRouterCompletionMessage(chat.ChatCompletionMessage): reasoning_details: list[_OpenRouterReasoningDetail] | None = None """The reasoning details associated with the message, if any.""" + tool_calls: list[_OpenRouterChatCompletionMessageToolCallUnion] | None = None # type: ignore[reportIncompatibleVariableOverride] + """The tool calls generated by the model, such as function calls.""" + class _OpenRouterChoice(chat_completion.Choice): """Wraps OpenAI chat completion choice with OpenRouter specific attributes.""" diff --git a/tests/models/cassettes/test_openrouter/test_openrouter_tool_optional_parameters.yaml b/tests/models/cassettes/test_openrouter/test_openrouter_tool_optional_parameters.yaml new file mode 100644 index 0000000000..a826576f09 --- /dev/null +++ b/tests/models/cassettes/test_openrouter/test_openrouter_tool_optional_parameters.yaml @@ -0,0 +1,86 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '362' + content-type: + - application/json + host: + - openrouter.ai + method: POST + parsed_body: + messages: + - content: Can you find me any education content? + role: user + model: anthropic/claude-sonnet-4.5 + stream: false + tool_choice: auto + tools: + - function: + description: '' + name: find_education_content + parameters: + properties: + title: + anyOf: + - type: string + - type: 'null' + default: null + type: object + type: function + uri: https://openrouter.ai/api/v1/chat/completions + response: + headers: + access-control-allow-origin: + - '*' + connection: + - keep-alive + content-length: + - '611' + content-type: + - application/json + permissions-policy: + - payment=(self "https://checkout.stripe.com" "https://connect-js.stripe.com" "https://js.stripe.com" "https://*.js.stripe.com" + "https://hooks.stripe.com") + referrer-policy: + - no-referrer, strict-origin-when-cross-origin + transfer-encoding: + - chunked + vary: + - Accept-Encoding + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + content: I'll search for education content for you. + reasoning: null + refusal: null + role: assistant + tool_calls: + - function: + name: find_education_content + id: toolu_vrtx_015QAXScZzRDPttiPoc34AdD + index: 0 + type: function + native_finish_reason: tool_calls + created: 1764308342 + id: gen-1764308342-FInFdBZR9TF8jmnOwZGZ + model: anthropic/claude-sonnet-4.5 + object: chat.completion + provider: Google + usage: + completion_tokens: 48 + prompt_tokens: 568 + total_tokens: 616 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_openrouter.py b/tests/models/test_openrouter.py index d6243bdfe5..c52ec6e546 100644 --- a/tests/models/test_openrouter.py +++ b/tests/models/test_openrouter.py @@ -358,3 +358,51 @@ async def test_openrouter_map_messages_reasoning(allow_model_requests: None, ope } ] ) + + +async def test_openrouter_tool_optional_parameters(allow_model_requests: None, openrouter_api_key: str) -> None: + provider = OpenRouterProvider(api_key=openrouter_api_key) + + class FindEducationContentFilters(BaseModel): + title: str | None = None + + model = OpenRouterModel('anthropic/claude-sonnet-4.5', provider=provider) + response = await model_request( + model, + [ModelRequest.user_text_prompt('Can you find me any education content?')], + model_request_parameters=ModelRequestParameters( + function_tools=[ + ToolDefinition( + name='find_education_content', + description='', + parameters_json_schema=FindEducationContentFilters.model_json_schema(), + ) + ], + allow_text_output=True, # Allow model to either use tools or respond directly + ), + ) + + assert len(response.parts) == 2 + + tool_call_part = response.parts[1] + assert isinstance(tool_call_part, ToolCallPart) + assert tool_call_part.tool_call_id == snapshot('toolu_vrtx_015QAXScZzRDPttiPoc34AdD') + assert tool_call_part.tool_name == 'find_education_content' + assert tool_call_part.args == snapshot(None) + + mapped_messages = await model._map_messages([response], None) # type: ignore[reportPrivateUsage] + tool_call_message = mapped_messages[0] + assert tool_call_message['role'] == 'assistant' + assert tool_call_message.get('content') == snapshot("I'll search for education content for you.") + assert tool_call_message.get('tool_calls') == snapshot( + [ + { + 'id': 'toolu_vrtx_015QAXScZzRDPttiPoc34AdD', + 'type': 'function', + 'function': { + 'name': 'find_education_content', + 'arguments': '{}', + }, + } + ] + )