Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class BedrockConfig(TypedDict, total=False):
additional_response_field_paths: Additional response field paths to extract
cache_prompt: Cache point type for the system prompt
cache_tools: Cache point type for tools
cache_messages: Cache point type for messages. If set to "default", adds a cache point at the end
of the last message.
guardrail_id: ID of the guardrail to apply
guardrail_trace: Guardrail trace mode. Defaults to enabled.
guardrail_version: Version of the guardrail to apply
Expand All @@ -95,6 +97,7 @@ class BedrockConfig(TypedDict, total=False):
additional_response_field_paths: Optional[list[str]]
cache_prompt: Optional[str]
cache_tools: Optional[str]
cache_messages: Optional[str]
guardrail_id: Optional[str]
guardrail_trace: Optional[Literal["enabled", "disabled", "enabled_full"]]
guardrail_stream_processing_mode: Optional[Literal["sync", "async"]]
Expand Down Expand Up @@ -203,9 +206,22 @@ def format_request(
Returns:
A Bedrock converse stream request.
"""
# Handle cache_messages configuration
processed_messages = messages
if self.config.get("cache_messages") == "default":
# Add cache point to the end of the last message (create copy to avoid modifying original)
if messages and len(messages) > 0:
# Create a shallow copy of the messages list
processed_messages = list(messages)
last_message = processed_messages[-1]
if "content" in last_message and isinstance(last_message["content"], list):
# Create a new message dict with updated content
new_content = [*last_message["content"], {"cachePoint": {"type": "default"}}]
processed_messages[-1] = {"role": last_message["role"], "content": new_content}

return {
"modelId": self.config["model_id"],
"messages": self._format_bedrock_messages(messages),
"messages": self._format_bedrock_messages(processed_messages),
"system": [
*([{"text": system_prompt}] if system_prompt else []),
*([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []),
Expand Down
102 changes: 102 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,108 @@ def test_format_request_cache(model, messages, model_id, tool_spec, cache_type):
assert tru_request == exp_request


def test_format_request_cache_messages(model, model_id, cache_type):
"""Test that cache_messages preserves existing cache points and adds one at the end."""
# Messages with existing cache points that should be preserved
messages_with_cache = [
{
"role": "user",
"content": [
{"text": "First message"},
{"cachePoint": {"type": "default"}}, # Should be preserved
],
},
{
"role": "assistant",
"content": [
{"text": "Response"},
{"cachePoint": {"type": "default"}}, # Should be preserved
],
},
{
"role": "user",
"content": [{"text": "Second message"}],
},
]

model.update_config(cache_messages=cache_type)
tru_request = model.format_request(messages_with_cache)

# Verify existing cache points are preserved and new one is added at the end
messages = tru_request["messages"]

# Check first message still has its cache point
assert messages[0]["content"] == [
{"text": "First message"},
{"cachePoint": {"type": "default"}},
]

# Check second message still has its cache point
assert messages[1]["content"] == [
{"text": "Response"},
{"cachePoint": {"type": "default"}},
]

# Check third message (last) has new cache point at the end
assert messages[2]["content"] == [
{"text": "Second message"},
{"cachePoint": {"type": cache_type}},
]

# Verify the full request structure
exp_request = {
"inferenceConfig": {},
"modelId": model_id,
"messages": messages,
"system": [],
}
assert tru_request == exp_request


def test_format_request_cache_messages_does_not_modify_original(model, cache_type):
"""Test that format_request does not modify the original messages when cache_messages is set."""
# Create original messages
original_messages = [
{
"role": "user",
"content": [
{"text": "First message"},
{"cachePoint": {"type": "default"}},
],
},
{
"role": "assistant",
"content": [
{"text": "Response"},
],
},
{
"role": "user",
"content": [{"text": "Second message"}],
},
]

# Create a deep copy for comparison
import copy

expected_messages = copy.deepcopy(original_messages)

# Call format_request with cache_messages enabled
model.update_config(cache_messages=cache_type)
_ = model.format_request(original_messages)

# Verify original messages are unchanged
assert original_messages == expected_messages

# Verify content lists are unchanged
assert original_messages[0]["content"] == [
{"text": "First message"},
{"cachePoint": {"type": "default"}},
]
assert original_messages[1]["content"] == [{"text": "Response"}]
assert original_messages[2]["content"] == [{"text": "Second message"}]


@pytest.mark.asyncio
async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist):
error_message = "Rate exceeded"
Expand Down