Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Repository = "https://github.com/openai/openai-agents-python"
[project.optional-dependencies]
voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
viz = ["graphviz>=0.17"]
litellm = ["litellm>=1.67.4.post1, <2"]
litellm = ["litellm>=1.80.8, <2"]
realtime = ["websockets>=15.0, <16"]
sqlalchemy = ["SQLAlchemy>=2.0", "asyncpg>=0.29.0"]
encrypt = ["cryptography>=45.0, <46"]
Expand Down
139 changes: 131 additions & 8 deletions src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ class InternalChatCompletionMessage(ChatCompletionMessage):
thinking_blocks: list[dict[str, Any]] | None = None


class InternalToolCall(ChatCompletionMessageFunctionToolCall):
"""
An internal subclass to carry provider-specific metadata (e.g., Gemini thought signatures)
without modifying the original model.
"""

extra_content: dict[str, Any] | None = None


class LitellmModel(Model):
"""This class enables using any model via LiteLLM. LiteLLM allows you to acess OpenAPI,
Anthropic, Gemini, Mistral, and many other models.
Expand Down Expand Up @@ -168,9 +177,15 @@ async def get_response(
"output_tokens": usage.output_tokens,
}

# Build provider_data for provider specific fields
provider_data: dict[str, Any] = {"model": self.model}
if message is not None and hasattr(response, "id"):
provider_data["response_id"] = response.id

items = (
Converter.message_to_output_items(
LitellmConverter.convert_message_to_openai(message)
LitellmConverter.convert_message_to_openai(message, model=self.model),
provider_data=provider_data,
)
if message is not None
else []
Expand Down Expand Up @@ -215,7 +230,9 @@ async def stream_response(
)

final_response: Response | None = None
async for chunk in ChatCmplStreamHandler.handle_stream(response, stream):
async for chunk in ChatCmplStreamHandler.handle_stream(
response, stream, model=self.model
):
yield chunk

if chunk.type == "response.completed":
Expand Down Expand Up @@ -280,13 +297,19 @@ async def _fetch_response(
)

converted_messages = Converter.items_to_messages(
input, preserve_thinking_blocks=preserve_thinking_blocks
input, model=self.model, preserve_thinking_blocks=preserve_thinking_blocks
)

# Fix for interleaved thinking bug: reorder messages to ensure tool_use comes before tool_result # noqa: E501
if "anthropic" in self.model.lower() or "claude" in self.model.lower():
converted_messages = self._fix_tool_message_ordering(converted_messages)

# Convert Google's extra_content to litellm's provider_specific_fields format
if "gemini" in self.model.lower():
converted_messages = self._convert_gemini_extra_content_to_provider_specific_fields(
converted_messages
)

if system_instructions:
converted_messages.insert(
0,
Expand Down Expand Up @@ -436,6 +459,65 @@ async def _fetch_response(
)
return response, ret

def _convert_gemini_extra_content_to_provider_specific_fields(
self, messages: list[ChatCompletionMessageParam]
) -> list[ChatCompletionMessageParam]:
"""
Convert Gemini model's extra_content format to provider_specific_fields format for litellm.

Transforms tool calls from internal format:
extra_content={"google": {"thought_signature": "..."}}
To litellm format:
provider_specific_fields={"thought_signature": "..."}

