Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions src/strands/models/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def _format_request_message_contents(self, role: str, content: ContentBlock) ->
if "image" in content:
return [{"role": role, "images": [content["image"]["source"]["bytes"]]}]

if "reasoningContent" in content:
return []

if "toolUse" in content:
return [
{
Expand Down Expand Up @@ -237,13 +240,16 @@ def format_chunk(self, event: dict[str, Any]) -> StreamEvent:
return {"messageStart": {"role": "assistant"}}

case "content_start":
if event["data_type"] == "text":
if event["data_type"] == "text" or event["data_type"] == "reasoning_content":
return {"contentBlockStart": {"start": {}}}

tool_name = event["data"].function.name
return {"contentBlockStart": {"start": {"toolUse": {"name": tool_name, "toolUseId": tool_name}}}}

case "content_delta":
if event["data_type"] == "reasoning_content":
return {"contentBlockDelta": {"delta": {"reasoningContent": {"text": event["data"]}}}}

if event["data_type"] == "text":
return {"contentBlockDelta": {"delta": {"text": event["data"]}}}

Expand Down Expand Up @@ -320,14 +326,29 @@ async def stream(
yield self.format_chunk({"chunk_type": "message_start"})
yield self.format_chunk({"chunk_type": "content_start", "data_type": "text"})

is_thinking = False
async for event in response:
if event.message.thinking:
if not is_thinking:
is_thinking = True
yield self.format_chunk({"chunk_type": "content_start", "data_type": "reasoning_content"})
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "reasoning_content", "data": event.message.thinking}
)
elif is_thinking:
is_thinking = False
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "reasoning_content"})

for tool_call in event.message.tool_calls or []:
yield self.format_chunk({"chunk_type": "content_start", "data_type": "tool", "data": tool_call})
yield self.format_chunk({"chunk_type": "content_delta", "data_type": "tool", "data": tool_call})
yield self.format_chunk({"chunk_type": "content_stop", "data_type": "tool", "data": tool_call})
tool_requested = True

yield self.format_chunk({"chunk_type": "content_delta", "data_type": "text", "data": event.message.content})
if event.message.content:
yield self.format_chunk(
{"chunk_type": "content_delta", "data_type": "text", "data": event.message.content}
)

yield self.format_chunk({"chunk_type": "content_stop", "data_type": "text"})
yield self.format_chunk(
Expand Down
77 changes: 77 additions & 0 deletions tests/strands/models/test_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ async def test_stream(ollama_client, model, agenerator, alist, captured_warnings
mock_event = unittest.mock.Mock()
mock_event.message.tool_calls = None
mock_event.message.content = "Hello"
mock_event.message.thinking = None
mock_event.done_reason = "stop"
mock_event.eval_count = 10
mock_event.prompt_eval_count = 5
Expand Down Expand Up @@ -457,6 +458,63 @@ async def test_stream(ollama_client, model, agenerator, alist, captured_warnings
assert len(captured_warnings) == 0


@pytest.mark.asyncio
async def test_stream_thinking(ollama_client, model, agenerator, alist, captured_warnings):
think_event = unittest.mock.Mock()
think_event.message.tool_calls = None
think_event.message.content = None
think_event.message.thinking = "t1"
think_event.done_reason = "stop"
think_event.eval_count = 10
think_event.prompt_eval_count = 5
think_event.total_duration = 1000000 # 1ms in nanoseconds

text_event = unittest.mock.Mock()
text_event.message.tool_calls = None
text_event.message.content = "Hello"
text_event.message.thinking = None
text_event.done_reason = "stop"
text_event.eval_count = 10
text_event.prompt_eval_count = 5
text_event.total_duration = 1000000 # 1ms in nanoseconds

ollama_client.chat = unittest.mock.AsyncMock(return_value=agenerator([think_event, text_event]))

messages = [{"role": "user", "content": [{"text": "Hello"}]}]
response = model.stream(messages)

tru_events = await alist(response)
exp_events = [
{"messageStart": {"role": "assistant"}},
{"contentBlockStart": {"start": {}}},
{"contentBlockStart": {"start": {}}},
{"contentBlockDelta": {"delta": {"reasoningContent": {"text": "t1"}}}},
{"contentBlockStop": {}},
{"contentBlockDelta": {"delta": {"text": "Hello"}}},
{"contentBlockStop": {}},
{"messageStop": {"stopReason": "end_turn"}},
{
"metadata": {
"usage": {"inputTokens": 10, "outputTokens": 5, "totalTokens": 15},
"metrics": {"latencyMs": 1.0},
}
},
]

assert tru_events == exp_events
expected_request = {
"model": "m1",
"messages": [{"role": "user", "content": "Hello"}],
"options": {},
"stream": True,
"tools": [],
}
ollama_client.chat.assert_called_once_with(**expected_request)

# Ensure no warnings emitted
assert len(captured_warnings) == 0


@pytest.mark.asyncio
async def test_tool_choice_not_supported_warns(ollama_client, model, agenerator, alist, captured_warnings):
"""Test that non-None toolChoice emits warning for unsupported providers."""
Expand All @@ -465,6 +523,7 @@ async def test_tool_choice_not_supported_warns(ollama_client, model, agenerator,
mock_event = unittest.mock.Mock()
mock_event.message.tool_calls = None
mock_event.message.content = "Hello"
mock_event.message.thinking = None
mock_event.done_reason = "stop"
mock_event.eval_count = 10
mock_event.prompt_eval_count = 5
Expand All @@ -487,6 +546,7 @@ async def test_stream_with_tool_calls(ollama_client, model, agenerator, alist):
mock_tool_call.function.arguments = {"expression": "2+2"}
mock_event.message.tool_calls = [mock_tool_call]
mock_event.message.content = "I'll calculate that for you"
mock_event.message.thinking = None
mock_event.done_reason = "stop"
mock_event.eval_count = 15
mock_event.prompt_eval_count = 8
Expand Down Expand Up @@ -559,3 +619,20 @@ def test_update_config_validation_warns_on_unknown_keys(model, captured_warnings
assert len(captured_warnings) == 1
assert "Invalid configuration parameters" in str(captured_warnings[0].message)
assert "wrong_param" in str(captured_warnings[0].message)


def test_format_chunk_content_block_delta_thinking_delta(model):
event = {"chunk_type": "content_delta", "data_type": "reasoning_content", "data": "t1"}

tru_chunk = model.format_chunk(event)
exp_chunk = {
"contentBlockDelta": {
"delta": {
"reasoningContent": {
"text": "t1",
},
},
},
}

assert tru_chunk == exp_chunk