diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index c6a500597..d0525d5e7 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -72,6 +72,8 @@ class BedrockConfig(TypedDict, total=False): additional_response_field_paths: Additional response field paths to extract cache_prompt: Cache point type for the system prompt cache_tools: Cache point type for tools + cache_messages: Cache point type for messages. If set to "default", adds a cache point at the end + of the last message. guardrail_id: ID of the guardrail to apply guardrail_trace: Guardrail trace mode. Defaults to enabled. guardrail_version: Version of the guardrail to apply @@ -95,6 +97,7 @@ class BedrockConfig(TypedDict, total=False): additional_response_field_paths: Optional[list[str]] cache_prompt: Optional[str] cache_tools: Optional[str] + cache_messages: Optional[str] guardrail_id: Optional[str] guardrail_trace: Optional[Literal["enabled", "disabled", "enabled_full"]] guardrail_stream_processing_mode: Optional[Literal["sync", "async"]] @@ -203,9 +206,22 @@ def format_request( Returns: A Bedrock converse stream request. """ + # Handle cache_messages configuration + processed_messages = messages + if self.config.get("cache_messages") == "default": + # Add cache point to the end of the last message (create copy to avoid modifying original) + if messages and len(messages) > 0: + # Create a shallow copy of the messages list + processed_messages = list(messages) + last_message = processed_messages[-1] + if "content" in last_message and isinstance(last_message["content"], list): + # Create a new message dict with updated content + new_content = [*last_message["content"], {"cachePoint": {"type": "default"}}] + processed_messages[-1] = {"role": last_message["role"], "content": new_content} + return { "modelId": self.config["model_id"], - "messages": self._format_bedrock_messages(messages), + "messages": self._format_bedrock_messages(processed_messages), "system": [ *([{"text": system_prompt}] if system_prompt else []), *([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []), diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 96fee67fa..a5bd58ef6 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -492,6 +492,108 @@ def test_format_request_cache(model, messages, model_id, tool_spec, cache_type): assert tru_request == exp_request +def test_format_request_cache_messages(model, model_id, cache_type): + """Test that cache_messages preserves existing cache points and adds one at the end.""" + # Messages with existing cache points that should be preserved + messages_with_cache = [ + { + "role": "user", + "content": [ + {"text": "First message"}, + {"cachePoint": {"type": "default"}}, # Should be preserved + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Response"}, + {"cachePoint": {"type": "default"}}, # Should be preserved + ], + }, + { + "role": "user", + "content": [{"text": "Second message"}], + }, + ] + + model.update_config(cache_messages=cache_type) + tru_request = model.format_request(messages_with_cache) + + # Verify existing cache points are preserved and new one is added at the end + messages = tru_request["messages"] + + # Check first message still has its cache point + assert messages[0]["content"] == [ + {"text": "First message"}, + {"cachePoint": {"type": "default"}}, + ] + + # Check second message still has its cache point + assert messages[1]["content"] == [ + {"text": "Response"}, + {"cachePoint": {"type": "default"}}, + ] + + # Check third message (last) has new cache point at the end + assert messages[2]["content"] == [ + {"text": "Second message"}, + {"cachePoint": {"type": cache_type}}, + ] + + # Verify the full request structure + exp_request = { + "inferenceConfig": {}, + "modelId": model_id, + "messages": messages, + "system": [], + } + assert tru_request == exp_request + + +def test_format_request_cache_messages_does_not_modify_original(model, cache_type): + """Test that format_request does not modify the original messages when cache_messages is set.""" + # Create original messages + original_messages = [ + { + "role": "user", + "content": [ + {"text": "First message"}, + {"cachePoint": {"type": "default"}}, + ], + }, + { + "role": "assistant", + "content": [ + {"text": "Response"}, + ], + }, + { + "role": "user", + "content": [{"text": "Second message"}], + }, + ] + + # Create a deep copy for comparison + import copy + + expected_messages = copy.deepcopy(original_messages) + + # Call format_request with cache_messages enabled + model.update_config(cache_messages=cache_type) + _ = model.format_request(original_messages) + + # Verify original messages are unchanged + assert original_messages == expected_messages + + # Verify content lists are unchanged + assert original_messages[0]["content"] == [ + {"text": "First message"}, + {"cachePoint": {"type": "default"}}, + ] + assert original_messages[1]["content"] == [{"text": "Response"}] + assert original_messages[2]["content"] == [{"text": "Second message"}] + + @pytest.mark.asyncio async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist): error_message = "Rate exceeded"