Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from types import GenericAlias
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar, get_args, get_origin, overload

import tiktoken
from anyio.to_thread import run_sync
from pydantic import BaseModel, TypeAdapter
from pydantic.json_schema import JsonSchemaValue
Expand All @@ -32,10 +33,14 @@
AbstractSpan = AbstractSpan

if TYPE_CHECKING:
from openai.types.chat import ChatCompletionMessageParam
from openai.types.responses.response_input_item_param import ResponseInputItemParam

from pydantic_ai.agent import AgentRun, AgentRunResult
from pydantic_graph import GraphRun, GraphRunResult

from . import messages as _messages
from .models.openai import OpenAIModelName
from .tools import ObjectJsonSchema

_P = ParamSpec('_P')
Expand Down Expand Up @@ -507,3 +512,45 @@ def get_event_loop():
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
return event_loop


def num_tokens_from_messages(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is OpenAI specific so it should live in models/openai.py

messages: list[ChatCompletionMessageParam] | list[ResponseInputItemParam],
model: OpenAIModelName = 'gpt-4o-mini-2024-07-18',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need a default value

) -> int:
"""Return the number of tokens used by a list of messages."""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print('Warning: model not found. Using o200k_base encoding.') # TODO: How to handle warnings?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No warnings please, let's just make a best effort

encoding = tiktoken.get_encoding('o200k_base')
if model in {
'gpt-3.5-turbo-0125',
'gpt-4-0314',
'gpt-4-32k-0314',
'gpt-4-0613',
'gpt-4-32k-0613',
'gpt-4o-mini-2024-07-18',
'gpt-4o-2024-08-06',
}:
tokens_per_message = 3
elif 'gpt-3.5-turbo' in model:
return num_tokens_from_messages(messages, model='gpt-3.5-turbo-0125')
elif 'gpt-4o-mini' in model:
return num_tokens_from_messages(messages, model='gpt-4o-mini-2024-07-18')
elif 'gpt-4o' in model:
return num_tokens_from_messages(messages, model='gpt-4o-2024-08-06')
elif 'gpt-4' in model:
return num_tokens_from_messages(messages, model='gpt-4-0613')
else:
raise NotImplementedError(
f"""num_tokens_from_messages() is not implemented for model {model}."""
) # TODO: How to handle other models?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you able to reverse engineer the right formula for gpt-5?

As long as we document that this is a best effort calculation and may not be accurate down to the exact token, we can have one branch of logic for "everything before gpt-5" and one for every newer. If future models have different rules, we can update the logic then.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think with a decreased primer the calculation for gpt5 is more accurate.

Should the method from the cookbook be the default for all other models?

num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for value in message.values():
if isinstance(value, str):
num_tokens += len(encoding.encode(value))
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
37 changes: 36 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
from .._output import DEFAULT_OUTPUT_TOOL_NAME, OutputObjectDefinition
from .._run_context import RunContext
from .._thinking_part import split_content_into_text_and_thinking
from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime
from .._utils import (
guard_tool_call_id as _guard_tool_call_id,
now_utc as _now_utc,
num_tokens_from_messages,
number_to_datetime,
)
from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool
from ..exceptions import UserError
from ..messages import (
Expand Down Expand Up @@ -907,6 +912,20 @@ def _inline_text_file_part(text: str, *, media_type: str, identifier: str) -> Ch
)
return ChatCompletionContentPartTextParam(text=text, type='text')

async def count_tokens(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> usage.RequestUsage:
"""Make a request to the model for counting tokens."""
openai_messages = await self._map_messages(messages, model_request_parameters)
token_count = num_tokens_from_messages(openai_messages, self.model_name)

return usage.RequestUsage(
input_tokens=token_count,
)


@deprecated(
'`OpenAIModel` was renamed to `OpenAIChatModel` to clearly distinguish it from `OpenAIResponsesModel` which '
Expand Down Expand Up @@ -1701,6 +1720,22 @@ async def _map_user_prompt(part: UserPromptPart) -> responses.EasyInputMessagePa
assert_never(item)
return responses.EasyInputMessageParam(role='user', content=content)

async def count_tokens(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> usage.RequestUsage:
"""Make a request to the model for counting tokens."""
_, openai_messages = await self._map_messages(
messages, cast(OpenAIResponsesModelSettings, model_settings or {}), model_request_parameters
)
token_count = num_tokens_from_messages(openai_messages, self.model_name)

return usage.RequestUsage(
input_tokens=token_count,
)


@dataclass
class OpenAIStreamedResponse(StreamedResponse):
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ dependencies = [
# WARNING if you add optional groups, please update docs/install.md
logfire = ["logfire[httpx]>=3.14.1"]
# Models
openai = ["openai>=1.107.2"]
openai = ["openai>=1.107.2","tiktoken>=0.12.0"]
cohere = ["cohere>=5.18.0; platform_system != 'Emscripten'"]
vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"]
google = ["google-genai>=1.50.1"]
Expand Down
Loading
Loading