Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion src/strands/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ class BedrockConfig(TypedDict, total=False):
additional_response_field_paths: Additional response field paths to extract
cache_prompt: Cache point type for the system prompt
cache_tools: Cache point type for tools
cache_messages: Cache point type for messages. If set to "default", removes all existing cache points
from messages and adds a cache point at the end of the last message.
guardrail_id: ID of the guardrail to apply
guardrail_trace: Guardrail trace mode. Defaults to enabled.
guardrail_version: Version of the guardrail to apply
Expand All @@ -95,6 +97,7 @@ class BedrockConfig(TypedDict, total=False):
additional_response_field_paths: Optional[list[str]]
cache_prompt: Optional[str]
cache_tools: Optional[str]
cache_messages: Optional[str]
guardrail_id: Optional[str]
guardrail_trace: Optional[Literal["enabled", "disabled", "enabled_full"]]
guardrail_stream_processing_mode: Optional[Literal["sync", "async"]]
Expand Down Expand Up @@ -185,6 +188,24 @@ def get_config(self) -> BedrockConfig:
"""
return self.config

def _remove_cache_points_from_messages(self, messages: Messages) -> Messages:
"""Remove all cache points from messages.

Args:
messages: List of messages to process.

Returns:
Messages with cache points removed.
"""
cleaned_messages: Messages = []
for message in messages:
if "content" in message and isinstance(message["content"], list):
cleaned_content = [item for item in message["content"] if "cachePoint" not in item]
cleaned_messages.append({"role": message["role"], "content": cleaned_content})
else:
cleaned_messages.append(message)
return cleaned_messages
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the purpose of clearing cachePoints from previously processed messages?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out. Originally, the implementation would add a cache point to the last message on each call. To avoid accumulating too many cache points across multiple calls, I decided to remove all existing cache points before adding the new one.

However, I realized that this approach was problematic because it modified the original messages passed to the function. So I modified the code to create a copy of the messages and add the cache point to the copied version, which means clearing existing cachePoints is no longer necessary.

I have also updated the PR description to reflect this change.


def format_request(
self,
messages: Messages,
Expand All @@ -203,9 +224,21 @@ def format_request(
Returns:
A Bedrock converse stream request.
"""
# Handle cache_messages configuration
processed_messages = messages
if self.config.get("cache_messages") == "default":
# Remove all existing cache points from messages
processed_messages = self._remove_cache_points_from_messages(messages)
# Add cache point to the end of the last message
if processed_messages and len(processed_messages) > 0:
last_message = processed_messages[-1]
if "content" in last_message and isinstance(last_message["content"], list):
# Create a new list with the cache point appended
last_message["content"] = [*last_message["content"], {"cachePoint": {"type": "default"}}]

return {
"modelId": self.config["model_id"],
"messages": self._format_bedrock_messages(messages),
"messages": self._format_bedrock_messages(processed_messages),
"system": [
*([{"text": system_prompt}] if system_prompt else []),
*([{"cachePoint": {"type": self.config["cache_prompt"]}}] if self.config.get("cache_prompt") else []),
Expand Down
52 changes: 52 additions & 0 deletions tests/strands/models/test_bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,58 @@ def test_format_request_cache(model, messages, model_id, tool_spec, cache_type):
assert tru_request == exp_request


def test_format_request_cache_messages(model, model_id, cache_type):
"""Test that cache_messages removes existing cache points and adds one at the end."""
# Messages with existing cache points that should be removed
messages_with_cache = [
{
"role": "user",
"content": [
{"text": "First message"},
{"cachePoint": {"type": "default"}}, # Should be removed
],
},
{
"role": "assistant",
"content": [
{"text": "Response"},
{"cachePoint": {"type": "default"}}, # Should be removed
],
},
{
"role": "user",
"content": [{"text": "Second message"}],
},
]

model.update_config(cache_messages=cache_type)
tru_request = model.format_request(messages_with_cache)

# Verify all old cache points are removed and new one is at the end
messages = tru_request["messages"]

# Check first message has no cache point
assert messages[0]["content"] == [{"text": "First message"}]

# Check second message has no cache point
assert messages[1]["content"] == [{"text": "Response"}]

# Check last message has cache point at the end
assert messages[2]["content"] == [
{"text": "Second message"},
{"cachePoint": {"type": cache_type}},
]

# Verify the full request structure
exp_request = {
"inferenceConfig": {},
"modelId": model_id,
"messages": messages,
"system": [],
}
assert tru_request == exp_request


@pytest.mark.asyncio
async def test_stream_throttling_exception_from_event_stream_error(bedrock_client, model, messages, alist):
error_message = "Rate exceeded"
Expand Down