Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ class RunResultStreaming(RunResultBase):
# Store the asyncio tasks that we're waiting on
run_loop_task: asyncio.Task[Any] | None = field(default=None, repr=False)
_input_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
_triggered_input_guardrail_result: InputGuardrailResult | None = field(default=None, repr=False)
_output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False)
_stored_exception: Exception | None = field(default=None, repr=False)
_cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False)
Expand Down
1 change: 1 addition & 0 deletions src/agents/run_internal/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ async def run_input_guardrails_with_queue(
for done in asyncio.as_completed(guardrail_tasks):
result = await done
if result.output.tripwire_triggered:
streamed_result._triggered_input_guardrail_result = result
for t in guardrail_tasks:
t.cancel()
await asyncio.gather(*guardrail_tasks, return_exceptions=True)
Expand Down
23 changes: 22 additions & 1 deletion src/agents/run_internal/run_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,9 @@ async def _save_stream_items_without_count(
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
break
except Exception as e:
if current_span and not isinstance(e, ModelBehaviorError):
if current_span and not isinstance(
e, (ModelBehaviorError, InputGuardrailTripwireTriggered)
):
_error_tracing.attach_error_to_span(
current_span,
SpanError(
Expand Down Expand Up @@ -1100,6 +1102,24 @@ async def run_single_turn_streamed(
reasoning_item_id_policy: ReasoningItemIdPolicy | None = None,
) -> SingleStepResult:
"""Run a single streamed turn and emit events as results arrive."""

async def raise_if_input_guardrail_tripwire_known() -> None:
tripwire_result = streamed_result._triggered_input_guardrail_result
if tripwire_result is not None:
raise InputGuardrailTripwireTriggered(tripwire_result)
Comment on lines +1107 to +1109

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Preserve guardrail results before raising streamed tripwire

raise_if_input_guardrail_tripwire_known raises immediately when _triggered_input_guardrail_result is set. In the slow-cancel sibling path, run_input_guardrails_with_queue has not yet appended that result to streamed_result.input_guardrail_results, so the surfaced InputGuardrailTripwireTriggered.run_data.input_guardrail_results can be empty. This drops the triggering guardrail context for callers handling the exception.

Useful? React with 👍 / 👎.


task = streamed_result._input_guardrails_task
if task is None or not task.done():
return

guardrail_exception = task.exception()
if guardrail_exception is not None:
raise guardrail_exception

tripwire_result = streamed_result._triggered_input_guardrail_result
if tripwire_result is not None:
raise InputGuardrailTripwireTriggered(tripwire_result)

emitted_tool_call_ids: set[str] = set()
emitted_reasoning_item_ids: set[str] = set()
emitted_tool_search_fingerprints: set[str] = set()
Expand Down Expand Up @@ -1433,6 +1453,7 @@ async def rewind_model_request() -> None:
run_config=run_config,
tool_use_tracker=tool_use_tracker,
event_queue=streamed_result._event_queue,
before_side_effects=raise_if_input_guardrail_tripwire_known,
)

items_to_filter = session_items_for_turn(single_step_result)
Expand Down
4 changes: 4 additions & 0 deletions src/agents/run_internal/turn_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,6 +1696,7 @@ async def get_single_step_result_from_response(
run_config: RunConfig,
tool_use_tracker,
event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] | None = None,
before_side_effects: Callable[[], Awaitable[None]] | None = None,
) -> SingleStepResult:
processed_response = process_model_response(
agent=agent,
Expand All @@ -1706,6 +1707,9 @@ async def get_single_step_result_from_response(
existing_items=pre_step_items,
)

if before_side_effects is not None:
await before_side_effects()

tool_use_tracker.record_processed_response(agent, processed_response)

if event_queue is not None and processed_response.new_items:
Expand Down
132 changes: 132 additions & 0 deletions tests/test_guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,138 @@ async def slow_parallel_check(
assert model.first_turn_args is not None, "Model should have been called in parallel mode"


@pytest.mark.asyncio
async def test_parallel_guardrail_trip_before_tool_execution_stops_streaming_turn():
tool_was_executed = False
model_started = asyncio.Event()
guardrail_tripped = asyncio.Event()

@function_tool
def dangerous_tool() -> str:
nonlocal tool_was_executed
tool_was_executed = True
return "tool_executed"

@input_guardrail(run_in_parallel=True)
async def tripwire_before_tool_execution(
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
) -> GuardrailFunctionOutput:
await asyncio.wait_for(model_started.wait(), timeout=1)
guardrail_tripped.set()
return GuardrailFunctionOutput(
output_info="parallel_trip_before_tool_execution",
tripwire_triggered=True,
)

model = FakeModel()
original_stream_response = model.stream_response

async def delayed_stream_response(*args, **kwargs):
model_started.set()
await asyncio.wait_for(guardrail_tripped.wait(), timeout=1)
await asyncio.sleep(SHORT_DELAY)
async for event in original_stream_response(*args, **kwargs):
yield event

agent = Agent(
name="streaming_guardrail_hardening_agent",
instructions="Call the dangerous_tool immediately",
tools=[dangerous_tool],
input_guardrails=[tripwire_before_tool_execution],
model=model,
)
model.set_next_output([get_function_tool_call("dangerous_tool", arguments="{}")])
model.set_next_output([get_text_message("done")])

with patch.object(model, "stream_response", side_effect=delayed_stream_response):
result = Runner.run_streamed(agent, "trigger guardrail")

with pytest.raises(InputGuardrailTripwireTriggered):
async for _event in result.stream_events():
pass

assert model_started.is_set() is True
assert guardrail_tripped.is_set() is True
assert tool_was_executed is False
assert model.first_turn_args is not None, "Model should have been called in parallel mode"


@pytest.mark.asyncio
async def test_parallel_guardrail_trip_with_slow_cancel_sibling_stops_streaming_turn():
tool_was_executed = False
model_started = asyncio.Event()
guardrail_tripped = asyncio.Event()
slow_cancel_started = asyncio.Event()
slow_cancel_finished = asyncio.Event()

@function_tool
def dangerous_tool() -> str:
nonlocal tool_was_executed
tool_was_executed = True
return "tool_executed"

@input_guardrail(run_in_parallel=True)
async def tripwire_before_tool_execution(
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
) -> GuardrailFunctionOutput:
await asyncio.wait_for(model_started.wait(), timeout=1)
guardrail_tripped.set()
return GuardrailFunctionOutput(
output_info="parallel_trip_before_tool_execution_with_slow_cancel",
tripwire_triggered=True,
)

@input_guardrail(run_in_parallel=True)
async def slow_to_cancel_guardrail(
ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
) -> GuardrailFunctionOutput:
try:
await asyncio.Event().wait()
return GuardrailFunctionOutput(
output_info="slow_to_cancel_guardrail_completed",
tripwire_triggered=False,
)
except asyncio.CancelledError:
slow_cancel_started.set()
await asyncio.sleep(SHORT_DELAY)
slow_cancel_finished.set()
raise

model = FakeModel()
original_stream_response = model.stream_response

async def delayed_stream_response(*args, **kwargs):
model_started.set()
await asyncio.wait_for(guardrail_tripped.wait(), timeout=1)
await asyncio.wait_for(slow_cancel_started.wait(), timeout=1)
async for event in original_stream_response(*args, **kwargs):
yield event

agent = Agent(
name="streaming_guardrail_slow_cancel_agent",
instructions="Call the dangerous_tool immediately",
tools=[dangerous_tool],
input_guardrails=[tripwire_before_tool_execution, slow_to_cancel_guardrail],
model=model,
)
model.set_next_output([get_function_tool_call("dangerous_tool", arguments="{}")])
model.set_next_output([get_text_message("done")])

with patch.object(model, "stream_response", side_effect=delayed_stream_response):
result = Runner.run_streamed(agent, "trigger guardrail")

with pytest.raises(InputGuardrailTripwireTriggered):
async for _event in result.stream_events():
pass

assert model_started.is_set() is True
assert guardrail_tripped.is_set() is True
assert slow_cancel_started.is_set() is True
assert slow_cancel_finished.is_set() is True
assert tool_was_executed is False
assert model.first_turn_args is not None, "Model should have been called in parallel mode"


@pytest.mark.asyncio
async def test_blocking_guardrail_prevents_tool_execution():
tool_was_executed = False
Expand Down
Loading