Only processes tool_calls that appear after the last user message.
See: https://ai.google.dev/gemini-api/docs/thought-signatures
"""

# Find the index of the last user message
last_user_index = -1
for i in range(len(messages) - 1, -1, -1):
if isinstance(messages[i], dict) and messages[i].get("role") == "user":
last_user_index = i
break

for i, message in enumerate(messages):
if not isinstance(message, dict):
continue

# Only process assistant messages that come after the last user message
# If no user message found (last_user_index == -1), process all messages
if last_user_index != -1 and i <= last_user_index:
continue

# Check if this is an assistant message with tool calls
if message.get("role") == "assistant" and message.get("tool_calls"):
tool_calls = message.get("tool_calls", [])

for tool_call in tool_calls: # type: ignore[attr-defined]
if not isinstance(tool_call, dict):
continue

# Default to skip validator, overridden if valid thought signature exists
tool_call["provider_specific_fields"] = {
"thought_signature": "skip_thought_signature_validator"
}

# Override with actual thought signature if extra_content exists
if "extra_content" in tool_call:
extra_content = tool_call.pop("extra_content")
if isinstance(extra_content, dict):
# Extract google-specific fields
google_fields = extra_content.get("google")
if google_fields and isinstance(google_fields, dict):
thought_sig = google_fields.get("thought_signature")
if thought_sig:
tool_call["provider_specific_fields"] = {
"thought_signature": thought_sig
}

return messages

def _fix_tool_message_ordering(
self, messages: list[ChatCompletionMessageParam]
) -> list[ChatCompletionMessageParam]:
Expand Down Expand Up @@ -563,15 +645,26 @@ def _merge_headers(self, model_settings: ModelSettings):
class LitellmConverter:
@classmethod
def convert_message_to_openai(
cls, message: litellm.types.utils.Message
cls, message: litellm.types.utils.Message, model: str | None = None
) -> ChatCompletionMessage:
"""
Convert a LiteLLM message to OpenAI ChatCompletionMessage format.

Args:
message: The LiteLLM message to convert
model: The target model to convert to. Used to handle provider-specific
transformations.
"""
if message.role != "assistant":
raise ModelBehaviorError(f"Unsupported role: {message.role}")

tool_calls: (
list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall] | None
) = (
[LitellmConverter.convert_tool_call_to_openai(tool) for tool in message.tool_calls]
[
LitellmConverter.convert_tool_call_to_openai(tool, model=model)
for tool in message.tool_calls
]
if message.tool_calls
else None
)
Expand Down Expand Up @@ -641,13 +734,43 @@ def convert_annotations_to_openai(

@classmethod
def convert_tool_call_to_openai(
cls, tool_call: litellm.types.utils.ChatCompletionMessageToolCall
cls, tool_call: litellm.types.utils.ChatCompletionMessageToolCall, model: str | None = None
) -> ChatCompletionMessageFunctionToolCall:
return ChatCompletionMessageFunctionToolCall(
id=tool_call.id,
# Clean up litellm's addition of __thought__ suffix to tool_call.id for
# Gemini models. See: https://github.com/BerriAI/litellm/pull/16895
# This suffix is redundant since we can get thought_signature from
# provider_specific_fields, and this hack causes validation errors when
# cross-model passing to other models.
tool_call_id = tool_call.id
if model and "gemini" in model.lower() and "__thought__" in tool_call_id:
tool_call_id = tool_call_id.split("__thought__")[0]

# Convert litellm's tool call format to chat completion message format
base_tool_call = ChatCompletionMessageFunctionToolCall(
id=tool_call_id,
type="function",
function=Function(
name=tool_call.function.name or "",
arguments=tool_call.function.arguments,
),
)

# Preserve provider-specific fields if present (e.g., Gemini thought signatures)
if hasattr(tool_call, "provider_specific_fields") and tool_call.provider_specific_fields:
# Convert to nested extra_content structure
extra_content: dict[str, Any] = {}
provider_fields = tool_call.provider_specific_fields

# Check for thought_signature (Gemini specific)
if model and "gemini" in model.lower():
if "thought_signature" in provider_fields:
extra_content["google"] = {
"thought_signature": provider_fields["thought_signature"]
}

return InternalToolCall(
**base_tool_call.model_dump(),
extra_content=extra_content if extra_content else None,
)

return base_tool_call
2 changes: 1 addition & 1 deletion src/agents/handoffs/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _format_transcript_item(item: TResponseInputItem) -> str:
return f"{prefix}: {content_str}" if content_str else prefix

item_type = item.get("type", "item")
rest = {k: v for k, v in item.items() if k != "type"}
rest = {k: v for k, v in item.items() if k not in ("type", "provider_data")}
try:
serialized = json.dumps(rest, ensure_ascii=False, default=str)
except TypeError:
Expand Down
Loading