Skip to content

Enable non-strict output types #539

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions examples/basic/non_strict_output_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import asyncio
import json
from dataclasses import dataclass
from typing import Any

from agents import Agent, AgentOutputSchema, AgentOutputSchemaBase, Runner

"""This example demonstrates how to use an output type that is not in strict mode. Strict mode
allows us to guarantee valid JSON output, but some schemas are not strict-compatible.

In this example, we define an output type that is not strict-compatible, and then we run the
agent with strict_json_schema=False.

We also demonstrate a custom output type.

To understand which schemas are strict-compatible, see:
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
"""


@dataclass
class OutputType:
jokes: dict[int, str]
"""A list of jokes, indexed by joke number."""


class CustomOutputSchema(AgentOutputSchemaBase):
"""A demonstration of a custom output schema."""

def is_plain_text(self) -> bool:
return False

def name(self) -> str:
return "CustomOutputSchema"

def json_schema(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {"jokes": {"type": "object", "properties": {"joke": {"type": "string"}}}},
}

def is_strict_json_schema(self) -> bool:
return False

def validate_json(self, json_str: str) -> Any:
json_obj = json.loads(json_str)
# Just for demonstration, we'll return a list.
return list(json_obj["jokes"].values())


async def main():
agent = Agent(
name="Assistant",
instructions="You are a helpful assistant.",
output_type=OutputType,
)

input = "Tell me 3 short jokes."

# First, let's try with a strict output type. This should raise an exception.
try:
result = await Runner.run(agent, input)
raise AssertionError("Should have raised an exception")
except Exception as e:
print(f"Error (expected): {e}")

# Now let's try again with a non-strict output type. This should work.
# In some cases, it will raise an error - the schema isn't strict, so the model may
# produce an invalid JSON object.
agent.output_type = AgentOutputSchema(OutputType, strict_json_schema=False)
result = await Runner.run(agent, input)
print(result.final_output)

# Finally, let's try a custom output type.
agent.output_type = CustomOutputSchema()
result = await Runner.run(agent, input)
print(result.final_output)


