diff --git a/examples/run-examples.py b/examples/run-examples.py index d470d59..cbfd510 100755 --- a/examples/run-examples.py +++ b/examples/run-examples.py @@ -43,6 +43,7 @@ class Sample: Sample("agent_custom_loop.py"), Sample("agent_nested.py"), Sample("streaming_tool.py"), + Sample("openai_chat_completions.py"), Sample("explicit_client.py"), Sample("multimodal_input.py"), Sample("check_connection.py"), diff --git a/examples/samples/openai_chat_completions.py b/examples/samples/openai_chat_completions.py new file mode 100644 index 0000000..770cffc --- /dev/null +++ b/examples/samples/openai_chat_completions.py @@ -0,0 +1,39 @@ +"""OpenAI Chat Completions protocol — stream text from GPT-5.5.""" + +import asyncio + +import ai +from ai.providers.openai import OpenAIChatCompletionsProtocol + +messages = [ + ai.system_message("Be concise."), + ai.user_message( + "Explain what the OpenAI Chat Completions API is in two sentences." + ), +] + + +async def main() -> None: + provider = ai.get_provider("openai") + if not provider.is_configured(): + print(f"[SKIP] {provider.name} provider is not configured") + return + + model = ai.Model("gpt-5.5", provider=provider) + + try: + async with ai.stream( + model, + messages, + protocol=OpenAIChatCompletionsProtocol(), + ) as stream: + async for event in stream: + if isinstance(event, ai.events.TextDelta): + print(event.chunk, end="", flush=True) + print() + finally: + await provider.aclose() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 88ca7de..6449ff8 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -51,6 +51,7 @@ ImageParams, Model, Provider, + ProviderProtocol, Stream, VideoParams, generate, @@ -109,6 +110,7 @@ "UnsupportedProviderError", "Model", "Provider", + "ProviderProtocol", "ImageParams", "VideoParams", "Stream", diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index a0ff827..f2a2cb4 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -32,7 +32,7 @@ ids = await ai.get_provider("openai").list_models() """ -from ..providers.base import Provider +from ..providers.base import Provider, ProviderProtocol from .core.api import ( Executor, GenerateExecutor, @@ -56,6 +56,7 @@ "ImageParams", "Model", "Provider", + "ProviderProtocol", "Stream", "StreamExecutor", "StreamRequest", diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index 0e45932..c9129cc 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -1,8 +1,19 @@ +from __future__ import annotations + import contextlib import dataclasses from collections.abc import AsyncGenerator, AsyncIterator, Sequence from contextlib import AbstractAsyncContextManager -from typing import Any, Generic, Protocol, Self, cast, overload, runtime_checkable +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Protocol, + Self, + cast, + overload, + runtime_checkable, +) import pydantic @@ -15,6 +26,9 @@ from . import model as model_ from . import params as params_ +if TYPE_CHECKING: + from ...providers import base as provider_base + # Stream output type. Defaults to ``str``: when the stream was opened # without an ``output_type``, ``Stream.output`` returns the concatenated # message text. @@ -28,6 +42,7 @@ class StreamRequest: tools: Sequence[types.tools.Tool] | None = None output_type: type[pydantic.BaseModel] | None = None params: Any = None + protocol: provider_base.ProviderProtocol[Any] | None = None @dataclasses.dataclass(frozen=True) @@ -35,6 +50,7 @@ class GenerateRequest: model: model_.Model messages: list[types.messages.Message] params: params_.GenerateParams + protocol: provider_base.ProviderProtocol[Any] | None = None @runtime_checkable @@ -65,6 +81,7 @@ async def _do_stream( tools=request.tools, output_type=request.output_type, params=request.params, + protocol=request.protocol, ): yield ev @@ -73,6 +90,7 @@ async def _do_generate(self, request: GenerateRequest) -> types.messages.Message request.model, request.messages, request.params, + protocol=request.protocol, ) @@ -347,6 +365,7 @@ def stream( *, context: StreamContext, params: Any = None, + protocol: provider_base.ProviderProtocol[Any] | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[str]]: ... @overload @@ -355,6 +374,7 @@ def stream[T: pydantic.BaseModel]( context: StreamContext, output_type: type[T], params: Any = None, + protocol: provider_base.ProviderProtocol[Any] | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[T]]: ... @overload @@ -364,6 +384,7 @@ def stream( *, tools: Sequence[types.tools.Tool] | None = None, params: Any = None, + protocol: provider_base.ProviderProtocol[Any] | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[str]]: ... @overload @@ -374,6 +395,7 @@ def stream[T: pydantic.BaseModel]( tools: Sequence[types.tools.Tool] | None = None, output_type: type[T], params: Any = None, + protocol: provider_base.ProviderProtocol[Any] | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[T]]: ... def stream( @@ -384,6 +406,7 @@ def stream( tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, params: Any = None, + protocol: provider_base.ProviderProtocol[Any] | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[Any]]: """Stream an LLM response. @@ -420,6 +443,7 @@ def stream( tools=tools, output_type=output_type, params=params, + protocol=protocol, executor=executor, ) @@ -432,6 +456,7 @@ async def _stream( tools: Sequence[types.tools.Tool] | None, output_type: type[pydantic.BaseModel] | None, params: Any, + protocol: provider_base.ProviderProtocol[Any] | None, executor: StreamExecutor, ) -> AsyncIterator[Stream[Any]]: if messages and messages[-1].replay: @@ -443,13 +468,7 @@ async def _stream( ) else: prepared = integrity.prepare_messages(messages) - request = StreamRequest( - model, - prepared, - tools, - output_type, - params, - ) + request = StreamRequest(model, prepared, tools, output_type, params, protocol) s = Stream(executor._do_stream(request), output_type=output_type) try: yield s @@ -462,11 +481,12 @@ async def generate( messages: list[types.messages.Message], params: params_.GenerateParams, *, + protocol: provider_base.ProviderProtocol[Any] | None = None, executor: GenerateExecutor = _default_executor, ) -> types.messages.Message: """Generate a non-streaming response (images, video, etc.).""" messages = integrity.prepare_messages(messages) - request = GenerateRequest(model, messages, params) + request = GenerateRequest(model, messages, params, protocol) return await executor._do_generate(request) diff --git a/src/ai/models/core/model.py b/src/ai/models/core/model.py index fd81d96..d8902bb 100644 --- a/src/ai/models/core/model.py +++ b/src/ai/models/core/model.py @@ -1,6 +1,7 @@ """Model metadata types.""" import os +from typing import Any from ... import _modelsdev from ...errors import ConfigurationError @@ -13,8 +14,8 @@ class Model: """Lightweight reference to a model on a specific provider. * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-6"``). - * ``adapter`` — wire protocol key (e.g. ``"ai-gateway-v3"``, ``"anthropic"``). * ``provider`` — :class:`Provider` that owns this model. + * ``protocol`` — optional wire-protocol override for this model. """ def __init__( @@ -22,31 +23,32 @@ def __init__( id: str, *, provider: base.Provider, - adapter: str | None = None, + protocol: base.ProviderProtocol[Any] | None = None, ) -> None: self.id = id self.provider = provider - self.adapter = adapter or provider.adapter + self.protocol = protocol def __eq__(self, other: object) -> bool: return ( isinstance(other, Model) and self.id == other.id - and self.adapter == other.adapter and self.provider is other.provider + and self.protocol is other.protocol ) def __repr__(self) -> str: - return ( - f"Model(id={self.id!r}, adapter={self.adapter!r}, " - f"provider={self.provider!r})" - ) + return f"Model(id={self.id!r}, provider={self.provider!r})" def __hash__(self) -> int: - return hash((self.id, self.adapter, id(self.provider))) + return hash((self.id, id(self.provider), id(self.protocol))) -def get_model(model_id: str | None = None) -> Model: +def get_model( + model_id: str | None = None, + *, + protocol: base.ProviderProtocol[Any] | None = None, +) -> Model: """Resolve a model ID into a :class:`Model`. Args: @@ -56,6 +58,9 @@ def get_model(model_id: str | None = None) -> Model: Vercel AI Gateway. Examples: ``"openai:gpt-5"`` or ``"anthropic/claude-sonnet-4"``. When omitted, reads ``AI_SDK_DEFAULT_MODEL`` from the environment. + protocol: + Optional wire-protocol override for this model. When omitted, + the provider chooses its default protocol. Raises: Raises :class:`ai.ConfigurationError` when ``model_id`` and ``AI_SDK_DEFAULT_MODEL`` is empty or malformed. @@ -89,4 +94,4 @@ def get_model(model_id: str | None = None) -> Model: model_provider_config=model_provider_config, ) - return Model(provider_model_id, provider=provider) + return Model(provider_model_id, provider=provider, protocol=protocol) diff --git a/src/ai/providers/__init__.py b/src/ai/providers/__init__.py index 7ed6aba..5530e4d 100644 --- a/src/ai/providers/__init__.py +++ b/src/ai/providers/__init__.py @@ -2,7 +2,7 @@ from .ai_gateway import GatewayProvider from .anthropic import AnthropicCompatibleProvider -from .base import Provider, get_provider +from .base import Provider, ProviderProtocol, get_provider from .openai import OpenAICompatibleProvider __all__ = [ @@ -10,5 +10,6 @@ "GatewayProvider", "OpenAICompatibleProvider", "Provider", + "ProviderProtocol", "get_provider", ] diff --git a/src/ai/providers/ai_gateway/__init__.py b/src/ai/providers/ai_gateway/__init__.py index 3b7da49..3f9e71b 100644 --- a/src/ai/providers/ai_gateway/__init__.py +++ b/src/ai/providers/ai_gateway/__init__.py @@ -27,16 +27,15 @@ ) as s: ... -The heavy ``.protocol`` module is loaded lazily by provider methods so that -``import ai`` does not pull in extra I/O code at import time. This matters -for sandboxed runtimes (e.g. Temporal workflow workers). """ from . import errors, tools +from .protocol import GatewayV3Protocol from .provider import GatewayProvider __all__ = [ "GatewayProvider", + "GatewayV3Protocol", "errors", "tools", ] diff --git a/src/ai/providers/ai_gateway/protocol.py b/src/ai/providers/ai_gateway/protocol.py index b36f888..76d6546 100644 --- a/src/ai/providers/ai_gateway/protocol.py +++ b/src/ai/providers/ai_gateway/protocol.py @@ -13,6 +13,7 @@ from ... import types from ...models import core +from .. import base from ..anthropic import tools as anthropic_tools from ..openai import tools as openai_tools from . import client as gateway_client @@ -635,3 +636,40 @@ async def generate( return await _generate_image(gateway, model, messages, params) except client_errors.GatewayError as exc: raise errors.map_error(exc) from exc + + +class GatewayV3Protocol(base.ProviderProtocol[gateway_client.GatewayClient]): + """AI Gateway v3 wire protocol.""" + + def stream( + self, + client: gateway_client.GatewayClient, + model: core.model.Model, + messages: list[types.messages.Message], + *, + tools: Sequence[types.tools.Tool] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + params: Any = None, + provider: str, + ) -> AsyncGenerator[types.events.Event]: + _ = provider + return stream( + client, + model, + messages, + tools=tools, + output_type=output_type, + params=params, + ) + + async def generate( + self, + client: gateway_client.GatewayClient, + model: core.model.Model, + messages: list[types.messages.Message], + params: core.GenerateParams, + *, + provider: str, + ) -> types.messages.Message: + _ = provider + return await generate(client, model, messages, params) diff --git a/src/ai/providers/ai_gateway/provider.py b/src/ai/providers/ai_gateway/provider.py index b6bb6a3..131844e 100644 --- a/src/ai/providers/ai_gateway/provider.py +++ b/src/ai/providers/ai_gateway/provider.py @@ -14,6 +14,7 @@ from .. import base from . import client as gateway_client from . import errors +from . import protocol as protocol_module from .client import errors as client_errors if TYPE_CHECKING: @@ -43,11 +44,12 @@ def __init__( headers: Mapping[str, str] | None = None, env: Mapping[str, str] | None = None, client: httpx.AsyncClient | None = None, + protocol: base.ProviderProtocol[Any] | None = None, ) -> None: super().__init__( name="ai-gateway", - adapter="ai-gateway-v3", base_url=base_url, + protocol=protocol or protocol_module.GatewayV3Protocol(), api_key=api_key, api_key_env=_API_KEY_ENV, headers=headers, @@ -82,17 +84,16 @@ def stream( tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, params: Any = None, + protocol: base.ProviderProtocol[Any] | None = None, ) -> AsyncGenerator[events.Event]: """Stream via the AI Gateway v3 protocol.""" - from . import protocol - - return protocol.stream( - self.client, + return super().stream( model, messages, tools=tools, output_type=output_type, params=params, + protocol=protocol, ) async def generate( @@ -100,11 +101,11 @@ async def generate( model: model_.Model, messages: list[messages_.Message], params: params_.GenerateParams, + *, + protocol: base.ProviderProtocol[Any] | None = None, ) -> messages_.Message: """Generate media via the AI Gateway v3 protocol.""" - from . import protocol - - return await protocol.generate(self.client, model, messages, params) + return await super().generate(model, messages, params, protocol=protocol) @classmethod def from_modelsdev_provider( @@ -117,6 +118,7 @@ def from_modelsdev_provider( headers: Mapping[str, str] | None = None, env: Mapping[str, str] | None = None, client: httpx.AsyncClient | None = None, + protocol: base.ProviderProtocol[Any] | None = None, ) -> base.Provider[gateway_client.GatewayClient]: return cls( api_key=api_key, @@ -124,6 +126,7 @@ def from_modelsdev_provider( headers=headers, env=env, client=client, + protocol=protocol, ) @property diff --git a/src/ai/providers/anthropic/__init__.py b/src/ai/providers/anthropic/__init__.py index 73192c1..a183b55 100644 --- a/src/ai/providers/anthropic/__init__.py +++ b/src/ai/providers/anthropic/__init__.py @@ -22,6 +22,7 @@ """ from . import tools +from .protocol import AnthropicMessagesProtocol from .provider import AnthropicCompatibleProvider -__all__ = ["AnthropicCompatibleProvider", "tools"] +__all__ = ["AnthropicCompatibleProvider", "AnthropicMessagesProtocol", "tools"] diff --git a/src/ai/providers/anthropic/protocol.py b/src/ai/providers/anthropic/protocol.py index 9e83d63..045c8f1 100644 --- a/src/ai/providers/anthropic/protocol.py +++ b/src/ai/providers/anthropic/protocol.py @@ -16,6 +16,7 @@ from ... import types from ...models import core from ...types import events +from .. import base from . import _sdk, errors from . import tools as anthropic_tools @@ -583,3 +584,28 @@ async def stream( provider=provider, model_id=model.id, ) from exc + + +class AnthropicMessagesProtocol(base.ProviderProtocol[Any]): + """Anthropic Messages API protocol.""" + + def stream( + self, + client: anthropic.AsyncAnthropic, + model: core.model.Model, + messages: list[types.messages.Message], + *, + tools: Sequence[types.tools.Tool] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + params: Any = None, + provider: str, + ) -> AsyncGenerator[events.Event]: + return stream( + client, + model, + messages, + tools=tools, + output_type=output_type, + params=params, + provider=provider, + ) diff --git a/src/ai/providers/anthropic/provider.py b/src/ai/providers/anthropic/provider.py index c06a935..1d9623f 100644 --- a/src/ai/providers/anthropic/provider.py +++ b/src/ai/providers/anthropic/provider.py @@ -10,7 +10,8 @@ from ... import errors as ai_errors from .. import base -from . import _sdk, errors, protocol +from . import _sdk, errors +from . import protocol as protocol_module from . import tools as tools_module if TYPE_CHECKING: @@ -53,6 +54,7 @@ def __init__( headers: Mapping[str, str] | None = None, env: Mapping[str, str] | None = None, client: AnthropicClient | None = None, + protocol: base.ProviderProtocol[Any] | None = None, ) -> None: anthropic_sdk = None if client is not None and not isinstance(client, httpx.AsyncClient): @@ -76,8 +78,8 @@ def __init__( super().__init__( name=name, - adapter="anthropic", base_url=default_base_url, + protocol=protocol or protocol_module.AnthropicMessagesProtocol(), api_key=api_key, api_key_env=api_key_env, base_url_env=base_url_env, @@ -132,16 +134,16 @@ def stream( tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, params: Any = None, + protocol: base.ProviderProtocol[Any] | None = None, ) -> AsyncGenerator[events.Event]: """Stream via the Anthropic messages protocol.""" - return protocol.stream( - self.sdk_client, + return super().stream( model, messages, tools=tools, output_type=output_type, params=params, - provider=self.name, + protocol=protocol, ) @classmethod @@ -155,6 +157,7 @@ def from_modelsdev_provider( headers: Mapping[str, str] | None = None, env: Mapping[str, str] | None = None, client: AnthropicClient | None = None, + protocol: base.ProviderProtocol[Any] | None = None, ) -> base.Provider[AnthropicSDKClient]: resolved_base_url = base_url or base.provider_base_url( provider, @@ -177,6 +180,7 @@ def from_modelsdev_provider( headers=headers, env=env, client=client, + protocol=protocol, ) @property diff --git a/src/ai/providers/base.py b/src/ai/providers/base.py index 8f8d61d..a6cf250 100644 --- a/src/ai/providers/base.py +++ b/src/ai/providers/base.py @@ -24,12 +24,46 @@ ClientT = TypeVar("ClientT", default=Any) +class ProviderProtocol(Generic[ClientT]): + """Interface implemented by provider wire protocols.""" + + def stream( + self, + client: ClientT, + model: model_.Model, + messages: list[messages_.Message], + *, + tools: Sequence[tools_.Tool] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + params: Any = None, + provider: str, + ) -> AsyncGenerator[events.Event]: + """Stream a language-model response using *client*.""" + raise NotImplementedError( + f"protocol {type(self).__name__!r} does not support stream()" + ) + + async def generate( + self, + client: ClientT, + model: model_.Model, + messages: list[messages_.Message], + params: params_.GenerateParams, + *, + provider: str, + ) -> messages_.Message: + """Generate a non-streaming response using *client*.""" + raise NotImplementedError( + f"protocol {type(self).__name__!r} does not support generate()" + ) + + class Provider(Generic[ClientT]): """Base class for model providers. A provider carries provider-specific configuration and a shared upstream client: API endpoint, authentication, and model enumeration. Model objects - hold metadata (``id``, ``adapter``) plus a back-reference to their provider. + hold metadata plus a back-reference to their provider. """ handles: ClassVar[tuple[str, ...]] = () @@ -46,8 +80,8 @@ def __init__( self, *, name: str, - adapter: str, base_url: str, + protocol: ProviderProtocol[ClientT] | None = None, api_key: str | None = None, api_key_env: str | None = None, base_url_env: str | None = None, @@ -59,8 +93,8 @@ def __init__( if type(self) is Provider: raise TypeError("Provider is a base class; implement a subclass instead") self._name = name - self._adapter = adapter self._base_url = base_url + self._protocol = protocol self._api_key = api_key self._api_key_env = api_key_env self._base_url_env = base_url_env @@ -141,11 +175,6 @@ async def aclose(self) -> None: """Close provider-owned resources, if any.""" return None - @property - def adapter(self) -> str: - """Provider protocol key used in model metadata and reprs.""" - return self._adapter - @property def config_envs(self) -> tuple[str, ...]: """Additional env vars used to configure the provider client.""" @@ -156,6 +185,13 @@ def name(self) -> str: """Human-readable provider name (for repr, error messages).""" return self._name + @property + def protocol(self) -> ProviderProtocol[ClientT]: + """Default wire protocol used by this provider.""" + if self._protocol is None: + raise RuntimeError(f"provider {self.name!r} does not have a protocol") + return self._protocol + async def list_models(self) -> list[str]: """List available model IDs from the provider API.""" raise NotImplementedError @@ -168,18 +204,37 @@ def stream( tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, params: Any = None, + protocol: ProviderProtocol[Any] | None = None, ) -> AsyncGenerator[events.Event]: """Stream a language-model response from this provider.""" - raise NotImplementedError(f"provider {self.name!r} does not support stream()") + selected_protocol = protocol or model.protocol or self.protocol + return selected_protocol.stream( + self.client, + model, + messages, + tools=tools, + output_type=output_type, + params=params, + provider=self.name, + ) async def generate( self, model: model_.Model, messages: list[messages_.Message], params: params_.GenerateParams, + *, + protocol: ProviderProtocol[Any] | None = None, ) -> messages_.Message: """Generate a non-streaming response from this provider.""" - raise NotImplementedError(f"provider {self.name!r} does not support generate()") + selected_protocol = protocol or model.protocol or self.protocol + return await selected_protocol.generate( + self.client, + model, + messages, + params, + provider=self.name, + ) async def probe(self, model: model_.Model) -> None: """Probe if provider is online and can serve given model. @@ -210,6 +265,7 @@ def from_id( headers: Mapping[str, str] | None = None, env: Mapping[str, str] | None = None, client: Any | None = None, + protocol: ProviderProtocol[Any] | None = None, ) -> Provider[Any]: """Return a concrete provider for a models.dev provider ID.""" modelsdev_provider = _modelsdev.get_provider_by_id(known_id) @@ -230,6 +286,7 @@ def from_id( headers=headers, env=env, client=client, + protocol=protocol, ) raise UnsupportedProviderError(modelsdev_provider.id) @@ -245,6 +302,7 @@ def from_modelsdev_provider( headers: Mapping[str, str] | None = None, env: Mapping[str, str] | None = None, client: Any | None = None, + protocol: ProviderProtocol[Any] | None = None, ) -> Provider[Any]: """Construct this provider implementation from models.dev metadata.""" raise NotImplementedError @@ -261,6 +319,7 @@ def get_provider( headers: Mapping[str, str] | None = None, env: Mapping[str, str] | None = None, client: ClientT | None = None, + protocol: ProviderProtocol[ClientT] | None = None, ) -> Provider[ClientT]: """Create a provider from a models.dev provider ID.""" return Provider.from_id( @@ -270,6 +329,7 @@ def get_provider( headers=headers, env=env, client=client, + protocol=protocol, ) diff --git a/src/ai/providers/openai/__init__.py b/src/ai/providers/openai/__init__.py index 6c12119..b590d20 100644 --- a/src/ai/providers/openai/__init__.py +++ b/src/ai/providers/openai/__init__.py @@ -14,6 +14,12 @@ """ from . import tools +from .protocol import OpenAIChatCompletionsProtocol, OpenAIResponsesProtocol from .provider import OpenAICompatibleProvider -__all__ = ["OpenAICompatibleProvider", "tools"] +__all__ = [ + "OpenAIChatCompletionsProtocol", + "OpenAICompatibleProvider", + "OpenAIResponsesProtocol", + "tools", +] diff --git a/src/ai/providers/openai/protocol.py b/src/ai/providers/openai/protocol.py index e67a010..b18bde7 100644 --- a/src/ai/providers/openai/protocol.py +++ b/src/ai/providers/openai/protocol.py @@ -1,20 +1,24 @@ -"""OpenAI protocol — chat completions API. +"""OpenAI-compatible wire protocols. Message/tool conversion and streaming via the official ``openai`` SDK. -OpenAI-compatible providers own the SDK client used by this protocol. +OpenAI-compatible providers own the SDK client used by these protocols. """ from __future__ import annotations import base64 +import json from collections.abc import AsyncGenerator, Mapping, Sequence from typing import TYPE_CHECKING, Any import pydantic +from ... import errors as ai_errors from ... import types from ...models import core +from .. import base from . import _sdk, errors +from . import tools as openai_tools if TYPE_CHECKING: import openai @@ -376,3 +380,1012 @@ async def stream( provider=provider, model_id=model.id, ) from exc + + +class OpenAIChatCompletionsProtocol(base.ProviderProtocol[Any]): + """OpenAI Chat Completions protocol.""" + + def stream( + self, + client: openai.AsyncOpenAI, + model: core.model.Model, + messages: list[types.messages.Message], + *, + tools: Sequence[types.tools.Tool] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + params: Any = None, + provider: str, + ) -> AsyncGenerator[types.events.Event]: + return stream( + client, + model, + messages, + tools=tools, + output_type=output_type, + params=params, + provider=provider, + ) + + +_OPENAI_METADATA_KEY = "openai" +_RESPONSES_PROTECTED_PARAMS = frozenset({"model", "input", "stream"}) +_BUILTIN_OUTPUT_TYPES = frozenset( + { + "web_search_call", + "file_search_call", + "code_interpreter_call", + "image_generation_call", + "mcp_call", + "mcp_approval_request", + "local_shell_call", + "shell_call", + "shell_call_output", + "apply_patch_call", + "tool_search_call", + "tool_search_output", + "computer_call", + } +) + + +def _coerce_responses_params(value: Any) -> dict[str, Any]: + if value is None: + return {} + if isinstance(value, Mapping): + return dict(value) + raise TypeError("openai responses stream params must be a dict") + + +def _json_dumps(value: Any) -> str: + return json.dumps(value, separators=(",", ":"), default=str) + + +def _model_dump(value: pydantic.BaseModel) -> dict[str, Any]: + return value.model_dump(exclude_none=True) + + +def _openai_metadata(part: Any) -> dict[str, Any]: + metadata = getattr(part, "provider_metadata", None) + if not isinstance(metadata, Mapping): + return {} + openai_metadata = metadata.get(_OPENAI_METADATA_KEY) + if not isinstance(openai_metadata, Mapping): + return {} + return dict(openai_metadata) + + +def _metadata_item_id(metadata: Mapping[str, Any]) -> str | None: + value = metadata.get("item_id") or metadata.get("itemId") + return value if isinstance(value, str) else None + + +def _provider_metadata_for_item( + item: Mapping[str, Any], + **extra: Any, +) -> dict[str, Any]: + item_id = item.get("id") + data = { + "raw_item": dict(item), + **({"item_id": item_id} if isinstance(item_id, str) else {}), + **{k: v for k, v in extra.items() if v is not None}, + } + return {_OPENAI_METADATA_KEY: data} + + +def _provider_metadata_for_response(response: Mapping[str, Any]) -> dict[str, Any]: + response_id = response.get("id") + model = response.get("model") + status = response.get("status") + data = { + **({"response_id": response_id} if isinstance(response_id, str) else {}), + **({"model": model} if isinstance(model, str) else {}), + **({"status": status} if isinstance(status, str) else {}), + } + return {_OPENAI_METADATA_KEY: data} if data else {} + + +def _maybe_item_reference( + part: Any, + *, + use_item_references: bool, +) -> dict[str, Any] | None: + if not use_item_references: + return None + item_id = _metadata_item_id(_openai_metadata(part)) + if item_id is None: + return None + return {"type": "item_reference", "id": item_id} + + +def _raw_item_from_metadata(part: Any) -> dict[str, Any] | None: + raw_item = _openai_metadata(part).get("raw_item") + if isinstance(raw_item, Mapping): + return dict(raw_item) + return None + + +def _stringify_tool_result(result: Any) -> str: + if result is None: + return "" + if isinstance(result, str): + return result + return _json_dumps(result) + + +async def _file_part_to_responses( + part: types.messages.FilePart, +) -> dict[str, Any]: + media_type = "image/jpeg" if part.media_type == "image/*" else part.media_type + data = part.data + + if media_type.startswith("image/"): + return { + "type": "input_image", + "image_url": types.media.data_to_data_url(data, media_type), + } + + if media_type == "application/pdf": + if isinstance(data, str) and types.media.is_downloadable_url(data): + return {"type": "input_file", "file_url": data} + return { + "type": "input_file", + "filename": part.filename or "document.pdf", + "file_data": types.media.data_to_data_url(data, media_type), + } + + if media_type.startswith("text/"): + if isinstance(data, bytes): + text_content = data.decode("utf-8") + elif types.media.is_url(data): + text_content = data + else: + text_content = base64.b64decode(data).decode("utf-8") + return {"type": "input_text", "text": text_content} + + raise ValueError(f"Unsupported media type for OpenAI Responses: {media_type}") + + +async def _messages_to_responses( + messages: list[types.messages.Message], + *, + use_item_references: bool, +) -> list[dict[str, Any]]: + result: list[dict[str, Any]] = [] + + for msg in messages: + match msg.role: + case "system": + text = "".join( + p.text for p in msg.parts if isinstance(p, types.messages.TextPart) + ) + if text: + result.append({"role": "system", "content": text}) + + case "user": + content: list[dict[str, Any]] = [] + for part in msg.parts: + match part: + case types.messages.TextPart(text=text): + content.append({"type": "input_text", "text": text}) + case types.messages.FilePart(): + content.append(await _file_part_to_responses(part)) + result.append({"role": "user", "content": content}) + + case "assistant": + assistant_content: list[dict[str, Any]] = [] + + for part in msg.parts: + if item_reference := _maybe_item_reference( + part, + use_item_references=use_item_references, + ): + _flush_assistant_content(result, assistant_content) + result.append(item_reference) + continue + + if raw_item := _raw_item_from_metadata(part): + _flush_assistant_content(result, assistant_content) + result.append(raw_item) + continue + + match part: + case types.messages.TextPart(text=text): + assistant_content.append( + {"type": "output_text", "text": text} + ) + case types.messages.ReasoningPart(text=text): + _flush_assistant_content(result, assistant_content) + metadata = _openai_metadata(part) + encrypted_content = metadata.get( + "reasoning_encrypted_content" + ) or metadata.get("reasoningEncryptedContent") + if encrypted_content is not None: + result.append( + { + "type": "reasoning", + "summary": [ + {"type": "summary_text", "text": text} + ], + "encrypted_content": encrypted_content, + } + ) + case types.messages.ToolCallPart(): + _flush_assistant_content(result, assistant_content) + result.append( + { + "type": "function_call", + "call_id": part.tool_call_id, + "name": part.tool_name, + "arguments": part.tool_args, + } + ) + case ( + types.messages.BuiltinToolCallPart() + | types.messages.BuiltinToolReturnPart() + ): + _flush_assistant_content(result, assistant_content) + + _flush_assistant_content(result, assistant_content) + + case "tool": + for part in msg.parts: + if isinstance(part, types.messages.ToolResultPart): + result.append( + { + "type": "function_call_output", + "call_id": part.tool_call_id, + "output": _stringify_tool_result( + part.get_model_input() + ), + } + ) + + case "internal": + continue + + return result + + +def _flush_assistant_content( + result: list[dict[str, Any]], + assistant_content: list[dict[str, Any]], +) -> None: + if not assistant_content: + return + result.append({"role": "assistant", "content": list(assistant_content)}) + assistant_content.clear() + + +def _tools_to_responses( + tools: Sequence[types.tools.Tool], +) -> list[dict[str, Any]]: + result: list[dict[str, Any]] = [] + + for tool in tools: + if tool.kind == "function": + args = tool.args + if not isinstance(args, types.tools.FunctionToolArgs): + raise TypeError(f"function tool {tool.name!r} has invalid args") + result.append( + { + "type": "function", + "name": tool.name, + "description": args.description or "", + "parameters": args.params, + } + ) + continue + + args = tool.args + tool_id = getattr(type(args), "openai_id", None) + if not isinstance(args, openai_tools.OpenAIProviderArgs): + raise TypeError(f"provider tool {tool.name!r} is not an OpenAI tool") + + match tool_id: + case "openai.web_search": + result.append({"type": "web_search", **_model_dump(args)}) + case "openai.web_search_preview": + result.append({"type": "web_search_preview", **_model_dump(args)}) + case "openai.file_search": + data = _model_dump(args) + ranking = data.pop("ranking", None) + if ranking is not None: + data["ranking_options"] = ranking + result.append({"type": "file_search", **data}) + case "openai.code_interpreter": + data = _model_dump(args) + if "container" not in data: + data["container"] = {"type": "auto"} + result.append({"type": "code_interpreter", **data}) + case "openai.image_generation": + result.append({"type": "image_generation", **_model_dump(args)}) + case "openai.local_shell": + result.append({"type": "local_shell"}) + case "openai.shell": + result.append({"type": "shell", **_model_dump(args)}) + case "openai.apply_patch": + result.append({"type": "apply_patch"}) + case "openai.mcp": + result.append({"type": "mcp", **_model_dump(args)}) + case "openai.tool_search": + result.append({"type": "tool_search", **_model_dump(args)}) + case _: + raise NotImplementedError(f"unsupported OpenAI provider tool {tool_id}") + + return result + + +def _event_to_dict(event: Any) -> dict[str, Any]: + if isinstance(event, Mapping): + return dict(event) + if hasattr(event, "model_dump"): + dumped = event.model_dump(exclude_none=True, mode="json") + return dict(dumped) if isinstance(dumped, Mapping) else {} + if hasattr(event, "to_dict"): + dumped = event.to_dict() + return dict(dumped) if isinstance(dumped, Mapping) else {} + return { + key: value + for key in dir(event) + if not key.startswith("_") and not callable(value := getattr(event, key, None)) + } + + +def _usage_from_response(response: Mapping[str, Any]) -> types.usage.Usage | None: + usage = response.get("usage") + if not isinstance(usage, Mapping): + return None + + input_details = usage.get("input_tokens_details") + output_details = usage.get("output_tokens_details") + if not isinstance(input_details, Mapping): + input_details = {} + if not isinstance(output_details, Mapping): + output_details = {} + + return types.usage.Usage( + input_tokens=int(usage.get("input_tokens") or 0), + output_tokens=int(usage.get("output_tokens") or 0), + reasoning_tokens=output_details.get("reasoning_tokens"), + cache_read_tokens=input_details.get("cached_tokens"), + raw=dict(usage), + ) + + +def _image_media_type( + params: Mapping[str, Any], + tools: Sequence[types.tools.Tool], +) -> str: + for tool in tools: + if isinstance(tool.args, openai_tools.ImageGenerationArgs): + fmt = str(tool.args.output_format or "png") + return "image/jpeg" if fmt == "jpeg" else f"image/{fmt}" + text = params.get("text") + if isinstance(text, Mapping): + text_fmt = text.get("output_format") + if isinstance(text_fmt, str): + return "image/jpeg" if text_fmt == "jpeg" else f"image/{text_fmt}" + return "image/png" + + +def _state_key(item: Mapping[str, Any], data: Mapping[str, Any]) -> str: + value = item.get("id") or item.get("call_id") or data.get("item_id") + if isinstance(value, str): + return value + output_index = data.get("output_index") + return str(output_index) if output_index is not None else "" + + +def _builtin_tool_name(item: Mapping[str, Any]) -> str: + item_type = item.get("type") + match item_type: + case "web_search_call": + return "web_search" + case "file_search_call": + return "file_search" + case "code_interpreter_call": + return "code_interpreter" + case "image_generation_call": + return "image_generation" + case "mcp_call" | "mcp_approval_request": + name = item.get("name") + return f"mcp.{name}" if isinstance(name, str) else "mcp" + case "local_shell_call": + return "local_shell" + case "shell_call" | "shell_call_output": + return "shell" + case "apply_patch_call": + return "apply_patch" + case "tool_search_call" | "tool_search_output": + return "tool_search" + case "computer_call": + return "computer_use" + case _: + return str(item_type or "") + + +def _builtin_tool_call_id(item: Mapping[str, Any]) -> str: + value = item.get("call_id") or item.get("id") + return value if isinstance(value, str) else "" + + +def _builtin_tool_args(item: Mapping[str, Any]) -> str: + item_type = item.get("type") + match item_type: + case "code_interpreter_call": + return _json_dumps( + {"code": item.get("code"), "container_id": item.get("container_id")} + ) + case "mcp_call" | "mcp_approval_request": + arguments = item.get("arguments") + return arguments if isinstance(arguments, str) else _json_dumps(arguments) + case "local_shell_call" | "shell_call": + return _json_dumps({"action": item.get("action")}) + case "apply_patch_call": + return _json_dumps( + {"call_id": item.get("call_id"), "operation": item.get("operation")} + ) + case "tool_search_call": + return _json_dumps( + {"arguments": item.get("arguments"), "call_id": item.get("call_id")} + ) + case _: + return "{}" + + +def _builtin_tool_result(item: Mapping[str, Any]) -> Any: + item_type = item.get("type") + match item_type: + case "web_search_call": + return {"action": item.get("action")} + case "file_search_call": + return {"queries": item.get("queries"), "results": item.get("results")} + case "code_interpreter_call": + return { + "container_id": item.get("container_id"), + "outputs": item.get("outputs"), + } + case "image_generation_call": + return {"result": item.get("result")} + case "mcp_call": + return { + "server_label": item.get("server_label"), + "name": item.get("name"), + "arguments": item.get("arguments"), + "output": item.get("output"), + "error": item.get("error"), + } + case "mcp_approval_request": + return { + "server_label": item.get("server_label"), + "name": item.get("name"), + "arguments": item.get("arguments"), + "approval_request_id": item.get("approval_request_id") + or item.get("id"), + } + case "shell_call_output": + return {"output": item.get("output")} + case "tool_search_output": + return {"tools": item.get("tools")} + case "computer_call": + return {"status": item.get("status") or "completed"} + case _: + return None + + +def _index_key(data: Mapping[str, Any]) -> str | None: + output_index = data.get("output_index") + return str(output_index) if output_index is not None else None + + +def _lookup_state( + states_by_item: dict[str, dict[str, Any]], + states_by_index: dict[str, dict[str, Any]], + data: Mapping[str, Any], +) -> dict[str, Any] | None: + item_id = data.get("item_id") + if isinstance(item_id, str) and item_id in states_by_item: + return states_by_item[item_id] + index = _index_key(data) + if index is not None: + return states_by_index.get(index) + return None + + +async def _stream_responses( + sdk_client: openai.AsyncOpenAI, + model: core.model.Model, + messages: list[types.messages.Message], + *, + tools: Sequence[types.tools.Tool] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + params: Any = None, + provider: str, +) -> AsyncGenerator[types.events.Event]: + openai_sdk = _sdk.import_sdk(provider=provider) + stream_params = _coerce_responses_params(params) + protected = sorted(_RESPONSES_PROTECTED_PARAMS & stream_params.keys()) + if protected: + raise ValueError( + "openai responses params cannot override protocol-owned fields: " + + ", ".join(protected) + ) + + request_tools = list(tools or ()) + use_item_references = stream_params.get("store") is not False + response_input = await _messages_to_responses( + messages, + use_item_references=use_item_references, + ) + response_tools = _tools_to_responses(request_tools) if request_tools else None + + api_kwargs: dict[str, Any] = dict(stream_params) + api_kwargs.update({"model": model.id, "input": response_input, "stream": True}) + if response_tools: + api_kwargs["tools"] = response_tools + + if output_type is not None: + openai_pydantic = _sdk.import_pydantic(provider=provider) + text_config = dict(api_kwargs.get("text") or {}) + text_config["format"] = { + "type": "json_schema", + "name": output_type.__name__, + "schema": openai_pydantic.to_strict_json_schema(output_type), + "strict": True, + } + api_kwargs["text"] = text_config + + image_media_type = _image_media_type(stream_params, request_tools) + text_blocks: set[str] = set() + reasoning_blocks: set[str] = set() + reasoning_delta_blocks: set[str] = set() + reasoning_ended_blocks: set[str] = set() + function_states_by_item: dict[str, dict[str, Any]] = {} + function_states_by_index: dict[str, dict[str, Any]] = {} + builtin_states_by_item: dict[str, dict[str, Any]] = {} + builtin_states_by_index: dict[str, dict[str, Any]] = {} + usage: types.usage.Usage | None = None + response_metadata: dict[str, Any] | None = None + + try: + sdk_stream = await sdk_client.responses.create(**api_kwargs) + yield types.events.StreamStart() + + async for sdk_event in sdk_stream: + data = _event_to_dict(sdk_event) + event_type = data.get("type") + + if event_type == "response.created": + response = data.get("response") + if isinstance(response, Mapping): + response_metadata = _provider_metadata_for_response(response) + continue + + if event_type in {"response.completed", "response.incomplete"}: + response = data.get("response") + if isinstance(response, Mapping): + usage = _usage_from_response(response) or usage + response_metadata = _provider_metadata_for_response(response) + continue + + if event_type == "response.failed": + response = data.get("response") + if isinstance(response, Mapping): + usage = _usage_from_response(response) or usage + response_metadata = _provider_metadata_for_response(response) + continue + + if event_type == "error": + error = data.get("error") + if isinstance(error, Mapping): + message = error.get("message") or error.get("code") or error + else: + message = error or data + raise ai_errors.ProviderResponseError(str(message), provider=provider) + + if event_type == "response.output_item.added": + item = data.get("item") + if not isinstance(item, Mapping): + continue + item = dict(item) + item_type = item.get("type") + state_key = _state_key(item, data) + index = _index_key(data) + + if item_type == "message": + block_id = str(item.get("id") or state_key or "text") + text_blocks.add(block_id) + yield types.events.TextStart( + block_id=block_id, + provider_metadata=_provider_metadata_for_item(item), + ) + continue + + if item_type == "reasoning": + block_id = f"{item.get('id') or state_key}:0" + reasoning_blocks.add(block_id) + yield types.events.ReasoningStart( + block_id=block_id, + provider_metadata=_provider_metadata_for_item( + item, + reasoning_encrypted_content=item.get("encrypted_content"), + ), + ) + continue + + if item_type in {"function_call", "custom_tool_call"}: + tool_call_id = str(item.get("call_id") or state_key) + tool_name = str(item.get("name") or "") + new_state: dict[str, Any] = { + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "arguments": "", + "delta_emitted": False, + } + if state_key: + function_states_by_item[state_key] = new_state + if index is not None: + function_states_by_index[index] = new_state + yield types.events.ToolStart( + tool_call_id=tool_call_id, + tool_name=tool_name, + provider_metadata=_provider_metadata_for_item(item), + ) + arguments = item.get("arguments") or item.get("input") + if isinstance(arguments, str) and arguments: + new_state["arguments"] = arguments + new_state["delta_emitted"] = True + yield types.events.ToolDelta( + tool_call_id=tool_call_id, + chunk=arguments, + ) + continue + + if item_type in _BUILTIN_OUTPUT_TYPES: + tool_call_id = _builtin_tool_call_id(item) + tool_name = _builtin_tool_name(item) + new_state = { + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "arguments": "", + "delta_emitted": False, + } + if state_key: + builtin_states_by_item[state_key] = new_state + if index is not None: + builtin_states_by_index[index] = new_state + yield types.events.BuiltinToolStart( + tool_call_id=tool_call_id, + tool_name=tool_name, + provider_metadata=_provider_metadata_for_item(item), + ) + continue + + if event_type == "response.output_text.delta": + block_id = str(data.get("item_id") or "text") + if block_id not in text_blocks: + text_blocks.add(block_id) + yield types.events.TextStart(block_id=block_id) + delta = data.get("delta") + if isinstance(delta, str) and delta: + yield types.events.TextDelta(block_id=block_id, chunk=delta) + continue + + if event_type == "response.output_text.done": + continue + + if event_type in { + "response.function_call_arguments.delta", + "response.custom_tool_call_input.delta", + }: + function_state = _lookup_state( + function_states_by_item, + function_states_by_index, + data, + ) + delta = data.get("delta") + if function_state is not None and isinstance(delta, str) and delta: + function_state["arguments"] += delta + function_state["delta_emitted"] = True + yield types.events.ToolDelta( + tool_call_id=function_state["tool_call_id"], + chunk=delta, + ) + continue + + if event_type in { + "response.function_call_arguments.done", + "response.custom_tool_call_input.done", + }: + function_state = _lookup_state( + function_states_by_item, + function_states_by_index, + data, + ) + arguments = data.get("arguments") or data.get("input") + if ( + function_state is not None + and isinstance(arguments, str) + and arguments + and not function_state["delta_emitted"] + ): + function_state["arguments"] = arguments + function_state["delta_emitted"] = True + yield types.events.ToolDelta( + tool_call_id=function_state["tool_call_id"], + chunk=arguments, + ) + continue + + if event_type == "response.reasoning_summary_part.added": + block_id = f"{data.get('item_id')}:{data.get('summary_index', 0)}" + if block_id not in reasoning_blocks: + reasoning_blocks.add(block_id) + yield types.events.ReasoningStart(block_id=block_id) + continue + + if event_type in { + "response.reasoning_summary_text.delta", + "response.reasoning_text.delta", + }: + block_id = f"{data.get('item_id')}:{data.get('summary_index', 0)}" + if block_id not in reasoning_blocks: + reasoning_blocks.add(block_id) + yield types.events.ReasoningStart(block_id=block_id) + delta = data.get("delta") + if isinstance(delta, str) and delta: + reasoning_delta_blocks.add(block_id) + yield types.events.ReasoningDelta(block_id=block_id, chunk=delta) + continue + + if event_type == "response.reasoning_summary_part.done": + block_id = f"{data.get('item_id')}:{data.get('summary_index', 0)}" + reasoning_blocks.discard(block_id) + reasoning_ended_blocks.add(block_id) + yield types.events.ReasoningEnd(block_id=block_id) + continue + + if event_type == "response.image_generation_call.partial_image": + item_id = str(data.get("item_id") or "") + partial = data.get("partial_image_b64") + if isinstance(partial, str) and partial: + yield types.events.FileEvent( + block_id=f"{item_id}:partial", + media_type=image_media_type, + data=partial, + ) + continue + + if event_type == "response.code_interpreter_call_code.delta": + builtin_state = _lookup_state( + builtin_states_by_item, + builtin_states_by_index, + data, + ) + delta = data.get("delta") + if builtin_state is not None and isinstance(delta, str) and delta: + builtin_state["arguments"] += delta + continue + + if event_type == "response.output_item.done": + item = data.get("item") + if not isinstance(item, Mapping): + continue + item = dict(item) + item_type = item.get("type") + state_key = _state_key(item, data) + index = _index_key(data) + + if item_type == "message": + block_id = str(item.get("id") or state_key or "text") + if block_id in text_blocks: + text_blocks.remove(block_id) + yield types.events.TextEnd( + block_id=block_id, + provider_metadata=_provider_metadata_for_item(item), + ) + continue + + if item_type == "reasoning": + item_id = str(item.get("id") or state_key) + summaries = item.get("summary") + if isinstance(summaries, list) and summaries: + for idx, summary in enumerate(summaries): + block_id = f"{item_id}:{idx}" + if block_id in reasoning_ended_blocks: + continue + if block_id not in reasoning_blocks: + reasoning_blocks.add(block_id) + yield types.events.ReasoningStart( + block_id=block_id, + provider_metadata=_provider_metadata_for_item( + item, + reasoning_encrypted_content=item.get( + "encrypted_content" + ), + ), + ) + if block_id not in reasoning_delta_blocks and isinstance( + summary, Mapping + ): + text = summary.get("text") + if isinstance(text, str) and text: + yield types.events.ReasoningDelta( + block_id=block_id, + chunk=text, + ) + reasoning_blocks.discard(block_id) + reasoning_ended_blocks.add(block_id) + yield types.events.ReasoningEnd( + block_id=block_id, + provider_metadata=_provider_metadata_for_item( + item, + reasoning_encrypted_content=item.get( + "encrypted_content" + ), + ), + ) + else: + for block_id in list(reasoning_blocks): + if block_id.startswith(f"{item_id}:"): + reasoning_blocks.remove(block_id) + yield types.events.ReasoningEnd( + block_id=block_id, + provider_metadata=_provider_metadata_for_item( + item, + reasoning_encrypted_content=item.get( + "encrypted_content" + ), + ), + ) + continue + + if item_type in {"function_call", "custom_tool_call"}: + function_state = ( + function_states_by_item.pop(state_key, None) + if state_key + else None + ) + if function_state is None and index is not None: + function_state = function_states_by_index.get(index) + if index is not None: + function_states_by_index.pop(index, None) + tool_call_id = str(item.get("call_id") or state_key) + tool_name = str(item.get("name") or "") + arguments = item.get("arguments") or item.get("input") + if function_state is None: + yield types.events.ToolStart( + tool_call_id=tool_call_id, + tool_name=tool_name, + provider_metadata=_provider_metadata_for_item(item), + ) + if isinstance(arguments, str) and arguments: + yield types.events.ToolDelta( + tool_call_id=tool_call_id, + chunk=arguments, + ) + elif ( + isinstance(arguments, str) + and arguments + and not function_state["delta_emitted"] + ): + yield types.events.ToolDelta( + tool_call_id=function_state["tool_call_id"], + chunk=arguments, + ) + tool_call_id = function_state["tool_call_id"] + else: + tool_call_id = function_state["tool_call_id"] + yield types.events.ToolEnd( + tool_call_id=tool_call_id, + tool_call=types.messages.DUMMY_TOOL_CALL, + provider_metadata=_provider_metadata_for_item(item), + ) + continue + + if item_type in _BUILTIN_OUTPUT_TYPES: + builtin_state = ( + builtin_states_by_item.pop(state_key, None) + if state_key + else None + ) + if builtin_state is None and index is not None: + builtin_state = builtin_states_by_index.get(index) + if index is not None: + builtin_states_by_index.pop(index, None) + tool_call_id = _builtin_tool_call_id(item) + tool_name = _builtin_tool_name(item) + arguments = _builtin_tool_args(item) + if builtin_state is None: + yield types.events.BuiltinToolStart( + tool_call_id=tool_call_id, + tool_name=tool_name, + provider_metadata=_provider_metadata_for_item(item), + ) + else: + tool_call_id = builtin_state["tool_call_id"] + tool_name = builtin_state["tool_name"] + if arguments and ( + builtin_state is None or not builtin_state["delta_emitted"] + ): + yield types.events.BuiltinToolDelta( + tool_call_id=tool_call_id, + chunk=arguments, + ) + yield types.events.BuiltinToolEnd( + tool_call_id=tool_call_id, + tool_call=types.messages.BuiltinToolCallPart( + tool_call_id=tool_call_id, + tool_name=tool_name, + tool_args=arguments, + provider_metadata=_provider_metadata_for_item(item), + ), + provider_metadata=_provider_metadata_for_item(item), + ) + + if item_type == "image_generation_call": + result = item.get("result") + if isinstance(result, str) and result: + yield types.events.FileEvent( + block_id=str(item.get("id") or tool_call_id), + media_type=image_media_type, + data=result, + provider_metadata=_provider_metadata_for_item(item), + ) + + result_payload = _builtin_tool_result(item) + if result_payload is not None: + yield types.events.BuiltinToolResult( + tool_call_id=tool_call_id, + result=types.messages.BuiltinToolReturnPart( + tool_call_id=tool_call_id, + tool_name=tool_name, + result=result_payload, + provider_metadata=_provider_metadata_for_item(item), + ), + ) + continue + + for block_id in list(text_blocks): + yield types.events.TextEnd(block_id=block_id) + for block_id in list(reasoning_blocks): + yield types.events.ReasoningEnd(block_id=block_id) + + yield types.events.StreamEnd( + usage=usage, + provider_metadata=response_metadata, + ) + except openai_sdk.OpenAIError as exc: + raise errors.map_error(exc, provider=provider, model_id=model.id) from exc + + +class OpenAIResponsesProtocol(base.ProviderProtocol[Any]): + """OpenAI Responses API protocol.""" + + def stream( + self, + client: openai.AsyncOpenAI, + model: core.model.Model, + messages: list[types.messages.Message], + *, + tools: Sequence[types.tools.Tool] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + params: Any = None, + provider: str, + ) -> AsyncGenerator[types.events.Event]: + return _stream_responses( + client, + model, + messages, + tools=tools, + output_type=output_type, + params=params, + provider=provider, + ) + + +def default_protocol(provider: str) -> base.ProviderProtocol[Any]: + """Return the best default OpenAI-compatible protocol for *provider*.""" + if provider == "openai": + return OpenAIResponsesProtocol() + return OpenAIChatCompletionsProtocol() diff --git a/src/ai/providers/openai/provider.py b/src/ai/providers/openai/provider.py index 7777687..58f18e1 100644 --- a/src/ai/providers/openai/provider.py +++ b/src/ai/providers/openai/provider.py @@ -10,7 +10,8 @@ from ... import errors as ai_errors from .. import base -from . import _sdk, errors, protocol +from . import _sdk, errors +from . import protocol as protocol_module from . import tools as tools_module if TYPE_CHECKING: @@ -55,6 +56,7 @@ def __init__( headers: Mapping[str, str] | None = None, env: Mapping[str, str] | None = None, client: OpenAIClient | None = None, + protocol: base.ProviderProtocol[Any] | None = None, ) -> None: openai_sdk = None if client is not None and not isinstance(client, httpx.AsyncClient): @@ -75,8 +77,8 @@ def __init__( super().__init__( name=name, - adapter="openai", base_url=default_base_url, + protocol=protocol or protocol_module.default_protocol(name), api_key=api_key, api_key_env=api_key_env, base_url_env=base_url_env, @@ -127,16 +129,16 @@ def stream( tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, params: Any = None, + protocol: base.ProviderProtocol[Any] | None = None, ) -> AsyncGenerator[events.Event]: - """Stream via the OpenAI chat completions protocol.""" - return protocol.stream( - self.sdk_client, + """Stream via this provider's configured OpenAI-compatible protocol.""" + return super().stream( model, messages, tools=tools, output_type=output_type, params=params, - provider=self.name, + protocol=protocol, ) @classmethod @@ -150,6 +152,7 @@ def from_modelsdev_provider( headers: Mapping[str, str] | None = None, env: Mapping[str, str] | None = None, client: OpenAIClient | None = None, + protocol: base.ProviderProtocol[Any] | None = None, ) -> base.Provider[OpenAISDKClient]: resolved_base_url = base_url or base.provider_base_url( provider, @@ -172,16 +175,16 @@ def from_modelsdev_provider( headers=headers, env=env, client=client, + protocol=protocol, ) @property def tools(self) -> ModuleType: """The provider's built-in tool factories. - Convenience accessor: ``openai.tools.web_search(...)``. The - chat-completions protocol currently raises if a built-in tool is - passed; route via the AI Gateway provider until a Responses - protocol ships. + Convenience accessor: ``openai.tools.web_search(...)``. These tools + require a protocol that supports OpenAI provider-executed tools, such + as :class:`OpenAIResponsesProtocol`. """ return tools_module diff --git a/tests/conftest.py b/tests/conftest.py index aa23c5b..6f6eeb2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,13 +24,11 @@ def __init__( self, *, name: str = "mock", - adapter: str = "mock", base_url: str = "http://mock.test", api_key_env: str | None = "MOCK_API_KEY", ) -> None: super().__init__( name=name, - adapter=adapter, base_url=base_url, api_key_env=api_key_env, ) @@ -48,7 +46,18 @@ def stream( tools: Sequence[ai.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, params: Any = None, + protocol: models.ProviderProtocol[Any] | None = None, ) -> AsyncGenerator[events_.Event]: + if protocol is not None: + return protocol.stream( + None, + model, + messages, + tools=tools, + output_type=output_type, + params=params, + provider=self.name, + ) if self._stream_impl is None: raise RuntimeError("MockProvider: no stream implementation configured") return cast( @@ -67,7 +76,17 @@ async def generate( model: models.Model, messages: list[messages_.Message], params: Any, + *, + protocol: models.ProviderProtocol[Any] | None = None, ) -> messages_.Message: + if protocol is not None: + return await protocol.generate( + None, + model, + messages, + params, + provider=self.name, + ) if self._generate_impl is None: raise RuntimeError("MockProvider: no generate implementation configured") return cast( @@ -77,10 +96,9 @@ async def generate( MOCK_PROVIDER = MockProvider() -# A fixed Model used in tests — adapter="mock" dispatches to the mock adapter. +# A fixed Model used in tests. MOCK_MODEL: models.Model = models.Model( id="mock-model", - adapter="mock", provider=MOCK_PROVIDER, ) diff --git a/tests/models/core/test_api.py b/tests/models/core/test_api.py index 91480c9..37fd72f 100644 --- a/tests/models/core/test_api.py +++ b/tests/models/core/test_api.py @@ -230,11 +230,45 @@ async def test_stream_requires_model_messages_or_context() -> None: pass -async def test_generate_dispatches_to_registered_adapter() -> None: - provider = MockProvider(adapter="mock-generate") +async def test_stream_accepts_protocol_kwarg() -> None: + class OverrideProtocol(models.ProviderProtocol[Any]): + def stream( + self, + client: Any, + model: models.Model, + messages: list[messages_.Message], + *, + tools: Sequence[ai.tools.Tool] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + params: Any = None, + provider: str, + ) -> AsyncGenerator[events_.Event]: + _ = client, model, messages, tools, output_type, params, provider + + async def _stream() -> AsyncGenerator[events_.Event]: + yield events_.StreamStart() + yield events_.TextStart(block_id="text") + yield events_.TextDelta(block_id="text", chunk="override") + yield events_.TextEnd(block_id="text") + yield events_.StreamEnd() + + return _stream() + + async with models.stream( + MOCK_MODEL, + [ai.user_message("Hi")], + protocol=OverrideProtocol(), + ) as stream: + async for _ in stream: + pass + + assert stream.text == "override" + + +async def test_generate_dispatches_to_provider() -> None: + provider = MockProvider() model = models.Model( id="generate-model", - adapter="mock-generate", provider=provider, ) sentinel = messages_.Message( @@ -264,9 +298,38 @@ async def _generate( assert result is sentinel +async def test_generate_accepts_protocol_kwarg() -> None: + sentinel = messages_.Message( + role="assistant", + parts=[messages_.FilePart(data=b"\x89PNG", media_type="image/png")], + ) + + class OverrideProtocol(models.ProviderProtocol[Any]): + async def generate( + self, + client: Any, + model: models.Model, + messages: list[messages_.Message], + params: models.GenerateParams, + *, + provider: str, + ) -> messages_.Message: + _ = client, model, messages, params, provider + return sentinel + + result = await models.generate( + MOCK_MODEL, + [ai.user_message("A cat")], + models.ImageParams(n=1), + protocol=OverrideProtocol(), + ) + + assert result is sentinel + + class _CheckProvider(MockProvider): def __init__(self) -> None: - super().__init__(adapter="mock-check") + super().__init__() self.checked_model: models.Model | None = None async def probe(self, model: models.Model) -> None: diff --git a/tests/models/test_resolution.py b/tests/models/test_resolution.py index 76a3458..20ac643 100644 --- a/tests/models/test_resolution.py +++ b/tests/models/test_resolution.py @@ -2,30 +2,33 @@ import ai from ai import ConfigurationError, models +from ai.providers.ai_gateway import GatewayV3Protocol +from ai.providers.anthropic import AnthropicMessagesProtocol +from ai.providers.openai import OpenAIChatCompletionsProtocol, OpenAIResponsesProtocol def test_get_resolves_provider_qualified_model_id() -> None: model = ai.get_model("openai:gpt-5") assert model.id == "gpt-5" - assert model.adapter == "openai" assert model.provider.name == "openai" + assert isinstance(model.provider.protocol, OpenAIResponsesProtocol) def test_get_resolves_provider_qualified_anthropic_model_id() -> None: model = models.get_model("anthropic:claude-sonnet-4-5") assert model.id == "claude-sonnet-4-5" - assert model.adapter == "anthropic" assert model.provider.name == "anthropic" + assert isinstance(model.provider.protocol, AnthropicMessagesProtocol) def test_get_defaults_to_gateway_when_provider_is_omitted() -> None: model = models.get_model("anthropic/claude-sonnet-4") assert model.id == "anthropic/claude-sonnet-4" - assert model.adapter == "ai-gateway-v3" assert model.provider.name == "ai-gateway" + assert isinstance(model.provider.protocol, GatewayV3Protocol) def test_get_uses_default_model_env_when_model_id_is_omitted( @@ -36,8 +39,8 @@ def test_get_uses_default_model_env_when_model_id_is_omitted( model = models.get_model() assert model.id == "anthropic/claude-sonnet-4" - assert model.adapter == "ai-gateway-v3" assert model.provider.name == "ai-gateway" + assert isinstance(model.provider.protocol, GatewayV3Protocol) def test_get_rejects_missing_default_model_env( @@ -62,7 +65,7 @@ def test_provider_from_id_resolves_openai_compatible_provider() -> None: provider = ai.get_provider("deepseek") assert provider.name == "deepseek" - assert provider.adapter == "openai" + assert isinstance(provider.protocol, OpenAIChatCompletionsProtocol) assert provider.default_base_url == "https://api.deepseek.com" assert provider.api_key_env == "DEEPSEEK_API_KEY" assert provider.config_envs == () @@ -70,7 +73,7 @@ def test_provider_from_id_resolves_openai_compatible_provider() -> None: def test_provider_base_class_cannot_be_constructed_directly() -> None: with pytest.raises(TypeError, match="base class"): - ai.Provider(name="custom", adapter="mock", base_url="https://example.com") + ai.Provider(name="custom", base_url="https://example.com") def test_provider_from_id_uses_template_envs_for_base_url() -> None: @@ -106,21 +109,21 @@ def test_get_resolves_gateway_alias() -> None: model = models.get_model("ai-gateway:alibaba/qwen-3-14b") assert model.id == "alibaba/qwen-3-14b" - assert model.adapter == "ai-gateway-v3" assert model.provider.name == "ai-gateway" + assert isinstance(model.provider.protocol, GatewayV3Protocol) gateway_model = models.get_model("gateway:alibaba/qwen-3-14b") assert gateway_model.id == model.id - assert gateway_model.adapter == model.adapter assert gateway_model.provider.name == model.provider.name + assert isinstance(gateway_model.provider.protocol, GatewayV3Protocol) def test_get_uses_model_provider_config_for_anthropic_compatibility() -> None: model = models.get_model("azure:claude-sonnet-4-5") assert model.id == "claude-sonnet-4-5" - assert model.adapter == "anthropic" assert model.provider.name == "azure" + assert isinstance(model.provider.protocol, AnthropicMessagesProtocol) assert model.provider.default_base_url == ( "https://${AZURE_RESOURCE_NAME}.services.ai.azure.com/anthropic/v1" ) @@ -132,8 +135,8 @@ def test_get_uses_model_provider_config_for_openai_compatibility() -> None: model = models.get_model("azure:kimi-k2.5") assert model.id == "kimi-k2.5" - assert model.adapter == "openai" assert model.provider.name == "azure" + assert isinstance(model.provider.protocol, OpenAIChatCompletionsProtocol) assert model.provider.default_base_url == ( "https://${AZURE_RESOURCE_NAME}.services.ai.azure.com/models" ) @@ -161,3 +164,18 @@ def test_get_rejects_unsupported_provider_package() -> None: def test_get_rejects_empty_model_id() -> None: with pytest.raises(ConfigurationError, match="malformed model_id: ''"): models.get_model("") + + +def test_get_model_accepts_model_protocol_override() -> None: + protocol = OpenAIChatCompletionsProtocol() + model = models.get_model("openai:gpt-5", protocol=protocol) + + assert model.protocol is protocol + assert isinstance(model.provider.protocol, OpenAIResponsesProtocol) + + +def test_get_provider_accepts_provider_protocol_override() -> None: + protocol = OpenAIChatCompletionsProtocol() + provider = ai.get_provider("openai", protocol=protocol) + + assert provider.protocol is protocol diff --git a/tests/providers/anthropic/test_provider.py b/tests/providers/anthropic/test_provider.py index 01f9822..fe25bf5 100644 --- a/tests/providers/anthropic/test_provider.py +++ b/tests/providers/anthropic/test_provider.py @@ -7,7 +7,10 @@ import pytest import ai -from ai.providers.anthropic import AnthropicCompatibleProvider +from ai.providers.anthropic import ( + AnthropicCompatibleProvider, + AnthropicMessagesProtocol, +) async def test_list_models_gets_models_with_provider_headers_and_sorts_ids() -> None: @@ -176,13 +179,12 @@ def test_get_provider_accepts_base_url_and_api_key() -> None: model = ai.Model("custom-model", provider=provider) assert repr(provider) == "anthropic" - assert provider.adapter == "anthropic" + assert isinstance(provider.protocol, AnthropicMessagesProtocol) assert provider.base_url == "https://custom.example.com" assert provider.api_key == "sk-custom" assert provider.headers == {"X-Custom-Header": "example"} assert provider.is_configured() is True assert model.id == "custom-model" - assert model.adapter == "anthropic" assert model.provider is provider diff --git a/tests/providers/openai/test_adapter.py b/tests/providers/openai/test_adapter.py index e063979..6579922 100644 --- a/tests/providers/openai/test_adapter.py +++ b/tests/providers/openai/test_adapter.py @@ -16,7 +16,7 @@ import ai from ai.providers.openai import protocol from ai.providers.openai import tools as openai_tools -from ai.types import messages +from ai.types import events, messages, tools class _Answer(pydantic.BaseModel): @@ -31,6 +31,22 @@ async def __anext__(self) -> Any: raise StopAsyncIteration +class _ListStream: + def __init__(self, items: list[dict[str, Any]]) -> None: + self._items = items + self._idx = 0 + + def __aiter__(self) -> _ListStream: + return self + + async def __anext__(self) -> dict[str, Any]: + if self._idx >= len(self._items): + raise StopAsyncIteration + item = self._items[self._idx] + self._idx += 1 + return item + + class _FakeCompletions: def __init__(self, captured: dict[str, Any]) -> None: self._captured = captured @@ -54,6 +70,21 @@ async def close(self) -> None: self.closed = True +class _FakeResponses: + def __init__(self, captured: dict[str, Any], items: list[dict[str, Any]]) -> None: + self._captured = captured + self._items = items + + async def create(self, **kwargs: Any) -> _ListStream: + self._captured.update(kwargs) + return _ListStream(self._items) + + +class _FakeResponsesClient: + def __init__(self, captured: dict[str, Any], items: list[dict[str, Any]]) -> None: + self.responses = _FakeResponses(captured, items) + + class _RaisingCompletions: def __init__(self, exc: openai.OpenAIError) -> None: self._exc = exc @@ -88,11 +119,283 @@ def _patch( return cast(openai.AsyncOpenAI, fake), captured +def _patch_responses( + items: list[dict[str, Any]] | None = None, +) -> tuple[openai.AsyncOpenAI, dict[str, Any]]: + captured: dict[str, Any] = {} + fake = _FakeResponsesClient(captured, items or []) + return cast(openai.AsyncOpenAI, fake), captured + + async def _drain(stream: Any) -> None: async for _ in stream: pass +async def test_responses_request_uses_responses_input() -> None: + fake, captured = _patch_responses() + + await _drain( + protocol.OpenAIResponsesProtocol().stream( + fake, + _MODEL, + [ai.system_message("rules"), ai.user_message("Hi")], + provider="openai", + ) + ) + + assert captured["model"] == "gpt-5.4" + assert captured["stream"] is True + assert captured["input"] == [ + {"role": "system", "content": "rules"}, + {"role": "user", "content": [{"type": "input_text", "text": "Hi"}]}, + ] + assert "messages" not in captured + + +async def test_responses_raw_params_and_structured_output() -> None: + fake, captured = _patch_responses() + + await _drain( + protocol.OpenAIResponsesProtocol().stream( + fake, + _MODEL, + [ai.user_message("Hi")], + output_type=_Answer, + params={ + "reasoning": {"effort": "high"}, + "include": ["file_search_call.results"], + "text": {"verbosity": "low"}, + "extra_headers": {"x-openai-feature": "enabled"}, + }, + provider="openai", + ) + ) + + assert captured["reasoning"] == {"effort": "high"} + assert captured["include"] == ["file_search_call.results"] + assert captured["extra_headers"] == {"x-openai-feature": "enabled"} + assert captured["text"]["verbosity"] == "low" + assert captured["text"]["format"]["type"] == "json_schema" + assert captured["text"]["format"]["name"] == "_Answer" + assert captured["text"]["format"]["strict"] is True + + +async def test_responses_tools_convert_function_and_provider_tools() -> None: + fake, captured = _patch_responses() + + await _drain( + protocol.OpenAIResponsesProtocol().stream( + fake, + _MODEL, + [ai.user_message("Hi")], + tools=[ + tools.Tool( + kind="function", + name="weather", + args=tools.FunctionToolArgs( + description="Get weather", + params={ + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + ), + ), + openai_tools.web_search(search_context_size="low"), + openai_tools.code_interpreter(), + ], + provider="openai", + ) + ) + + assert captured["tools"] == [ + { + "type": "function", + "name": "weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + }, + {"type": "web_search", "search_context_size": "low"}, + {"type": "code_interpreter", "container": {"type": "auto"}}, + ] + + +async def test_responses_streams_text_and_usage() -> None: + fake, _ = _patch_responses( + [ + { + "type": "response.created", + "response": { + "id": "resp_1", + "model": "gpt-5.4", + "status": "in_progress", + }, + }, + { + "type": "response.output_item.added", + "output_index": 0, + "item": {"id": "msg_1", "type": "message", "role": "assistant"}, + }, + {"type": "response.output_text.delta", "item_id": "msg_1", "delta": "Hi"}, + { + "type": "response.output_item.done", + "output_index": 0, + "item": { + "id": "msg_1", + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Hi"}], + }, + }, + { + "type": "response.completed", + "response": { + "id": "resp_1", + "model": "gpt-5.4", + "status": "completed", + "usage": { + "input_tokens": 3, + "input_tokens_details": {"cached_tokens": 1}, + "output_tokens": 5, + "output_tokens_details": {"reasoning_tokens": 2}, + }, + }, + }, + ] + ) + + stream = ai.Stream( + protocol.OpenAIResponsesProtocol().stream( + fake, + _MODEL, + [ai.user_message("Hi")], + provider="openai", + ) + ) + async for _ in stream: + pass + + assert stream.text == "Hi" + assert stream.usage is not None + assert stream.usage.input_tokens == 3 + assert stream.usage.output_tokens == 5 + assert stream.usage.reasoning_tokens == 2 + assert stream.usage.cache_read_tokens == 1 + assert stream.message.provider_metadata == { + "openai": { + "response_id": "resp_1", + "model": "gpt-5.4", + "status": "completed", + } + } + + +async def test_responses_streams_function_tool_call() -> None: + fake, _ = _patch_responses( + [ + { + "type": "response.output_item.added", + "output_index": 0, + "item": { + "id": "fc_1", + "type": "function_call", + "call_id": "call_1", + "name": "weather", + }, + }, + { + "type": "response.function_call_arguments.delta", + "item_id": "fc_1", + "output_index": 0, + "delta": '{"city"', + }, + { + "type": "response.function_call_arguments.delta", + "item_id": "fc_1", + "output_index": 0, + "delta": ':"SF"}', + }, + { + "type": "response.output_item.done", + "output_index": 0, + "item": { + "id": "fc_1", + "type": "function_call", + "call_id": "call_1", + "name": "weather", + "arguments": '{"city":"SF"}', + }, + }, + ] + ) + + stream = ai.Stream( + protocol.OpenAIResponsesProtocol().stream( + fake, + _MODEL, + [ai.user_message("Hi")], + provider="openai", + ) + ) + async for _ in stream: + pass + + assert len(stream.tool_calls) == 1 + assert stream.tool_calls[0].tool_call_id == "call_1" + assert stream.tool_calls[0].tool_name == "weather" + assert stream.tool_calls[0].tool_args == '{"city":"SF"}' + + +async def test_responses_streams_builtin_tool_call_and_result() -> None: + fake, _ = _patch_responses( + [ + { + "type": "response.output_item.added", + "output_index": 0, + "item": { + "id": "ws_1", + "type": "web_search_call", + "status": "searching", + }, + }, + { + "type": "response.output_item.done", + "output_index": 0, + "item": { + "id": "ws_1", + "type": "web_search_call", + "status": "completed", + "action": {"type": "search", "query": "weather"}, + }, + }, + ] + ) + + stream = ai.Stream( + protocol.OpenAIResponsesProtocol().stream( + fake, + _MODEL, + [ai.user_message("Hi")], + provider="openai", + ) + ) + seen: list[type[events.Event]] = [] + async for event in stream: + seen.append(type(event)) + + assert events.BuiltinToolStart in seen + assert events.BuiltinToolEnd in seen + assert len(stream.message.builtin_tool_calls) == 1 + assert stream.message.builtin_tool_calls[0].tool_name == "web_search" + assert len(stream.message.builtin_tool_returns) == 1 + assert stream.message.builtin_tool_returns[0].result == { + "action": {"type": "search", "query": "weather"} + } + + async def test_system_messages_use_openai_system_role( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/tests/providers/openai/test_provider.py b/tests/providers/openai/test_provider.py index ca7ea94..ef3d26f 100644 --- a/tests/providers/openai/test_provider.py +++ b/tests/providers/openai/test_provider.py @@ -7,7 +7,7 @@ import pytest import ai -from ai.providers.openai import OpenAICompatibleProvider +from ai.providers.openai import OpenAICompatibleProvider, OpenAIResponsesProtocol async def test_list_models_gets_models_with_auth_header_and_sorts_ids() -> None: @@ -207,13 +207,12 @@ def test_get_provider_accepts_base_url_and_api_key() -> None: model = ai.Model("custom-model", provider=provider) assert repr(provider) == "openai" - assert provider.adapter == "openai" + assert isinstance(provider.protocol, OpenAIResponsesProtocol) assert provider.base_url == "https://custom.example.com/v1" assert provider.api_key == "sk-custom" assert provider.headers == {"X-Custom-Header": "example"} assert provider.is_configured() is True assert model.id == "custom-model" - assert model.adapter == "openai" assert model.provider is provider assert isinstance(provider, ai.Provider)