Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,6 +687,8 @@ def function_tool(
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None,
tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None,
) -> FunctionTool:
"""Overload for usage as @function_tool (no parentheses)."""
...
Expand All @@ -702,6 +704,8 @@ def function_tool(
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None,
tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None,
) -> Callable[[ToolFunction[...]], FunctionTool]:
"""Overload for usage as @function_tool(...)."""
...
Expand All @@ -717,6 +721,8 @@ def function_tool(
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True,
tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None,
tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None,
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
"""
Decorator to create a FunctionTool from a function. By default, we will:
Expand Down Expand Up @@ -748,6 +754,8 @@ def function_tool(
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
context and agent and returns whether the tool is enabled. Disabled tools are hidden
from the LLM at runtime.
tool_input_guardrails: Optional list of guardrails to run before invoking the tool.
tool_output_guardrails: Optional list of guardrails to run after the tool returns.
"""

def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
Expand Down Expand Up @@ -845,6 +853,8 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any:
on_invoke_tool=_on_invoke_tool,
strict_json_schema=strict_mode,
is_enabled=is_enabled,
tool_input_guardrails=tool_input_guardrails,
tool_output_guardrails=tool_output_guardrails,
)

# If func is actually a callable, we were used as @function_tool with no parentheses
Expand Down
43 changes: 43 additions & 0 deletions tests/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
FunctionTool,
ModelBehaviorError,
RunContextWrapper,
ToolGuardrailFunctionOutput,
ToolInputGuardrailData,
ToolOutputGuardrailData,
function_tool,
tool_input_guardrail,
tool_output_guardrail,
)
from agents.tool import default_tool_error_function
from agents.tool_context import ToolContext
Expand Down Expand Up @@ -96,6 +101,21 @@ def complex_args_function(foo: Foo, bar: Bar, baz: str = "hello"):
return f"{foo.a + foo.b} {bar['x']}{bar['y']} {baz}"


@tool_input_guardrail
def reject_args_guardrail(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOutput:
"""Reject tool calls for test purposes."""
return ToolGuardrailFunctionOutput.reject_content(
message="blocked",
output_info={"tool": data.context.tool_name},
)


@tool_output_guardrail
def allow_output_guardrail(data: ToolOutputGuardrailData) -> ToolGuardrailFunctionOutput:
"""Allow tool outputs for test purposes."""
return ToolGuardrailFunctionOutput.allow(output_info={"echo": data.output})


@pytest.mark.asyncio
async def test_complex_args_function():
tool = function_tool(complex_args_function, failure_error_function=None)
Expand Down Expand Up @@ -359,3 +379,26 @@ def boom() -> None:
ctx = ToolContext(None, tool_name=boom.name, tool_call_id="boom", tool_arguments="{}")
result = await boom.on_invoke_tool(ctx, "{}")
assert result.startswith("handled:")


def test_function_tool_accepts_guardrail_arguments():
tool = function_tool(
simple_function,
tool_input_guardrails=[reject_args_guardrail],
tool_output_guardrails=[allow_output_guardrail],
)

assert tool.tool_input_guardrails == [reject_args_guardrail]
assert tool.tool_output_guardrails == [allow_output_guardrail]


def test_function_tool_decorator_accepts_guardrail_arguments():
@function_tool(
tool_input_guardrails=[reject_args_guardrail],
tool_output_guardrails=[allow_output_guardrail],
)
def guarded(a: int) -> int:
return a

assert guarded.tool_input_guardrails == [reject_args_guardrail]
assert guarded.tool_output_guardrails == [allow_output_guardrail]