if __name__ == "__main__":
asyncio.run(main())
3 changes: 2 additions & 1 deletion src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from . import _config
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
from .agent_output import AgentOutputSchema
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
from .computer import AsyncComputer, Button, Computer, Environment
from .exceptions import (
AgentsException,
Expand Down Expand Up @@ -158,6 +158,7 @@ def enable_verbose_stdout_logging():
"OpenAIProvider",
"OpenAIResponsesModel",
"AgentOutputSchema",
"AgentOutputSchemaBase",
"Computer",
"AsyncComputer",
"Environment",
Expand Down
6 changes: 3 additions & 3 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from openai.types.responses.response_reasoning_item import ResponseReasoningItem

from .agent import Agent, ToolsToFinalOutputResult
from .agent_output import AgentOutputSchema
from .agent_output import AgentOutputSchemaBase
from .computer import AsyncComputer, Computer
from .exceptions import AgentsException, ModelBehaviorError, UserError
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
Expand Down Expand Up @@ -195,7 +195,7 @@ async def execute_tools_and_side_effects(
pre_step_items: list[RunItem],
new_response: ModelResponse,
processed_response: ProcessedResponse,
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
Expand Down Expand Up @@ -335,7 +335,7 @@ def process_model_response(
agent: Agent[Any],
all_tools: list[Tool],
response: ModelResponse,
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
) -> ProcessedResponse:
items: list[RunItem] = []
Expand Down
11 changes: 9 additions & 2 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from typing_extensions import NotRequired, TypeAlias, TypedDict

from .agent_output import AgentOutputSchemaBase
from .guardrail import InputGuardrail, OutputGuardrail
from .handoffs import Handoff
from .items import ItemHelpers
Expand Down Expand Up @@ -141,8 +142,14 @@ class Agent(Generic[TContext]):
Runs only if the agent produces a final output.
"""

output_type: type[Any] | None = None
"""The type of the output object. If not provided, the output will be `str`."""
output_type: type[Any] | AgentOutputSchemaBase | None = None
"""The type of the output object. If not provided, the output will be `str`. In most cases,
you should pass a regular Python type (e.g. a dataclass, Pydantic model, TypedDict, etc).
You can customize this in two ways:
1. If you want non-strict schemas, pass `AgentOutputSchema(MyClass, strict_json_schema=False)`.
2. If you want to use a custom JSON schema (i.e. without using the SDK's automatic schema)
creation, subclass and pass an `AgentOutputSchemaBase` subclass.
"""

hooks: AgentHooks[TContext] | None = None
"""A class that receives callbacks on various lifecycle events for this agent.
Expand Down
66 changes: 58 additions & 8 deletions src/agents/agent_output.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
from dataclasses import dataclass
from typing import Any

Expand All @@ -12,8 +13,46 @@
_WRAPPER_DICT_KEY = "response"


class AgentOutputSchemaBase(abc.ABC):
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
produced by the LLM into the output type.
"""

@abc.abstractmethod
def is_plain_text(self) -> bool:
"""Whether the output type is plain text (versus a JSON object)."""
pass

@abc.abstractmethod
def name(self) -> str:
"""The name of the output type."""
pass

@abc.abstractmethod
def json_schema(self) -> dict[str, Any]:
"""Returns the JSON schema of the output. Will only be called if the output type is not
plain text.
"""
pass

@abc.abstractmethod
def is_strict_json_schema(self) -> bool:
"""Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema
features, but guarantees valis JSON. See here for details:
https://platform.openai.com/docs/guides/structured-outputs#supported-schemas
"""
pass

@abc.abstractmethod
def validate_json(self, json_str: str) -> Any:
"""Validate a JSON string against the output type. You must return the validated object,
or raise a `ModelBehaviorError` if the JSON is invalid.
"""
pass


@dataclass(init=False)
class AgentOutputSchema:
class AgentOutputSchema(AgentOutputSchemaBase):
"""An object that captures the JSON schema of the output, as well as validating/parsing JSON
produced by the LLM into the output type.
"""
Expand All @@ -32,7 +71,7 @@ class AgentOutputSchema:
_output_schema: dict[str, Any]
"""The JSON schema of the output."""

strict_json_schema: bool
_strict_json_schema: bool
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input.
"""
Expand All @@ -45,7 +84,7 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
setting this to True, as it increases the likelihood of correct JSON input.
"""
self.output_type = output_type
self.strict_json_schema = strict_json_schema
self._strict_json_schema = strict_json_schema

if output_type is None or output_type is str:
self._is_wrapped = False
Expand All @@ -70,24 +109,35 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True):
self._type_adapter = TypeAdapter(output_type)
self._output_schema = self._type_adapter.json_schema()

if self.strict_json_schema:
self._output_schema = ensure_strict_json_schema(self._output_schema)
if self._strict_json_schema:
try:
self._output_schema = ensure_strict_json_schema(self._output_schema)
except UserError as e:
raise UserError(
"Strict JSON schema is enabled, but the output type is not valid. "
"Either make the output type strict, or pass output_schema_strict=False to "
"your Agent()"
) from e

def is_plain_text(self) -> bool:
"""Whether the output type is plain text (versus a JSON object)."""
return self.output_type is None or self.output_type is str

def is_strict_json_schema(self) -> bool:
"""Whether the JSON schema is in strict mode."""
return self._strict_json_schema

def json_schema(self) -> dict[str, Any]:
"""The JSON schema of the output type."""
if self.is_plain_text():
raise UserError("Output type is plain text, so no JSON schema is available")
return self._output_schema

def validate_json(self, json_str: str, partial: bool = False) -> Any:
def validate_json(self, json_str: str) -> Any:
"""Validate a JSON string against the output type. Returns the validated object, or raises
a `ModelBehaviorError` if the JSON is invalid.
"""
validated = _json.validate_json(json_str, self._type_adapter, partial)
validated = _json.validate_json(json_str, self._type_adapter, partial=False)
if self._is_wrapped:
if not isinstance(validated, dict):
_error_tracing.attach_error_to_current_span(
Expand All @@ -113,7 +163,7 @@ def validate_json(self, json_str: str, partial: bool = False) -> Any:
return validated[_WRAPPER_DICT_KEY]
return validated

def output_type_name(self) -> str:
def name(self) -> str:
"""The name of the output type."""
return _type_to_str(self.output_type)

Expand Down
12 changes: 6 additions & 6 deletions src/agents/extensions/models/litellm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from openai.types.responses import Response

from ... import _debug
from ...agent_output import AgentOutputSchema
from ...agent_output import AgentOutputSchemaBase
from ...handoffs import Handoff
from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent
from ...logger import logger
Expand Down Expand Up @@ -68,7 +68,7 @@ async def get_response(
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
previous_response_id: str | None,
Expand Down Expand Up @@ -139,7 +139,7 @@ async def stream_response(
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
Expand Down Expand Up @@ -186,7 +186,7 @@ async def _fetch_response(
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
span: Span[GenerationSpanData],
tracing: ModelTracing,
Expand All @@ -200,7 +200,7 @@ async def _fetch_response(
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
span: Span[GenerationSpanData],
tracing: ModelTracing,
Expand All @@ -213,7 +213,7 @@ async def _fetch_response(
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
span: Span[GenerationSpanData],
tracing: ModelTracing,
Expand Down
6 changes: 3 additions & 3 deletions src/agents/models/chatcmpl_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
)
from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message

from ..agent_output import AgentOutputSchema
from ..agent_output import AgentOutputSchemaBase
from ..exceptions import AgentsException, UserError
from ..handoffs import Handoff
from ..items import TResponseInputItem, TResponseOutputItem
Expand Down Expand Up @@ -67,7 +67,7 @@ def convert_tool_choice(

@classmethod
def convert_response_format(
cls, final_output_schema: AgentOutputSchema | None
cls, final_output_schema: AgentOutputSchemaBase | None
) -> ResponseFormat | NotGiven:
if not final_output_schema or final_output_schema.is_plain_text():
return NOT_GIVEN
Expand All @@ -76,7 +76,7 @@ def convert_response_format(
"type": "json_schema",
"json_schema": {
"name": "final_output",
"strict": final_output_schema.strict_json_schema,
"strict": final_output_schema.is_strict_json_schema(),
"schema": final_output_schema.json_schema(),
},
}
Expand Down
6 changes: 3 additions & 3 deletions src/agents/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING

from ..agent_output import AgentOutputSchema
from ..agent_output import AgentOutputSchemaBase
from ..handoffs import Handoff
from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
from ..tool import Tool
Expand Down Expand Up @@ -41,7 +41,7 @@ async def get_response(
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
Expand Down Expand Up @@ -72,7 +72,7 @@ def stream_response(
input: str | list[TResponseInputItem],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: AgentOutputSchema | None,
output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff],
tracing: ModelTracing,
*,
Expand Down
Loading