Skip to content
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6e145e6
Refactor handle_text_delta() to generator pattern with split tag buff…
dsfaccini Oct 22, 2025
11b5f1f
fix test suite for generator pattern and ensure coverage
dsfaccini Oct 23, 2025
0f876de
Merge main into fix-split-thinking-tags-v2
dsfaccini Oct 23, 2025
b5c0910
Merge remote-tracking branch 'origin/main' into fix-split-thinking-ta…
dsfaccini Oct 23, 2025
3439159
rename _tag_buffer to _thinking_tag_buffer
dsfaccini Oct 23, 2025
876ebb2
remove pragmas
dsfaccini Oct 23, 2025
adc51e6
adds a finalize method to prevent lost content from buffered chunks t…
dsfaccini Oct 23, 2025
0818191
fix: handle thinking tags with trailing content and vendor_part_id=…
dsfaccini Oct 24, 2025
f50d4b4
fix coverage
dsfaccini Oct 24, 2025
551d035
remove pragmas
dsfaccini Oct 24, 2025
0998a63
Merge main to stay updated with latest changes
dsfaccini Oct 30, 2025
9b598dd
models
dsfaccini Nov 2, 2025
41a38e2
- include incomplete closing tags in thinking part
dsfaccini Nov 2, 2025
dcac211
wip: improve coverage
dsfaccini Nov 3, 2025
b9bdd78
- reduce complexity in parts manager
dsfaccini Nov 3, 2025
4b7f0c1
Merge branch 'main' into handle-streamed-thinking-over-multiple-chunks
dsfaccini Nov 4, 2025
ac03e38
add tests for coverage
dsfaccini Nov 4, 2025
5fae762
- fix coverage
dsfaccini Nov 5, 2025
28578bf
- fix case multiple_thinking_parts_with_text_between
dsfaccini Nov 5, 2025
0838109
test more cases without vendor id
dsfaccini Nov 6, 2025
2bc1304
refactor parts manager and add parametrized cases
dsfaccini Nov 8, 2025
3c74ee4
delay emission of empty thinking parts
dsfaccini Nov 9, 2025
2674084
update the groq test
dsfaccini Nov 9, 2025
593e02f
Merge branch 'main' into handle-streamed-thinking-over-multiple-chunks
dsfaccini Nov 9, 2025
0214933
fix coverage
dsfaccini Nov 9, 2025
7c44cd9
add more tests and fix coverage
dsfaccini Nov 9, 2025
06c74c6
fix coverage?
dsfaccini Nov 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
441 changes: 386 additions & 55 deletions pydantic_ai_slim/pydantic_ai/_parts_manager.py

Large diffs are not rendered by default.

21 changes: 19 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from __future__ import annotations as _annotations

import base64
import copy
import warnings
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Iterator
Expand Down Expand Up @@ -521,7 +522,7 @@ class StreamedResponse(ABC):
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
_usage: RequestUsage = field(default_factory=RequestUsage, init=False)

