-
Notifications
You must be signed in to change notification settings - Fork 1.4k
Implement OpenAI token counting using tiktoken
#3447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
d7f0b87
80a61f1
cc8cbf0
c1be8c1
1332cd8
cb5da87
6396f5d
46cd331
86a0b89
bacf788
acf86b0
6d2d4dd
9943173
75f29fa
6deaea2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |
| from types import GenericAlias | ||
| from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeGuard, TypeVar, get_args, get_origin, overload | ||
|
|
||
| import tiktoken | ||
| from anyio.to_thread import run_sync | ||
| from pydantic import BaseModel, TypeAdapter | ||
| from pydantic.json_schema import JsonSchemaValue | ||
|
|
@@ -32,10 +33,14 @@ | |
| AbstractSpan = AbstractSpan | ||
|
|
||
| if TYPE_CHECKING: | ||
| from openai.types.chat import ChatCompletionMessageParam | ||
| from openai.types.responses.response_input_item_param import ResponseInputItemParam | ||
|
|
||
| from pydantic_ai.agent import AgentRun, AgentRunResult | ||
| from pydantic_graph import GraphRun, GraphRunResult | ||
|
|
||
| from . import messages as _messages | ||
| from .models.openai import OpenAIModelName | ||
| from .tools import ObjectJsonSchema | ||
|
|
||
| _P = ParamSpec('_P') | ||
|
|
@@ -507,3 +512,45 @@ def get_event_loop(): | |
| event_loop = asyncio.new_event_loop() | ||
| asyncio.set_event_loop(event_loop) | ||
| return event_loop | ||
|
|
||
|
|
||
| def num_tokens_from_messages( | ||
| messages: list[ChatCompletionMessageParam] | list[ResponseInputItemParam], | ||
| model: OpenAIModelName = 'gpt-4o-mini-2024-07-18', | ||
|
||
| ) -> int: | ||
| """Return the number of tokens used by a list of messages.""" | ||
| try: | ||
| encoding = tiktoken.encoding_for_model(model) | ||
| except KeyError: | ||
| print('Warning: model not found. Using o200k_base encoding.') # TODO: How to handle warnings? | ||
|
||
| encoding = tiktoken.get_encoding('o200k_base') | ||
| if model in { | ||
| 'gpt-3.5-turbo-0125', | ||
| 'gpt-4-0314', | ||
| 'gpt-4-32k-0314', | ||
| 'gpt-4-0613', | ||
| 'gpt-4-32k-0613', | ||
| 'gpt-4o-mini-2024-07-18', | ||
| 'gpt-4o-2024-08-06', | ||
| }: | ||
| tokens_per_message = 3 | ||
| elif 'gpt-3.5-turbo' in model: | ||
| return num_tokens_from_messages(messages, model='gpt-3.5-turbo-0125') | ||
| elif 'gpt-4o-mini' in model: | ||
| return num_tokens_from_messages(messages, model='gpt-4o-mini-2024-07-18') | ||
| elif 'gpt-4o' in model: | ||
| return num_tokens_from_messages(messages, model='gpt-4o-2024-08-06') | ||
| elif 'gpt-4' in model: | ||
| return num_tokens_from_messages(messages, model='gpt-4-0613') | ||
| else: | ||
| raise NotImplementedError( | ||
| f"""num_tokens_from_messages() is not implemented for model {model}.""" | ||
| ) # TODO: How to handle other models? | ||
|
||
| num_tokens = 0 | ||
| for message in messages: | ||
| num_tokens += tokens_per_message | ||
| for value in message.values(): | ||
| if isinstance(value, str): | ||
| num_tokens += len(encoding.encode(value)) | ||
| num_tokens += 3 # every reply is primed with <|start|>assistant<|message|> | ||
| return num_tokens | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is OpenAI specific so it should live in
models/openai.py