Skip to content

Commit 05a7de2

Browse files
author
Danilo Poccia
committed
Add logic to distinguish server-side tools (executed by Bedrock) from
client-side tools that need agent execution. This prevents infinite loops when models return end_turn stopReason for server_tool_use types like nova_grounding.
1 parent 72db0a2 commit 05a7de2

File tree

1 file changed

+80
-21
lines changed

1 file changed

+80
-21
lines changed

src/strands/models/bedrock.py

Lines changed: 80 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -500,20 +500,17 @@ def _format_request_message_content(self, content: ContentBlock) -> dict[str, An
500500
for citation in citations["citations"]:
501501
filtered_citation: dict[str, Any] = {}
502502
if "location" in citation:
503-
location = citation["location"]
503+
location: dict[str, Any] = cast(dict[str, Any], citation["location"])
504504
filtered_location: dict[str, Any] = {}
505505
# Filter location fields to only include Bedrock-supported ones
506506
# Handle web-based citations
507507
if "web" in location:
508508
web_data = location["web"]
509509
filtered_location["web"] = {k: v for k, v in web_data.items() if k in ["url", "domain"]}
510510
# Handle document-based citations
511-
if "documentIndex" in location:
512-
filtered_location["documentIndex"] = location["documentIndex"]
513-
if "start" in location:
514-
filtered_location["start"] = location["start"]
515-
if "end" in location:
516-
filtered_location["end"] = location["end"]
511+
for field in ["documentIndex", "start", "end"]:
512+
if field in location:
513+
filtered_location[field] = location[field]
517514
if filtered_location:
518515
filtered_citation["location"] = filtered_location
519516
if "sourceContent" in citation:
@@ -687,8 +684,12 @@ def _stream(
687684
logger.debug("got response from model")
688685
if streaming:
689686
response = self.client.converse_stream(**request)
690-
# Track tool use events to fix stopReason for streaming responses
691-
has_tool_use = False
687+
# Track tool use/result events to fix stopReason for streaming responses
688+
# We need to distinguish server-side tools (already executed) from client-side tools
689+
tool_use_info: dict[str, str] = {} # toolUseId -> type (e.g., "server_tool_use")
690+
tool_result_ids: set[str] = set() # IDs of tools with results
691+
has_client_tools = False
692+
692693
for chunk in response["stream"]:
693694
if (
694695
"metadata" in chunk
@@ -700,22 +701,41 @@ def _stream(
700701
for event in self._generate_redaction_events():
701702
callback(event)
702703

703-
# Track if we see tool use events
704-
if "contentBlockStart" in chunk and chunk["contentBlockStart"].get("start", {}).get("toolUse"):
705-
has_tool_use = True
704+
# Track tool use events with their types
705+
if "contentBlockStart" in chunk:
706+
tool_use_start = chunk["contentBlockStart"].get("start", {}).get("toolUse")
707+
if tool_use_start:
708+
tool_use_id = tool_use_start.get("toolUseId", "")
709+
tool_type = tool_use_start.get("type", "")
710+
tool_use_info[tool_use_id] = tool_type
711+
# Check if it's a client-side tool (not server_tool_use)
712+
if tool_type != "server_tool_use":
713+
has_client_tools = True
714+
715+
# Track tool result events (for server-side tools that were already executed)
716+
if "contentBlockStart" in chunk:
717+
tool_result_start = chunk["contentBlockStart"].get("start", {}).get("toolResult")
718+
if tool_result_start:
719+
tool_result_ids.add(tool_result_start.get("toolUseId", ""))
706720

707721
# Fix stopReason for streaming responses that contain tool use
722+
# BUT: Only override if there are client-side tools without results
708723
if (
709-
has_tool_use
710-
and "messageStop" in chunk
724+
"messageStop" in chunk
711725
and (message_stop := chunk["messageStop"]).get("stopReason") == "end_turn"
712726
):
713-
# Create corrected chunk with tool_use stopReason
714-
modified_chunk = chunk.copy()
715-
modified_chunk["messageStop"] = message_stop.copy()
716-
modified_chunk["messageStop"]["stopReason"] = "tool_use"
717-
logger.warning("Override stop reason from end_turn to tool_use")
718-
callback(modified_chunk)
727+
# Check if we have client-side tools that need execution
728+
needs_execution = has_client_tools and not set(tool_use_info.keys()).issubset(tool_result_ids)
729+
730+
if needs_execution:
731+
# Create corrected chunk with tool_use stopReason
732+
modified_chunk = chunk.copy()
733+
modified_chunk["messageStop"] = message_stop.copy()
734+
modified_chunk["messageStop"]["stopReason"] = "tool_use"
735+
logger.warning("Override stop reason from end_turn to tool_use")
736+
callback(modified_chunk)
737+
else:
738+
callback(chunk)
719739
else:
720740
callback(chunk)
721741

@@ -777,6 +797,43 @@ def _stream(
777797
callback()
778798
logger.debug("finished streaming response from model")
779799

800+
def _has_client_side_tools_to_execute(self, message_content: list[dict[str, Any]]) -> bool:
801+
"""Check if message contains client-side tools that need execution.
802+
803+
Server-side tools (like nova_grounding) are executed by Bedrock and include
804+
toolResult blocks in the response. We should NOT override stopReason to
805+
"tool_use" for these tools.
806+
807+
Args:
808+
message_content: The content array from Bedrock response
809+
810+
Returns:
811+
True if there are client-side tools without results, False otherwise
812+
"""
813+
tool_use_ids = set()
814+
tool_result_ids = set()
815+
has_client_tools = False
816+
817+
for content in message_content:
818+
if "toolUse" in content:
819+
tool_use = content["toolUse"]
820+
tool_use_ids.add(tool_use["toolUseId"])
821+
822+
# Check if it's a server-side tool (Bedrock executes these)
823+
if tool_use.get("type") != "server_tool_use":
824+
has_client_tools = True
825+
826+
elif "toolResult" in content:
827+
# Track which tools already have results
828+
tool_result_ids.add(content["toolResult"]["toolUseId"])
829+
830+
# Only return True if there are client-side tools without results
831+
if not has_client_tools:
832+
return False
833+
834+
# Check if all tool uses have corresponding results
835+
return not tool_use_ids.issubset(tool_result_ids)
836+
780837
def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Iterable[StreamEvent]:
781838
"""Convert a non-streaming response to the streaming format.
782839
@@ -858,10 +915,12 @@ def _convert_non_streaming_to_streaming(self, response: dict[str, Any]) -> Itera
858915

859916
# Yield messageStop event
860917
# Fix stopReason for models that return end_turn when they should return tool_use on non-streaming side
918+
# BUT: Don't override for server-side tools (like nova_grounding) that are already executed
861919
current_stop_reason = response["stopReason"]
862920
if current_stop_reason == "end_turn":
863921
message_content = response["output"]["message"]["content"]
864-
if any("toolUse" in content for content in message_content):
922+
# Only override if there are client-side tools that need execution
923+
if self._has_client_side_tools_to_execute(message_content):
865924
current_stop_reason = "tool_use"
866925
logger.warning("Override stop reason from end_turn to tool_use")
867926

0 commit comments

Comments
 (0)