def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
"""Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.

This proxies the `_event_iterator()` and emits all events, while also checking for matches
Expand Down Expand Up @@ -580,6 +581,16 @@ def part_end_event(next_part: ModelResponsePart | None = None) -> PartEndEvent |

yield event

# Flush any buffered content and stream finalize events
for finalize_event in self._parts_manager.finalize():
if isinstance(finalize_event, PartStartEvent):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there are only two cases when finalize will have an effect: when we're buffering

  1. a split starting tag, i.e. <th → emits a PartStartEvent
  2. a split ending tag, i.e. </th → emits a PartDeltaEvent
    coverage is complaining that there's no test running through the PartDeltaEvent branch of this, so I need to figure out how to test it

if last_start_event:
end_event = part_end_event(finalize_event.part)
if end_event:
yield end_event
last_start_event = finalize_event
yield finalize_event
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set finalize_event.previous_part_kind like we do above? Could we reuse that same logic instead of duplicating it?


end_event = part_end_event()
if end_event:
yield end_event
Expand All @@ -602,8 +613,14 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:

def get(self) -> ModelResponse:
"""Build a [`ModelResponse`][pydantic_ai.messages.ModelResponse] from the data received from the stream so far."""
# Flush any buffered content before building response
# clone parts manager to avoid modifying the ongoing stream state
cloned_manager = copy.deepcopy(self._parts_manager)
for _ in cloned_manager.finalize():
pass

return ModelResponse(
parts=self._parts_manager.get_parts(),
parts=cloned_manager.get_parts(),
model_name=self.model_name,
timestamp=self.timestamp,
usage=self.usage(),
Expand Down
34 changes: 18 additions & 16 deletions pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,25 +729,26 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
elif isinstance(event, BetaRawContentBlockStartEvent):
current_block = event.content_block
if isinstance(current_block, BetaTextBlock) and current_block.text:
maybe_event = self._parts_manager.handle_text_delta(
for event_item in self._parts_manager.handle_text_delta(
vendor_part_id=event.index, content=current_block.text
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event_item
elif isinstance(current_block, BetaThinkingBlock):
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
content=current_block.thinking,
signature=current_block.signature,
provider_name=self.provider_name,
)
):
yield e
elif isinstance(current_block, BetaRedactedThinkingBlock):
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
id='redacted_thinking',
signature=current_block.data,
provider_name=self.provider_name,
)
):
yield e
elif isinstance(current_block, BetaToolUseBlock):
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=event.index,
Expand Down Expand Up @@ -803,23 +804,24 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:

elif isinstance(event, BetaRawContentBlockDeltaEvent):
if isinstance(event.delta, BetaTextDelta):
maybe_event = self._parts_manager.handle_text_delta(
for event_item in self._parts_manager.handle_text_delta(
vendor_part_id=event.index, content=event.delta.text
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event_item
elif isinstance(event.delta, BetaThinkingDelta):
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
content=event.delta.thinking,
provider_name=self.provider_name,
)
):
yield e
elif isinstance(event.delta, BetaSignatureDelta):
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=event.index,
signature=event.delta.signature,
provider_name=self.provider_name,
)
):
yield e
elif isinstance(event.delta, BetaInputJSONDelta):
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=event.index,
Expand Down
15 changes: 8 additions & 7 deletions pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,24 +687,25 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
delta = content_block_delta['delta']
if 'reasoningContent' in delta:
if redacted_content := delta['reasoningContent'].get('redactedContent'):
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=index,
id='redacted_content',
signature=redacted_content.decode('utf-8'),
provider_name=self.provider_name,
)
):
yield e
else:
signature = delta['reasoningContent'].get('signature')
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=index,
content=delta['reasoningContent'].get('text'),
signature=signature,
provider_name=self.provider_name if signature else None,
)
):
yield e
if text := delta.get('text'):
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=text)
if maybe_event is not None: # pragma: no branch
yield maybe_event
for event in self._parts_manager.handle_text_delta(vendor_part_id=index, content=text):
yield event
if 'toolUse' in delta:
tool_use = delta['toolUse']
maybe_event = self._parts_manager.handle_tool_call_delta(
Expand Down
12 changes: 6 additions & 6 deletions pydantic_ai_slim/pydantic_ai/models/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,26 +284,26 @@ class FunctionStreamedResponse(StreamedResponse):
def __post_init__(self):
self._usage += _estimate_usage([])

async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901
async for item in self._iter:
if isinstance(item, str):
response_tokens = _estimate_string_tokens(item)
self._usage += usage.RequestUsage(output_tokens=response_tokens)
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item)
if maybe_event is not None: # pragma: no branch
yield maybe_event
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=item):
yield event
elif isinstance(item, dict) and item:
for dtc_index, delta in item.items():
if isinstance(delta, DeltaThinkingPart):
if delta.content: # pragma: no branch
response_tokens = _estimate_string_tokens(delta.content)
self._usage += usage.RequestUsage(output_tokens=response_tokens)
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=dtc_index,
content=delta.content,
signature=delta.signature,
provider_name='function' if delta.signature else None,
)
):
yield e
elif isinstance(delta, DeltaToolCall):
if delta.json_args:
response_tokens = _estimate_string_tokens(delta.json_args)
Expand Down
7 changes: 3 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
if 'text' in gemini_part:
# Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled
# amongst the tool call deltas
maybe_event = self._parts_manager.handle_text_delta(
for event in self._parts_manager.handle_text_delta(
vendor_part_id=None, content=gemini_part['text']
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event

elif 'function_call' in gemini_part:
# Here, we assume all function_call parts are complete and don't have deltas.
Expand Down
15 changes: 9 additions & 6 deletions pydantic_ai_slim/pydantic_ai/models/google.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,19 +661,22 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
for part in parts:
if part.thought_signature:
signature = base64.b64encode(part.thought_signature).decode('utf-8')
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id='thinking',
signature=signature,
provider_name=self.provider_name,
)
):
yield e

if part.text is not None:
if part.thought:
yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text)
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id='thinking', content=part.text
):
yield e
else:
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text)
if maybe_event is not None: # pragma: no branch
yield maybe_event
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text):
yield event
elif part.function_call:
maybe_event = self._parts_manager.handle_tool_call_delta(
vendor_part_id=uuid4(),
Expand Down
12 changes: 6 additions & 6 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,9 +547,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
reasoning = True

# NOTE: The `reasoning` field is only present if `groq_reasoning_format` is set to `parsed`.
yield self._parts_manager.handle_thinking_delta(
for e in self._parts_manager.handle_thinking_delta(
vendor_part_id=f'reasoning-{reasoning_index}', content=choice.delta.reasoning
)
):
yield e
else:
reasoning = False

Expand All @@ -572,14 +573,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
# Handle the text part of the response
content = choice.delta.content
if content:
maybe_event = self._parts_manager.handle_text_delta(
for event in self._parts_manager.handle_text_delta(
vendor_part_id='content',
content=content,
thinking_tags=self._model_profile.thinking_tags,
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event

# Handle the tool calls
for dtc in choice.delta.tool_calls or []:
Expand Down
7 changes: 3 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,13 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
# Handle the text part of the response
content = choice.delta.content
if content:
maybe_event = self._parts_manager.handle_text_delta(
for event in self._parts_manager.handle_text_delta(
vendor_part_id='content',
content=content,
thinking_tags=self._model_profile.thinking_tags,
ignore_leading_whitespace=self._model_profile.ignore_streamed_leading_whitespace,
)
if maybe_event is not None: # pragma: no branch
yield maybe_event
):
yield event

for dtc in choice.delta.tool_calls or []:
maybe_event = self._parts_manager.handle_tool_call_delta(
Expand Down
8 changes: 4 additions & 4 deletions pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
content = choice.delta.content
text, thinking = _map_content(content)
for thought in thinking:
self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought)
for event in self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=thought):
yield event
if text:
# Attempt to produce an output tool call from the received text
output_tools = {c.name: c for c in self.model_request_parameters.output_tools}
Expand All @@ -653,9 +654,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
tool_call_id=maybe_tool_call_part.tool_call_id,
)
else:
maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=text)
if maybe_event is not None: # pragma: no branch
yield maybe_event
for event in self._parts_manager.handle_text_delta(vendor_part_id='content', content=text):
yield event

# Handle the explicit tool calls
for index, dtc in enumerate(choice.delta.tool_calls or []):
Expand Down
Loading
Loading