diff --git a/examples/basic/non_strict_output_type.py b/examples/basic/non_strict_output_type.py new file mode 100644 index 000000000..49fcc4e2c --- /dev/null +++ b/examples/basic/non_strict_output_type.py @@ -0,0 +1,81 @@ +import asyncio +import json +from dataclasses import dataclass +from typing import Any + +from agents import Agent, AgentOutputSchema, AgentOutputSchemaBase, Runner + +"""This example demonstrates how to use an output type that is not in strict mode. Strict mode +allows us to guarantee valid JSON output, but some schemas are not strict-compatible. + +In this example, we define an output type that is not strict-compatible, and then we run the +agent with strict_json_schema=False. + +We also demonstrate a custom output type. + +To understand which schemas are strict-compatible, see: +https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas +""" + + +@dataclass +class OutputType: + jokes: dict[int, str] + """A list of jokes, indexed by joke number.""" + + +class CustomOutputSchema(AgentOutputSchemaBase): + """A demonstration of a custom output schema.""" + + def is_plain_text(self) -> bool: + return False + + def name(self) -> str: + return "CustomOutputSchema" + + def json_schema(self) -> dict[str, Any]: + return { + "type": "object", + "properties": {"jokes": {"type": "object", "properties": {"joke": {"type": "string"}}}}, + } + + def is_strict_json_schema(self) -> bool: + return False + + def validate_json(self, json_str: str) -> Any: + json_obj = json.loads(json_str) + # Just for demonstration, we'll return a list. + return list(json_obj["jokes"].values()) + + +async def main(): + agent = Agent( + name="Assistant", + instructions="You are a helpful assistant.", + output_type=OutputType, + ) + + input = "Tell me 3 short jokes." + + # First, let's try with a strict output type. This should raise an exception. + try: + result = await Runner.run(agent, input) + raise AssertionError("Should have raised an exception") + except Exception as e: + print(f"Error (expected): {e}") + + # Now let's try again with a non-strict output type. This should work. + # In some cases, it will raise an error - the schema isn't strict, so the model may + # produce an invalid JSON object. + agent.output_type = AgentOutputSchema(OutputType, strict_json_schema=False) + result = await Runner.run(agent, input) + print(result.final_output) + + # Finally, let's try a custom output type. + agent.output_type = CustomOutputSchema() + result = await Runner.run(agent, input) + print(result.final_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index db7d312e4..6d7c90b4f 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -6,7 +6,7 @@ from . import _config from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult -from .agent_output import AgentOutputSchema +from .agent_output import AgentOutputSchema, AgentOutputSchemaBase from .computer import AsyncComputer, Button, Computer, Environment from .exceptions import ( AgentsException, @@ -158,6 +158,7 @@ def enable_verbose_stdout_logging(): "OpenAIProvider", "OpenAIResponsesModel", "AgentOutputSchema", + "AgentOutputSchemaBase", "Computer", "AsyncComputer", "Environment", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 94c181b7f..b5a83685c 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -29,7 +29,7 @@ from openai.types.responses.response_reasoning_item import ResponseReasoningItem from .agent import Agent, ToolsToFinalOutputResult -from .agent_output import AgentOutputSchema +from .agent_output import AgentOutputSchemaBase from .computer import AsyncComputer, Computer from .exceptions import AgentsException, ModelBehaviorError, UserError from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult @@ -195,7 +195,7 @@ async def execute_tools_and_side_effects( pre_step_items: list[RunItem], new_response: ModelResponse, processed_response: ProcessedResponse, - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], run_config: RunConfig, @@ -335,7 +335,7 @@ def process_model_response( agent: Agent[Any], all_tools: list[Tool], response: ModelResponse, - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], ) -> ProcessedResponse: items: list[RunItem] = [] diff --git a/src/agents/agent.py b/src/agents/agent.py index a24456b06..e22f579fa 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -8,6 +8,7 @@ from typing_extensions import NotRequired, TypeAlias, TypedDict +from .agent_output import AgentOutputSchemaBase from .guardrail import InputGuardrail, OutputGuardrail from .handoffs import Handoff from .items import ItemHelpers @@ -141,8 +142,14 @@ class Agent(Generic[TContext]): Runs only if the agent produces a final output. """ - output_type: type[Any] | None = None - """The type of the output object. If not provided, the output will be `str`.""" + output_type: type[Any] | AgentOutputSchemaBase | None = None + """The type of the output object. If not provided, the output will be `str`. In most cases, + you should pass a regular Python type (e.g. a dataclass, Pydantic model, TypedDict, etc). + You can customize this in two ways: + 1. If you want non-strict schemas, pass `AgentOutputSchema(MyClass, strict_json_schema=False)`. + 2. If you want to use a custom JSON schema (i.e. without using the SDK's automatic schema) + creation, subclass and pass an `AgentOutputSchemaBase` subclass. + """ hooks: AgentHooks[TContext] | None = None """A class that receives callbacks on various lifecycle events for this agent. diff --git a/src/agents/agent_output.py b/src/agents/agent_output.py index 3262c57d6..066bbd835 100644 --- a/src/agents/agent_output.py +++ b/src/agents/agent_output.py @@ -1,3 +1,4 @@ +import abc from dataclasses import dataclass from typing import Any @@ -12,8 +13,46 @@ _WRAPPER_DICT_KEY = "response" +class AgentOutputSchemaBase(abc.ABC): + """An object that captures the JSON schema of the output, as well as validating/parsing JSON + produced by the LLM into the output type. + """ + + @abc.abstractmethod + def is_plain_text(self) -> bool: + """Whether the output type is plain text (versus a JSON object).""" + pass + + @abc.abstractmethod + def name(self) -> str: + """The name of the output type.""" + pass + + @abc.abstractmethod + def json_schema(self) -> dict[str, Any]: + """Returns the JSON schema of the output. Will only be called if the output type is not + plain text. + """ + pass + + @abc.abstractmethod + def is_strict_json_schema(self) -> bool: + """Whether the JSON schema is in strict mode. Strict mode constrains the JSON schema + features, but guarantees valis JSON. See here for details: + https://platform.openai.com/docs/guides/structured-outputs#supported-schemas + """ + pass + + @abc.abstractmethod + def validate_json(self, json_str: str) -> Any: + """Validate a JSON string against the output type. You must return the validated object, + or raise a `ModelBehaviorError` if the JSON is invalid. + """ + pass + + @dataclass(init=False) -class AgentOutputSchema: +class AgentOutputSchema(AgentOutputSchemaBase): """An object that captures the JSON schema of the output, as well as validating/parsing JSON produced by the LLM into the output type. """ @@ -32,7 +71,7 @@ class AgentOutputSchema: _output_schema: dict[str, Any] """The JSON schema of the output.""" - strict_json_schema: bool + _strict_json_schema: bool """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, as it increases the likelihood of correct JSON input. """ @@ -45,7 +84,7 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True): setting this to True, as it increases the likelihood of correct JSON input. """ self.output_type = output_type - self.strict_json_schema = strict_json_schema + self._strict_json_schema = strict_json_schema if output_type is None or output_type is str: self._is_wrapped = False @@ -70,24 +109,35 @@ def __init__(self, output_type: type[Any], strict_json_schema: bool = True): self._type_adapter = TypeAdapter(output_type) self._output_schema = self._type_adapter.json_schema() - if self.strict_json_schema: - self._output_schema = ensure_strict_json_schema(self._output_schema) + if self._strict_json_schema: + try: + self._output_schema = ensure_strict_json_schema(self._output_schema) + except UserError as e: + raise UserError( + "Strict JSON schema is enabled, but the output type is not valid. " + "Either make the output type strict, or pass output_schema_strict=False to " + "your Agent()" + ) from e def is_plain_text(self) -> bool: """Whether the output type is plain text (versus a JSON object).""" return self.output_type is None or self.output_type is str + def is_strict_json_schema(self) -> bool: + """Whether the JSON schema is in strict mode.""" + return self._strict_json_schema + def json_schema(self) -> dict[str, Any]: """The JSON schema of the output type.""" if self.is_plain_text(): raise UserError("Output type is plain text, so no JSON schema is available") return self._output_schema - def validate_json(self, json_str: str, partial: bool = False) -> Any: + def validate_json(self, json_str: str) -> Any: """Validate a JSON string against the output type. Returns the validated object, or raises a `ModelBehaviorError` if the JSON is invalid. """ - validated = _json.validate_json(json_str, self._type_adapter, partial) + validated = _json.validate_json(json_str, self._type_adapter, partial=False) if self._is_wrapped: if not isinstance(validated, dict): _error_tracing.attach_error_to_current_span( @@ -113,7 +163,7 @@ def validate_json(self, json_str: str, partial: bool = False) -> Any: return validated[_WRAPPER_DICT_KEY] return validated - def output_type_name(self) -> str: + def name(self) -> str: """The name of the output type.""" return _type_to_str(self.output_type) diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index 0fc277c35..e939ee8da 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -29,7 +29,7 @@ from openai.types.responses import Response from ... import _debug -from ...agent_output import AgentOutputSchema +from ...agent_output import AgentOutputSchemaBase from ...handoffs import Handoff from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent from ...logger import logger @@ -68,7 +68,7 @@ async def get_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, previous_response_id: str | None, @@ -139,7 +139,7 @@ async def stream_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, @@ -186,7 +186,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], span: Span[GenerationSpanData], tracing: ModelTracing, @@ -200,7 +200,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], span: Span[GenerationSpanData], tracing: ModelTracing, @@ -213,7 +213,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], span: Span[GenerationSpanData], tracing: ModelTracing, diff --git a/src/agents/models/chatcmpl_converter.py b/src/agents/models/chatcmpl_converter.py index 00175a16a..613a37453 100644 --- a/src/agents/models/chatcmpl_converter.py +++ b/src/agents/models/chatcmpl_converter.py @@ -36,7 +36,7 @@ ) from openai.types.responses.response_input_param import FunctionCallOutput, ItemReference, Message -from ..agent_output import AgentOutputSchema +from ..agent_output import AgentOutputSchemaBase from ..exceptions import AgentsException, UserError from ..handoffs import Handoff from ..items import TResponseInputItem, TResponseOutputItem @@ -67,7 +67,7 @@ def convert_tool_choice( @classmethod def convert_response_format( - cls, final_output_schema: AgentOutputSchema | None + cls, final_output_schema: AgentOutputSchemaBase | None ) -> ResponseFormat | NotGiven: if not final_output_schema or final_output_schema.is_plain_text(): return NOT_GIVEN @@ -76,7 +76,7 @@ def convert_response_format( "type": "json_schema", "json_schema": { "name": "final_output", - "strict": final_output_schema.strict_json_schema, + "strict": final_output_schema.is_strict_json_schema(), "schema": final_output_schema.json_schema(), }, } diff --git a/src/agents/models/interface.py b/src/agents/models/interface.py index bcf2c1a65..3a79e5640 100644 --- a/src/agents/models/interface.py +++ b/src/agents/models/interface.py @@ -5,7 +5,7 @@ from collections.abc import AsyncIterator from typing import TYPE_CHECKING -from ..agent_output import AgentOutputSchema +from ..agent_output import AgentOutputSchemaBase from ..handoffs import Handoff from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent from ..tool import Tool @@ -41,7 +41,7 @@ async def get_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, @@ -72,7 +72,7 @@ def stream_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 9989c1ee0..9fd102690 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -12,7 +12,7 @@ from openai.types.responses import Response from .. import _debug -from ..agent_output import AgentOutputSchema +from ..agent_output import AgentOutputSchemaBase from ..handoffs import Handoff from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent from ..logger import logger @@ -49,7 +49,7 @@ async def get_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, previous_response_id: str | None, @@ -110,7 +110,7 @@ async def stream_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, @@ -160,7 +160,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], span: Span[GenerationSpanData], tracing: ModelTracing, @@ -174,7 +174,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], span: Span[GenerationSpanData], tracing: ModelTracing, @@ -187,7 +187,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], span: Span[GenerationSpanData], tracing: ModelTracing, diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index ab4617d46..b751663da 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -18,7 +18,7 @@ ) from .. import _debug -from ..agent_output import AgentOutputSchema +from ..agent_output import AgentOutputSchemaBase from ..exceptions import UserError from ..handoffs import Handoff from ..items import ItemHelpers, ModelResponse, TResponseInputItem @@ -66,7 +66,7 @@ async def get_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, previous_response_id: str | None, @@ -131,7 +131,7 @@ async def stream_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, previous_response_id: str | None, @@ -182,7 +182,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], previous_response_id: str | None, stream: Literal[True], @@ -195,7 +195,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], previous_response_id: str | None, stream: Literal[False], @@ -207,7 +207,7 @@ async def _fetch_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], previous_response_id: str | None, stream: Literal[True] | Literal[False] = False, @@ -307,7 +307,7 @@ def convert_tool_choice( @classmethod def get_response_format( - cls, output_schema: AgentOutputSchema | None + cls, output_schema: AgentOutputSchemaBase | None ) -> ResponseTextConfigParam | NotGiven: if output_schema is None or output_schema.is_plain_text(): return NOT_GIVEN @@ -317,7 +317,7 @@ def get_response_format( "type": "json_schema", "name": "final_output", "schema": output_schema.json_schema(), - "strict": output_schema.strict_json_schema, + "strict": output_schema.is_strict_json_schema(), } } diff --git a/src/agents/result.py b/src/agents/result.py index a2a6cc4a5..2996eaf91 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -10,7 +10,7 @@ from ._run_impl import QueueCompleteSentinel from .agent import Agent -from .agent_output import AgentOutputSchema +from .agent_output import AgentOutputSchemaBase from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded from .guardrail import InputGuardrailResult, OutputGuardrailResult from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem @@ -124,7 +124,7 @@ class RunResultStreaming(RunResultBase): final_output: Any """The final output of the agent. This is None until the agent has finished running.""" - _current_agent_output_schema: AgentOutputSchema | None = field(repr=False) + _current_agent_output_schema: AgentOutputSchemaBase | None = field(repr=False) _trace: Trace | None = field(repr=False) diff --git a/src/agents/run.py b/src/agents/run.py index e2b0dbceb..1fc7b52f1 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -19,7 +19,7 @@ get_model_tracing_impl, ) from .agent import Agent -from .agent_output import AgentOutputSchema +from .agent_output import AgentOutputSchema, AgentOutputSchemaBase from .exceptions import ( AgentsException, InputGuardrailTripwireTriggered, @@ -185,7 +185,7 @@ async def run( if current_span is None: handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] if output_schema := cls._get_output_schema(current_agent): - output_type_name = output_schema.output_type_name() + output_type_name = output_schema.name() else: output_type_name = "str" @@ -517,7 +517,7 @@ async def _run_streamed_impl( if current_span is None: handoff_names = [h.agent_name for h in cls._get_handoffs(current_agent)] if output_schema := cls._get_output_schema(current_agent): - output_type_name = output_schema.output_type_name() + output_type_name = output_schema.name() else: output_type_name = "str" @@ -789,7 +789,7 @@ async def _get_single_step_result_from_response( original_input: str | list[TResponseInputItem], pre_step_items: list[RunItem], new_response: ModelResponse, - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], hooks: RunHooks[TContext], context_wrapper: RunContextWrapper[TContext], @@ -900,7 +900,7 @@ async def _get_new_response( agent: Agent[TContext], system_prompt: str | None, input: list[TResponseInputItem], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, all_tools: list[Tool], handoffs: list[Handoff], context_wrapper: RunContextWrapper[TContext], @@ -930,9 +930,11 @@ async def _get_new_response( return new_response @classmethod - def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchema | None: + def _get_output_schema(cls, agent: Agent[Any]) -> AgentOutputSchemaBase | None: if agent.output_type is None or agent.output_type is str: return None + elif isinstance(agent.output_type, AgentOutputSchemaBase): + return agent.output_type return AgentOutputSchema(agent.output_type) diff --git a/tests/fake_model.py b/tests/fake_model.py index 52d3a3b2b..c6b3ba924 100644 --- a/tests/fake_model.py +++ b/tests/fake_model.py @@ -5,7 +5,7 @@ from openai.types.responses import Response, ResponseCompletedEvent -from agents.agent_output import AgentOutputSchema +from agents.agent_output import AgentOutputSchemaBase from agents.handoffs import Handoff from agents.items import ( ModelResponse, @@ -51,7 +51,7 @@ async def get_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, @@ -93,7 +93,7 @@ async def stream_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, diff --git a/tests/test_agent_config.py b/tests/test_agent_config.py index 44339dad3..f79c0cf8a 100644 --- a/tests/test_agent_config.py +++ b/tests/test_agent_config.py @@ -1,7 +1,7 @@ import pytest from pydantic import BaseModel -from agents import Agent, Handoff, RunContextWrapper, Runner, handoff +from agents import Agent, AgentOutputSchema, Handoff, RunContextWrapper, Runner, handoff @pytest.mark.asyncio @@ -160,8 +160,9 @@ async def test_agent_final_output(): ) schema = Runner._get_output_schema(agent) + assert isinstance(schema, AgentOutputSchema) assert schema is not None assert schema.output_type == Foo - assert schema.strict_json_schema is True + assert schema.is_strict_json_schema() is True assert schema.json_schema() is not None assert not schema.is_plain_text() diff --git a/tests/test_openai_chatcompletions_converter.py b/tests/test_openai_chatcompletions_converter.py index e3a18b255..bcfca5495 100644 --- a/tests/test_openai_chatcompletions_converter.py +++ b/tests/test_openai_chatcompletions_converter.py @@ -232,7 +232,7 @@ def test_convert_response_format_returns_not_given_for_plain_text_and_dict_for_s assert resp_format["type"] == "json_schema" assert resp_format["json_schema"]["name"] == "final_output" assert "strict" in resp_format["json_schema"] - assert resp_format["json_schema"]["strict"] == schema.strict_json_schema + assert resp_format["json_schema"]["strict"] == schema.is_strict_json_schema() assert "schema" in resp_format["json_schema"] assert resp_format["json_schema"]["schema"] == schema.json_schema() diff --git a/tests/test_openai_responses_converter.py b/tests/test_openai_responses_converter.py index 34cbac5c5..8e4866656 100644 --- a/tests/test_openai_responses_converter.py +++ b/tests/test_openai_responses_converter.py @@ -92,7 +92,7 @@ class OutModel(BaseModel): assert inner.get("name") == "final_output" assert isinstance(inner.get("schema"), dict) # Should include a strict flag matching the schema's strictness setting. - assert inner.get("strict") == out_schema.strict_json_schema + assert inner.get("strict") == out_schema.is_strict_json_schema() def test_convert_tools_basic_types_and_includes(): diff --git a/tests/test_output_tool.py b/tests/test_output_tool.py index 86c4b3b5a..37c1b1b67 100644 --- a/tests/test_output_tool.py +++ b/tests/test_output_tool.py @@ -1,10 +1,18 @@ import json +from typing import Any import pytest from pydantic import BaseModel from typing_extensions import TypedDict -from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError +from agents import ( + Agent, + AgentOutputSchema, + AgentOutputSchemaBase, + ModelBehaviorError, + Runner, + UserError, +) from agents.agent_output import _WRAPPER_DICT_KEY from agents.util import _json @@ -27,6 +35,7 @@ def test_structured_output_pydantic(): output_schema = Runner._get_output_schema(agent) assert output_schema, "Should have an output tool config with a structured output type" + assert isinstance(output_schema, AgentOutputSchema) assert output_schema.output_type == Foo, "Should have the correct output type" assert not output_schema._is_wrapped, "Pydantic objects should not be wrapped" for key, value in Foo.model_json_schema().items(): @@ -45,6 +54,7 @@ def test_structured_output_typed_dict(): agent = Agent(name="test", output_type=Bar) output_schema = Runner._get_output_schema(agent) assert output_schema, "Should have an output tool config with a structured output type" + assert isinstance(output_schema, AgentOutputSchema) assert output_schema.output_type == Bar, "Should have the correct output type" assert not output_schema._is_wrapped, "TypedDicts should not be wrapped" @@ -57,6 +67,7 @@ def test_structured_output_list(): agent = Agent(name="test", output_type=list[str]) output_schema = Runner._get_output_schema(agent) assert output_schema, "Should have an output tool config with a structured output type" + assert isinstance(output_schema, AgentOutputSchema) assert output_schema.output_type == list[str], "Should have the correct output type" assert output_schema._is_wrapped, "Lists should be wrapped" @@ -98,7 +109,7 @@ def test_plain_text_obj_doesnt_produce_schema(): def test_structured_output_is_strict(): output_wrapper = AgentOutputSchema(output_type=Foo) - assert output_wrapper.strict_json_schema + assert output_wrapper.is_strict_json_schema() for key, value in Foo.model_json_schema().items(): assert output_wrapper.json_schema()[key] == value @@ -110,6 +121,48 @@ def test_structured_output_is_strict(): def test_setting_strict_false_works(): output_wrapper = AgentOutputSchema(output_type=Foo, strict_json_schema=False) - assert not output_wrapper.strict_json_schema + assert not output_wrapper.is_strict_json_schema() assert output_wrapper.json_schema() == Foo.model_json_schema() assert output_wrapper.json_schema() == Foo.model_json_schema() + + +_CUSTOM_OUTPUT_SCHEMA_JSON_SCHEMA = { + "type": "object", + "properties": { + "foo": {"type": "string"}, + }, + "required": ["foo"], +} + + +class CustomOutputSchema(AgentOutputSchemaBase): + def is_plain_text(self) -> bool: + return False + + def name(self) -> str: + return "FooBarBaz" + + def json_schema(self) -> dict[str, Any]: + return _CUSTOM_OUTPUT_SCHEMA_JSON_SCHEMA + + def is_strict_json_schema(self) -> bool: + return False + + def validate_json(self, json_str: str) -> Any: + return ["some", "output"] + + +def test_custom_output_schema(): + custom_output_schema = CustomOutputSchema() + agent = Agent(name="test", output_type=custom_output_schema) + output_schema = Runner._get_output_schema(agent) + + assert output_schema, "Should have an output tool config with a structured output type" + assert isinstance(output_schema, CustomOutputSchema) + assert output_schema.json_schema() == _CUSTOM_OUTPUT_SCHEMA_JSON_SCHEMA + assert not output_schema.is_strict_json_schema() + assert not output_schema.is_plain_text() + + json_str = json.dumps({"foo": "bar"}) + validated = output_schema.validate_json(json_str) + assert validated == ["some", "output"] diff --git a/tests/voice/test_workflow.py b/tests/voice/test_workflow.py index 72a3370d3..2bdf2a657 100644 --- a/tests/voice/test_workflow.py +++ b/tests/voice/test_workflow.py @@ -9,7 +9,7 @@ from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent from agents import Agent, Model, ModelSettings, ModelTracing, Tool -from agents.agent_output import AgentOutputSchema +from agents.agent_output import AgentOutputSchemaBase from agents.handoffs import Handoff from agents.items import ( ModelResponse, @@ -48,7 +48,7 @@ async def get_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *, @@ -62,7 +62,7 @@ async def stream_response( input: str | list[TResponseInputItem], model_settings: ModelSettings, tools: list[Tool], - output_schema: AgentOutputSchema | None, + output_schema: AgentOutputSchemaBase | None, handoffs: list[Handoff], tracing: ModelTracing, *,