Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/run-examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Sample:
Sample("agent_custom_loop.py"),
Sample("agent_nested.py"),
Sample("streaming_tool.py"),
Sample("openai_chat_completions.py"),
Sample("explicit_client.py"),
Sample("multimodal_input.py"),
Sample("check_connection.py"),
Expand Down
39 changes: 39 additions & 0 deletions examples/samples/openai_chat_completions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""OpenAI Chat Completions protocol — stream text from GPT-5.5."""

import asyncio

import ai
from ai.providers.openai import OpenAIChatCompletionsProtocol

messages = [
ai.system_message("Be concise."),
ai.user_message(
"Explain what the OpenAI Chat Completions API is in two sentences."
),
]


async def main() -> None:
provider = ai.get_provider("openai")
if not provider.is_configured():
print(f"[SKIP] {provider.name} provider is not configured")
return

model = ai.Model("gpt-5.5", provider=provider)

try:
async with ai.stream(
model,
messages,
protocol=OpenAIChatCompletionsProtocol(),
) as stream:
async for event in stream:
if isinstance(event, ai.events.TextDelta):
print(event.chunk, end="", flush=True)
print()
finally:
await provider.aclose()


if __name__ == "__main__":
asyncio.run(main())
2 changes: 2 additions & 0 deletions src/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ImageParams,
Model,
Provider,
ProviderProtocol,
Stream,
VideoParams,
generate,
Expand Down Expand Up @@ -109,6 +110,7 @@
"UnsupportedProviderError",
"Model",
"Provider",
"ProviderProtocol",
"ImageParams",
"VideoParams",
"Stream",
Expand Down
3 changes: 2 additions & 1 deletion src/ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
ids = await ai.get_provider("openai").list_models()
"""

from ..providers.base import Provider
from ..providers.base import Provider, ProviderProtocol
from .core.api import (
Executor,
GenerateExecutor,
Expand All @@ -56,6 +56,7 @@
"ImageParams",
"Model",
"Provider",
"ProviderProtocol",
"Stream",
"StreamExecutor",
"StreamRequest",
Expand Down
38 changes: 29 additions & 9 deletions src/ai/models/core/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
from __future__ import annotations

import contextlib
import dataclasses
from collections.abc import AsyncGenerator, AsyncIterator, Sequence
from contextlib import AbstractAsyncContextManager
from typing import Any, Generic, Protocol, Self, cast, overload, runtime_checkable
from typing import (
TYPE_CHECKING,
Any,
Generic,
Protocol,
Self,
cast,
overload,
runtime_checkable,
)

import pydantic

Expand All @@ -15,6 +26,9 @@
from . import model as model_
from . import params as params_

if TYPE_CHECKING:
from ...providers import base as provider_base

# Stream output type. Defaults to ``str``: when the stream was opened
# without an ``output_type``, ``Stream.output`` returns the concatenated
# message text.
Expand All @@ -28,13 +42,15 @@ class StreamRequest:
tools: Sequence[types.tools.Tool] | None = None
output_type: type[pydantic.BaseModel] | None = None
params: Any = None
protocol: provider_base.ProviderProtocol[Any] | None = None


@dataclasses.dataclass(frozen=True)
class GenerateRequest:
model: model_.Model
messages: list[types.messages.Message]
params: params_.GenerateParams
protocol: provider_base.ProviderProtocol[Any] | None = None


@runtime_checkable
Expand Down Expand Up @@ -65,6 +81,7 @@ async def _do_stream(
tools=request.tools,
output_type=request.output_type,
params=request.params,
protocol=request.protocol,
):
yield ev

Expand All @@ -73,6 +90,7 @@ async def _do_generate(self, request: GenerateRequest) -> types.messages.Message
request.model,
request.messages,
request.params,
protocol=request.protocol,
)


