Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
from .._run_context import RunContext
from .._thinking_part import split_content_into_text_and_thinking
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
from .._utils import (
guard_tool_call_id as _guard_tool_call_id,
now_utc as _now_utc,
number_to_datetime,
)
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool
from ..exceptions import UserError
from ..messages import (
Expand Down Expand Up @@ -55,6 +59,7 @@
from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests, download_item, get_user_agent

try:
import tiktoken
from openai import NOT_GIVEN, APIConnectionError, APIStatusError, AsyncOpenAI, AsyncStream
from openai.types import AllModels, chat, responses
from openai.types.chat import (
Expand Down Expand Up @@ -1008,6 +1013,24 @@ def _inline_text_file_part(text: str, *, media_type: str, identifier: str) -> Ch
)
return ChatCompletionContentPartTextParam(text=text, type='text')

async def count_tokens(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> usage.RequestUsage:
"""Count the number of tokens in the given messages."""
if self.system != 'openai':
raise NotImplementedError('Token counting is only supported for OpenAI system.')

model_settings, model_request_parameters = self.prepare_request(model_settings, model_request_parameters)
openai_messages = await self._map_messages(messages, model_request_parameters)
token_count = _num_tokens_from_messages(openai_messages, self.model_name)

return usage.RequestUsage(
input_tokens=token_count,
)


@deprecated(
'`OpenAIModel` was renamed to `OpenAIChatModel` to clearly distinguish it from `OpenAIResponsesModel` which '
Expand Down Expand Up @@ -1804,6 +1827,26 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa
assert_never(item)
return responses.EasyInputMessageParam(role='user', content=content)

async def count_tokens(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> usage.RequestUsage:
"""Count the number of tokens in the given messages."""
if self.system != 'openai':
raise NotImplementedError('Token counting is only supported for OpenAI system.')

model_settings, model_request_parameters = self.prepare_request(model_settings, model_request_parameters)
_, openai_messages = await self._map_messages(
messages, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
)
token_count = _num_tokens_from_messages(openai_messages, self.model_name)

return usage.RequestUsage(
input_tokens=token_count,
)


@dataclass
class OpenAIStreamedResponse(StreamedResponse):
Expand Down Expand Up @@ -2519,3 +2562,31 @@ def _map_mcp_call(
provider_name=provider_name,
),
)


def _num_tokens_from_messages(
messages: list[chat.ChatCompletionMessageParam] | list[responses.ResponseInputItemParam],
model: OpenAIModelName,
) -> int:
"""Return the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
encoding = tiktoken.get_encoding('o200k_base')

if 'gpt-5' in model:
tokens_per_message = 3
final_primer = 2 # "reverse engineered" based on test cases
else:
# Adapted from https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#6-counting-tokens-for-chat-completions-api-calls
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at the cookbook again, I think we should also try to implement support for counting the tokens of tool definitions:

https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken#7-counting-tokens-for-chat-completions-with-tool-calls

tokens_per_message = 3
final_primer = 3 # every reply is primed with <|start|>assistant<|message|>

num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for value in message.values():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit weird, as it assumes every string value in the message dict will be sent to the model. That may be the case for ChatCompletionMessageParam, but not for ResponseInputItemParam, which is a union that includes things like Message:

class Message(TypedDict, total=False):
    content: Required[ResponseInputMessageContentListParam]
    """
    A list of one or many input items to the model, containing different content
    types.
    """

    role: Required[Literal["user", "system", "developer"]]
    """The role of the message input. One of `user`, `system`, or `developer`."""

    status: Literal["in_progress", "completed", "incomplete"]
    """The status of item.

    One of `in_progress`, `completed`, or `incomplete`. Populated when items are
    returned via API.
    """

    type: Literal["message"]
    """The type of the message input. Always set to `message`."""

I don't think those status and type fields end up with the model. But it'd be worth verifying by comparing our calculation with real data from the API, as I suggested below in the tests.

if isinstance(value, str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We also don't currently handle lists of strings properly, for example ChatCompletionMessageParam can be ChatCompletionUserMessageParam:

class ChatCompletionUserMessageParam(TypedDict, total=False):
    content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]]
    """The contents of the user message."""

    role: Required[Literal["user"]]
    """The role of the messages author, in this case `user`."""

    name: str
    """An optional name for the participant.

    Provides the model information to differentiate between participants of the same
    role.
    """

content may just be a str, but could also be a list of ChatCompletionContentPartTextParam:

class ChatCompletionContentPartTextParam(TypedDict, total=False):
    text: Required[str]
    """The text content."""

    type: Required[Literal["text"]]
    """The type of the content part."""

We shouldn't exclude that text from the count.

Same for ResponseInputItemParam, which can have text hidden inside lists.

Unfortunately OpenAI makes it very hard for us to calculate this stuff correctly, but I'd rather have no count_tokens method than one that only works in very specific unrealistic scenarios -- most users are going to have more complicated message histories than the ones we're currently accounting for. So I think we should either implement some smarter behavior (and verify in the tests that it works!), or "give up". Let me know if you're up for the challenge :)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am up try to count the tokens for more complicated histories. Seems there are quite a few possible inputs based on your comment.

Are there any test cases which I can use as a starting point which represent a more complicated structure?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wirthual Not specifically, but if you look at the types in the ModelRequest.parts, UserPromptPart.content and ModelResponse.parts type unions, it's pretty easy to (have AI) build one of each

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wirthual Just found https://github.com/pamelafox/openai-messages-token-helper which may be worth using or looking at for inspiration

num_tokens += len(encoding.encode(value))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this (or the get_encoding call further up?) could download a large file, but tiktoken is sync not async, we should wrap the call that may do a download in _utils.run_in_executor to run it in a thread

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The methods which download the encoding file are tiktoken.encoding_for_model and tiktoken.get_encoding. So they would be wrapped with _utils.run_in_executor and then awaited? And _num_tokens_from_messages would become async?

num_tokens += final_primer
return num_tokens
13 changes: 11 additions & 2 deletions pydantic_ai_slim/pydantic_ai/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,17 @@ class UsageLimits:
"""The maximum number of tokens allowed in requests and responses combined."""
count_tokens_before_request: bool = False
"""If True, perform a token counting pass before sending the request to the model,
to enforce `request_tokens_limit` ahead of time. This may incur additional overhead
(from calling the model's `count_tokens` API before making the actual request) and is disabled by default."""
to enforce `input_tokens_limit` ahead of time. This may incur additional overhead
(from calling the model's `count_tokens` method before making the actual request) and is disabled by default.

Supported by:

- [`OpenAIChatModel`][pydantic_ai.models.openai.OpenAIChatModel] and
[`OpenAIResponsesModel`][pydantic_ai.models.openai.OpenAIResponsesModel] (only for OpenAI models)
- [`AnthropicModel`][pydantic_ai.models.anthropic.AnthropicModel] (excluding Bedrock client)
- [`GoogleModel`][pydantic_ai.models.google.GoogleModel]
- [`BedrockModel`][pydantic_ai.models.bedrock.BedrockModel] (including Anthropic models)
"""

@property
@deprecated('`request_tokens_limit` is deprecated, use `input_tokens_limit` instead')
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ dependencies = [
# WARNING if you add optional groups, please update docs/install.md
logfire = ["logfire[httpx]>=3.14.1"]
# Models
openai = ["openai>=1.107.2"]
openai = ["openai>=1.107.2","tiktoken>=0.12.0"]
cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
google = ["google-genai>=1.51.0"]
Expand Down
Loading
Loading