From de878a26cb5cf5f444f42bcb8dcf7a13db3e7968 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Tue, 23 Dec 2025 23:16:02 +0900 Subject: [PATCH] feat: Add tool guardrails to function_tool decorator args (ref #2218) --- src/agents/tool.py | 10 +++++++++ tests/test_function_tool.py | 43 +++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/src/agents/tool.py b/src/agents/tool.py index 8c8d3e9880..86c9a6bc9e 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -687,6 +687,8 @@ def function_tool( failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, + tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None, + tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None, ) -> FunctionTool: """Overload for usage as @function_tool (no parentheses).""" ... @@ -702,6 +704,8 @@ def function_tool( failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, + tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None, + tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None, ) -> Callable[[ToolFunction[...]], FunctionTool]: """Overload for usage as @function_tool(...).""" ... @@ -717,6 +721,8 @@ def function_tool( failure_error_function: ToolErrorFunction | None = default_tool_error_function, strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, + tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None, + tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: """ Decorator to create a FunctionTool from a function. By default, we will: @@ -748,6 +754,8 @@ def function_tool( is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run context and agent and returns whether the tool is enabled. Disabled tools are hidden from the LLM at runtime. + tool_input_guardrails: Optional list of guardrails to run before invoking the tool. + tool_output_guardrails: Optional list of guardrails to run after the tool returns. """ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: @@ -845,6 +853,8 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any: on_invoke_tool=_on_invoke_tool, strict_json_schema=strict_mode, is_enabled=is_enabled, + tool_input_guardrails=tool_input_guardrails, + tool_output_guardrails=tool_output_guardrails, ) # If func is actually a callable, we were used as @function_tool with no parentheses diff --git a/tests/test_function_tool.py b/tests/test_function_tool.py index 18107773d8..3597f48c38 100644 --- a/tests/test_function_tool.py +++ b/tests/test_function_tool.py @@ -11,7 +11,12 @@ FunctionTool, ModelBehaviorError, RunContextWrapper, + ToolGuardrailFunctionOutput, + ToolInputGuardrailData, + ToolOutputGuardrailData, function_tool, + tool_input_guardrail, + tool_output_guardrail, ) from agents.tool import default_tool_error_function from agents.tool_context import ToolContext @@ -96,6 +101,21 @@ def complex_args_function(foo: Foo, bar: Bar, baz: str = "hello"): return f"{foo.a + foo.b} {bar['x']}{bar['y']} {baz}" +@tool_input_guardrail +def reject_args_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput: + """Reject tool calls for test purposes.""" + return ToolGuardrailFunctionOutput.reject_content( + message="blocked", + output_info={"tool": data.context.tool_name}, + ) + + +@tool_output_guardrail +def allow_output_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput: + """Allow tool outputs for test purposes.""" + return ToolGuardrailFunctionOutput.allow(output_info={"echo": data.output}) + + @pytest.mark.asyncio async def test_complex_args_function(): tool = function_tool(complex_args_function, failure_error_function=None) @@ -359,3 +379,26 @@ def boom() -> None: ctx = ToolContext(None, tool_name=boom.name, tool_call_id="boom", tool_arguments="{}") result = await boom.on_invoke_tool(ctx, "{}") assert result.startswith("handled:") + + +def test_function_tool_accepts_guardrail_arguments(): + tool = function_tool( + simple_function, + tool_input_guardrails=[reject_args_guardrail], + tool_output_guardrails=[allow_output_guardrail], + ) + + assert tool.tool_input_guardrails == [reject_args_guardrail] + assert tool.tool_output_guardrails == [allow_output_guardrail] + + +def test_function_tool_decorator_accepts_guardrail_arguments(): + @function_tool( + tool_input_guardrails=[reject_args_guardrail], + tool_output_guardrails=[allow_output_guardrail], + ) + def guarded(a: int) -> int: + return a + + assert guarded.tool_input_guardrails == [reject_args_guardrail] + assert guarded.tool_output_guardrails == [allow_output_guardrail]