diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8607a2601..dcef86d17 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -29,7 +29,7 @@ ) from opentelemetry import trace as trace_api -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError from .. import _identifier from ..event_loop.event_loop import event_loop_cycle @@ -445,7 +445,7 @@ async def invoke_async( return cast(AgentResult, event["result"]) - def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> T: + def structured_output(self, output_model: Type[T], prompt: AgentInput = None, max_retries: int = 0) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -462,19 +462,23 @@ def structured_output(self, output_model: Type[T], prompt: AgentInput = None) -> - list[ContentBlock]: Multi-modal content blocks - list[Message]: Complete messages with roles - None: Use existing conversation history + max_retries: Maximum number of self-healing retry attempts (additional LLM calls) + if validation fails (default: 0). Raises: ValueError: If no conversation history or prompt is provided. """ def execute() -> T: - return asyncio.run(self.structured_output_async(output_model, prompt)) + return asyncio.run(self.structured_output_async(output_model, prompt, max_retries)) with ThreadPoolExecutor() as executor: future = executor.submit(execute) return future.result() - async def structured_output_async(self, output_model: Type[T], prompt: AgentInput = None) -> T: + async def structured_output_async( + self, output_model: Type[T], prompt: AgentInput = None, max_retries: int = 0 + ) -> T: """This method allows you to get structured output from the agent. If you pass in a prompt, it will be used temporarily without adding it to the conversation history. @@ -487,6 +491,8 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu output_model: The output model (a JSON schema written as a Pydantic BaseModel) that the agent will use when responding. prompt: The prompt to use for the agent (will not be added to conversation history). + max_retries: Maximum number of self-healing retry attempts (additional LLM calls) + if validation fails (default: 0). Raises: ValueError: If no conversation history or prompt is provided. @@ -507,6 +513,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu "gen_ai.agent.name": self.name, "gen_ai.agent.id": self.agent_id, "gen_ai.operation.name": "execute_structured_output", + "gen_ai.structured_output.max_retries": max_retries, } ) if self.system_prompt: @@ -519,17 +526,51 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu f"gen_ai.{message['role']}.message", attributes={"role": message["role"], "content": serialize(message["content"])}, ) - events = self.model.structured_output(output_model, temp_messages, system_prompt=self.system_prompt) - async for event in events: - if isinstance(event, TypedEvent): - event.prepare(invocation_state={}) - if event.is_callback_event: - self.callback_handler(**event.as_dict()) - structured_output_span.add_event( - "gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())} - ) - return event["output"] + last_exception = None + for attempt in range(max_retries + 1): + try: + if attempt > 0: + structured_output_span.add_event( + "gen_ai.structured_output.retry", + attributes={"attempt": attempt, "error": str(last_exception)}, + ) + + events = self.model.structured_output( + output_model, temp_messages, system_prompt=self.system_prompt + ) + async for event in events: + if isinstance(event, TypedEvent): + event.prepare(invocation_state={}) + if event.is_callback_event: + self.callback_handler(**event.as_dict()) + + structured_output_span.add_event( + "gen_ai.choice", attributes={"message": serialize(event["output"].model_dump())} + ) + return event["output"] + + except (ValidationError, ValueError) as e: + last_exception = e + if attempt < max_retries: + temp_messages = temp_messages + [ + { + "role": "user", + "content": [ + { + "text": ( + "Try again to generate a structured output. " + f"Your previous attempt failed with this exception: {e}" + ) + } + ], + } + ] + else: + raise + + # Should never reach here, but satisfy type checker + raise RuntimeError("Structured output failed after all retry attempts") finally: self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 200584115..c69431d2a 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1014,6 +1014,7 @@ def test_agent_structured_output(agent, system_prompt, user, agenerator): "gen_ai.agent.name": "Strands Agents", "gen_ai.agent.id": "default", "gen_ai.operation.name": "execute_structured_output", + "gen_ai.structured_output.max_retries": 0, } ) @@ -1143,6 +1144,76 @@ async def test_agent_structured_output_async(agent, system_prompt, user, agenera ) +def test_agent_structured_output_with_retry_on_validation_error(agent, system_prompt, user, agenerator): + """Test that structured_output retries on ValidationError.""" + from pydantic import ValidationError + + # First call raises ValidationError, second call succeeds + call_count = 0 + + async def mock_structured_output(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValidationError.from_exception_data("test", []) + else: + async for event in agenerator([{"output": user}]): + yield event + + agent.model.structured_output = mock_structured_output + + prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + + # Call with max_retries=1 + tru_result = agent.structured_output(type(user), prompt, max_retries=1) + exp_result = user + assert tru_result == exp_result + assert call_count == 2 # Should have been called twice + + +def test_agent_structured_output_with_retry_on_value_error(agent, system_prompt, user, agenerator): + """Test that structured_output retries on ValueError.""" + # First call raises ValueError, second call succeeds + call_count = 0 + + async def mock_structured_output(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise ValueError("No valid tool use found") + else: + async for event in agenerator([{"output": user}]): + yield event + + agent.model.structured_output = mock_structured_output + + prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + + # Call with max_retries=1 + tru_result = agent.structured_output(type(user), prompt, max_retries=1) + exp_result = user + assert tru_result == exp_result + assert call_count == 2 # Should have been called twice + + +def test_agent_structured_output_retry_exhausted(agent, system_prompt, user): + """Test that structured_output raises exception after exhausting retries.""" + from pydantic import ValidationError + + # Always raise ValidationError + async def mock_structured_output(*args, **kwargs): + raise ValidationError.from_exception_data("test", []) + yield # Make it a generator + + agent.model.structured_output = mock_structured_output + + prompt = "Jane Doe is 30 years old and her email is jane@doe.com" + + # Should raise after max_retries attempts + with pytest.raises(ValidationError): + agent.structured_output(type(user), prompt, max_retries=2) + + @pytest.mark.asyncio async def test_stream_async_returns_all_events(mock_event_loop_cycle, alist): agent = Agent()