Expand Down Expand Up @@ -347,6 +365,7 @@ def stream(
*,
context: StreamContext,
params: Any = None,
protocol: provider_base.ProviderProtocol[Any] | None = None,
executor: StreamExecutor = _default_executor,
) -> AbstractAsyncContextManager[Stream[str]]: ...
@overload
Expand All @@ -355,6 +374,7 @@ def stream[T: pydantic.BaseModel](
context: StreamContext,
output_type: type[T],
params: Any = None,
protocol: provider_base.ProviderProtocol[Any] | None = None,
executor: StreamExecutor = _default_executor,
) -> AbstractAsyncContextManager[Stream[T]]: ...
@overload
Expand All @@ -364,6 +384,7 @@ def stream(
*,
tools: Sequence[types.tools.Tool] | None = None,
params: Any = None,
protocol: provider_base.ProviderProtocol[Any] | None = None,
executor: StreamExecutor = _default_executor,
) -> AbstractAsyncContextManager[Stream[str]]: ...
@overload
Expand All @@ -374,6 +395,7 @@ def stream[T: pydantic.BaseModel](
tools: Sequence[types.tools.Tool] | None = None,
output_type: type[T],
params: Any = None,
protocol: provider_base.ProviderProtocol[Any] | None = None,
executor: StreamExecutor = _default_executor,
) -> AbstractAsyncContextManager[Stream[T]]: ...
def stream(
Expand All @@ -384,6 +406,7 @@ def stream(
tools: Sequence[types.tools.Tool] | None = None,
output_type: type[pydantic.BaseModel] | None = None,
params: Any = None,
protocol: provider_base.ProviderProtocol[Any] | None = None,
executor: StreamExecutor = _default_executor,
) -> AbstractAsyncContextManager[Stream[Any]]:
"""Stream an LLM response.
Expand Down Expand Up @@ -420,6 +443,7 @@ def stream(
tools=tools,
output_type=output_type,
params=params,
protocol=protocol,
executor=executor,
)

Expand All @@ -432,6 +456,7 @@ async def _stream(
tools: Sequence[types.tools.Tool] | None,
output_type: type[pydantic.BaseModel] | None,
params: Any,
protocol: provider_base.ProviderProtocol[Any] | None,
executor: StreamExecutor,
) -> AsyncIterator[Stream[Any]]:
if messages and messages[-1].replay:
Expand All @@ -443,13 +468,7 @@ async def _stream(
)
else:
prepared = integrity.prepare_messages(messages)
request = StreamRequest(
model,
prepared,
tools,
output_type,
params,
)
request = StreamRequest(model, prepared, tools, output_type, params, protocol)
s = Stream(executor._do_stream(request), output_type=output_type)
try:
yield s
Expand All @@ -462,11 +481,12 @@ async def generate(
messages: list[types.messages.Message],
params: params_.GenerateParams,
*,
protocol: provider_base.ProviderProtocol[Any] | None = None,
executor: GenerateExecutor = _default_executor,
) -> types.messages.Message:
"""Generate a non-streaming response (images, video, etc.)."""
messages = integrity.prepare_messages(messages)
request = GenerateRequest(model, messages, params)
request = GenerateRequest(model, messages, params, protocol)
return await executor._do_generate(request)


Expand Down
27 changes: 16 additions & 11 deletions src/ai/models/core/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Model metadata types."""

import os
from typing import Any

from ... import _modelsdev
from ...errors import ConfigurationError
Expand All @@ -13,40 +14,41 @@ class Model:
"""Lightweight reference to a model on a specific provider.

* ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-6"``).
* ``adapter`` — wire protocol key (e.g. ``"ai-gateway-v3"``, ``"anthropic"``).
* ``provider`` — :class:`Provider` that owns this model.
* ``protocol`` — optional wire-protocol override for this model.
"""

def __init__(
self,
id: str,
*,
provider: base.Provider,
adapter: str | None = None,
protocol: base.ProviderProtocol[Any] | None = None,
) -> None:
self.id = id
self.provider = provider
self.adapter = adapter or provider.adapter
self.protocol = protocol

def __eq__(self, other: object) -> bool:
return (
isinstance(other, Model)
and self.id == other.id
and self.adapter == other.adapter
and self.provider is other.provider
and self.protocol is other.protocol
)

def __repr__(self) -> str:
return (
f"Model(id={self.id!r}, adapter={self.adapter!r}, "
f"provider={self.provider!r})"
)
return f"Model(id={self.id!r}, provider={self.provider!r})"

def __hash__(self) -> int:
return hash((self.id, self.adapter, id(self.provider)))
return hash((self.id, id(self.provider), id(self.protocol)))


def get_model(model_id: str | None = None) -> Model:
def get_model(
model_id: str | None = None,
*,
protocol: base.ProviderProtocol[Any] | None = None,
) -> Model:
"""Resolve a model ID into a :class:`Model`.

Args:
Expand All @@ -56,6 +58,9 @@ def get_model(model_id: str | None = None) -> Model:
Vercel AI Gateway. Examples: ``"openai:gpt-5"`` or
``"anthropic/claude-sonnet-4"``. When omitted, reads
``AI_SDK_DEFAULT_MODEL`` from the environment.
protocol:
Optional wire-protocol override for this model. When omitted,
the provider chooses its default protocol.
Raises:
Raises :class:`ai.ConfigurationError` when ``model_id`` and
``AI_SDK_DEFAULT_MODEL`` is empty or malformed.
Expand Down Expand Up @@ -89,4 +94,4 @@ def get_model(model_id: str | None = None) -> Model:
model_provider_config=model_provider_config,
)

return Model(provider_model_id, provider=provider)
return Model(provider_model_id, provider=provider, protocol=protocol)
3 changes: 2 additions & 1 deletion src/ai/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from .ai_gateway import GatewayProvider
from .anthropic import AnthropicCompatibleProvider
from .base import Provider, get_provider
from .base import Provider, ProviderProtocol, get_provider
from .openai import OpenAICompatibleProvider

__all__ = [
"AnthropicCompatibleProvider",
"GatewayProvider",
"OpenAICompatibleProvider",
"Provider",
"ProviderProtocol",
"get_provider",
]
5 changes: 2 additions & 3 deletions src/ai/providers/ai_gateway/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,15 @@
) as s:
...

The heavy ``.protocol`` module is loaded lazily by provider methods so that
``import ai`` does not pull in extra I/O code at import time. This matters
for sandboxed runtimes (e.g. Temporal workflow workers).
"""

from . import errors, tools
from .protocol import GatewayV3Protocol
from .provider import GatewayProvider

__all__ = [
"GatewayProvider",
"GatewayV3Protocol",
"errors",
"tools",
]
38 changes: 38 additions & 0 deletions src/ai/providers/ai_gateway/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from ... import types
from ...models import core
from .. import base
from ..anthropic import tools as anthropic_tools
from ..openai import tools as openai_tools
from . import client as gateway_client
Expand Down Expand Up @@ -635,3 +636,40 @@ async def generate(
return await _generate_image(gateway, model, messages, params)
except client_errors.GatewayError as exc:
raise errors.map_error(exc) from exc


class GatewayV3Protocol(base.ProviderProtocol[gateway_client.GatewayClient]):
"""AI Gateway v3 wire protocol."""

def stream(
self,
client: gateway_client.GatewayClient,
model: core.model.Model,
messages: list[types.messages.Message],
*,
tools: Sequence[types.tools.Tool] | None = None,
output_type: type[pydantic.BaseModel] | None = None,
params: Any = None,
provider: str,
) -> AsyncGenerator[types.events.Event]:
_ = provider
return stream(
client,
model,
messages,
tools=tools,
output_type=output_type,
params=params,
)

async def generate(
self,
client: gateway_client.GatewayClient,
model: core.model.Model,
messages: list[types.messages.Message],
params: core.GenerateParams,
*,
provider: str,
) -> types.messages.Message:
_ = provider
return await generate(client, model, messages, params)
Loading
Loading