Skip to content

Commit 5aa283a

Browse files
kiberguscopybara-github
authored andcommitted
feat: Introduce a ContentStream interface: a syntax sugar for reducing a stream of Parts to the needed modality.
FUTURE_COPYBARA_INTEGRATE_REVIEW=#1433 from googleapis:release-please--branches--main 171b659 PiperOrigin-RevId: 810874176
1 parent 5c4d7ee commit 5aa283a

File tree

3 files changed

+391
-0
lines changed

3 files changed

+391
-0
lines changed

google/genai/content_stream.py

Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utilities for working with streams of genai.Part content."""
16+
17+
from collections.abc import AsyncIterable, AsyncIterator
18+
from typing import Any, Callable, Generator, Generic, Optional, TypeVar, Union
19+
20+
from . import _transformers
21+
from . import types
22+
23+
24+
class ContentStream:
25+
"""A syntax sugar mixin for streams of Content / Parts.
26+
27+
Adapts whatever producer has to whatever consumer needs. Producer initializes
28+
ContentStream with an AsyncIterable or Collection of Content, Part or
29+
ContentListUnionDict. The consumer can iterate over the content in the stream
30+
or use accessor like .text() and reduce it to the given modality.
31+
32+
Models and agents need to work with multimodal streaming content. Consuming
33+
such streams without ContentStream may look like:
34+
35+
text = ''
36+
async for response in client.generate_content_stream(...)
37+
for content in response.candidates[0]:
38+
for part in content.parts:
39+
if not part.text:
40+
raise ValueError('Non text part received')
41+
text += part.text
42+
43+
Many consumers would benefit from more constrained interfaces, such as "just
44+
return a string". But producers have to provide a generic interface to satisfy
45+
all clients. It works another way too: if consumer can deal with streaming
46+
multimodal content, it should be able to ingest unary text-only inputs.
47+
48+
Whether ContentStream can be iterated over only once (like generator) or
49+
multiple times (like a list) depends on the specific implementation.
50+
It is strongly advised to allow reading content multiple times as this allows
51+
consumers to retry failures and tee response to multiple consumers. If the
52+
ContentStream is backed by a generator attempt to get content again (through
53+
any method) will raise a RuntimeError.
54+
55+
Producers may chose to subclass ContentStream and provide additional methods.
56+
"""
57+
58+
def __init__(
59+
self,
60+
*,
61+
content: Union[
62+
types.ContentListUnionDict,
63+
AsyncIterable[types.ContentListUnionDict],
64+
None,
65+
] = None,
66+
content_iterator: Optional[
67+
AsyncIterator[types.ContentListUnionDict]
68+
] = None,
69+
parts: Optional[AsyncIterable[types.Part]] = None,
70+
parts_iterator: Optional[AsyncIterator[types.Part]] = None,
71+
):
72+
"""Initializes the stream.
73+
74+
Only one of content_stream, parts_stream or content can be set.
75+
76+
Args:
77+
content: Constructs the stream from a static content object or
78+
AsyncIterable of objects convertible to Content. It must allow iterating
79+
over it multiple times.
80+
content_iterator: Same as `content`, but can be iterated only once.
81+
ContentStream will raise a RuntimeError on consecutive attempts. Use it
82+
if underlying iterable discards the content as soon as it is consumed.
83+
parts: An optimized version for the case when producer already yields Part
84+
objects.
85+
parts_iterator: Same as `parts`, but can be iterated only once.
86+
ContentStream will raise a RuntimeError on consecutive attempts. Use it
87+
if underlying iterable discards the content as soon as it is consumed.
88+
"""
89+
if (
90+
sum(
91+
x is not None
92+
for x in [content, content_iterator, parts, parts_iterator]
93+
)
94+
> 1
95+
):
96+
raise ValueError(
97+
'At most one of content, content_iterator, parts, parts_iterator can '
98+
'be provided.'
99+
)
100+
101+
if content_iterator:
102+
content = _StreamOnce(content_iterator)
103+
if content:
104+
if isinstance(content, AsyncIterable):
105+
parts = _StreamContentIterable(content)
106+
else:
107+
# We have a static content object, use optimized implementation for it.
108+
parts = _StreamContent(content)
109+
110+
if parts_iterator:
111+
parts = _StreamOnce(parts_iterator)
112+
113+
if parts:
114+
self.parts: Callable[[], AsyncIterator[types.Part]] = parts.__aiter__ # type: ignore[method-assign]
115+
116+
def parts(self) -> AsyncIterator[types.Part]:
117+
"""Returns an iterator that yields all genai.Parts from the stream.
118+
119+
Consecutive calls to this method return independent iterators that start
120+
from the beginning of the stream. If the stream can only be iterated once,
121+
a RuntimeError will be risen on the second attempt.
122+
"""
123+
# This method is overriden in the __init__ depending on the source type and
124+
# is defined here to provide a good docstring.
125+
126+
# Subclasses of ContentStream may also override this method directly.
127+
# Subclasses may also provide methods that return views of the original
128+
# ContentStream e.g. `.last_turn(self) -> ContentStream`
129+
raise NotImplementedError('ContentStream.parts is not implemented.')
130+
131+
async def text(self) -> str:
132+
"""Returns the stream contents as string.
133+
134+
Returns:
135+
The text of the part.
136+
137+
Raises:
138+
ValueError the underlying content contans non-text parts.
139+
"""
140+
text_parts = []
141+
async for part in self.parts():
142+
if part.text is not None:
143+
text_parts.append(part.text)
144+
elif (
145+
part.inline_data is not None
146+
and part.inline_data.mime_type is not None
147+
and part.inline_data.mime_type.startswith('text/')
148+
):
149+
if part.inline_data.data is None:
150+
raise ValueError('Invalid inline_data with None data encountered.')
151+
text_parts.append(part.inline_data.data.decode('utf-8'))
152+
else:
153+
raise ValueError(f'Part is not text: {part}')
154+
return ''.join(text_parts)
155+
156+
async def content(self) -> list[types.Content]:
157+
"""Returns all the contents of the stream as a list.
158+
159+
Any consecutive Content objects with matching roles will be merged in-to one
160+
Content object. This way even if the producer streams its output (which it
161+
has to do in separate Content objects), the consumer can rely on "each item
162+
is a turn". Though note that in live bidirectional setups the notion of turn
163+
may be fuzzy or not defined.
164+
"""
165+
# TODO(kibergus): To implement this we need part.part_metadata change to
166+
# reach production to represent roles in parts.
167+
raise NotImplementedError('CotentStream.content is not implemented yet.')
168+
169+
def __await__(self) -> Generator[Any, None, None]:
170+
"""Awaits until the stream is finished.
171+
172+
Useful if we are not interested in the content itself, but in the side
173+
effect of the code that produces it.
174+
175+
Returns:
176+
An awaitable that completes when the stream is finished.
177+
"""
178+
179+
async def await_parts() -> None:
180+
async for _ in self.parts():
181+
pass
182+
183+
return await_parts().__await__()
184+
185+
# More methods will be added here on as-needed basis. Candidates are:
186+
# async def get_dataclass(self, json_dataclass: type[T]) -> T:
187+
# Interprets contents of the stream as JSON from which the json_dataclass
188+
# can be instantiated. Works with models constrained with
189+
# `response_schema=json_dataclass`. Also can be used to pass structured data
190+
# between agents.
191+
#
192+
# async def pil_image(self) -> PIL.Image.Image:
193+
# For gen-media models. Asserts that the output is a single image and
194+
# returns it as PIL image.
195+
196+
197+
class _StreamContent(AsyncIterable[types.Part]):
198+
"""An AsyncIterable that yields all parts from a static Content."""
199+
200+
def __init__(self, content: types.ContentListUnionDict):
201+
self._content: list[types.Content] = _transformers.t_contents(content)
202+
203+
def __aiter__(self) -> AsyncIterator[types.Part]:
204+
async def yield_content() -> AsyncIterator[types.Part]:
205+
for content in self._content:
206+
if content.parts:
207+
for part in content.parts:
208+
yield part
209+
210+
return yield_content()
211+
212+
213+
class _StreamContentIterable(AsyncIterable[types.Part]):
214+
"""An AsyncIterable that yields all parts from a stream of Content."""
215+
216+
def __init__(self, content: AsyncIterable[types.ContentListUnionDict]):
217+
self._content = content
218+
219+
def __aiter__(self) -> AsyncIterator[types.Part]:
220+
async def yield_content() -> AsyncIterator[types.Part]:
221+
async for content in self._content:
222+
contents = _transformers.t_contents(content)
223+
for content in contents:
224+
if content.parts:
225+
for part in content.parts:
226+
yield part
227+
228+
return yield_content()
229+
230+
231+
T = TypeVar('T')
232+
233+
234+
class _StreamOnce(Generic[T]):
235+
"""An AsyncIterable that can be iterated over only once."""
236+
237+
def __init__(self, stream: AsyncIterator[T]):
238+
self._stream = stream
239+
self._exhausted = False
240+
241+
def __aiter__(self) -> AsyncIterator[T]:
242+
if self._exhausted:
243+
raise RuntimeError(
244+
'This ContentStream is backed by an generator and can only be'
245+
' iterated over once.'
246+
)
247+
self._exhausted = True
248+
return self._stream
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
16+
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for content_stream.py."""
16+
17+
from collections.abc import AsyncIterator, Iterable
18+
from typing import TypeVar
19+
20+
import pytest
21+
22+
from ... import content_stream
23+
from ... import types
24+
25+
T = TypeVar('T')
26+
27+
28+
async def _to_async_iter(iterable: Iterable[T]) -> AsyncIterator[T]:
29+
for item in iterable:
30+
yield item
31+
32+
33+
async def _parts_to_list(
34+
parts_iter: AsyncIterator[types.Part],
35+
) -> list[types.Part]:
36+
return [part async for part in parts_iter]
37+
38+
39+
@pytest.mark.asyncio
40+
async def test_init_with_static_content_obj():
41+
stream = content_stream.ContentStream(content=types.UserContent('hello'))
42+
assert await stream.text() == 'hello'
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_init_with_static_content_list():
47+
stream = content_stream.ContentStream(content=['hello', ' world'])
48+
assert await stream.text() == 'hello world'
49+
50+
51+
@pytest.mark.asyncio
52+
async def test_init_with_content_iterable():
53+
content = [types.UserContent('hello'), ' world']
54+
stream = content_stream.ContentStream(content=_to_async_iter(content))
55+
assert await stream.text() == 'hello world'
56+
57+
58+
@pytest.mark.asyncio
59+
async def test_init_with_content_iterator():
60+
content = [types.UserContent('hello'), ' world']
61+
stream = content_stream.ContentStream(
62+
content_iterator=_to_async_iter(content)
63+
)
64+
assert await stream.text() == 'hello world'
65+
66+
# Attempting to read the content a second time should fail.
67+
with pytest.raises(RuntimeError):
68+
await stream.text()
69+
70+
71+
@pytest.mark.asyncio
72+
async def test_init_with_parts_iterable():
73+
parts_list = [types.Part(text='hello'), types.Part(text=' world')]
74+
stream = content_stream.ContentStream(parts=_to_async_iter(parts_list))
75+
assert await stream.text() == 'hello world'
76+
77+
78+
@pytest.mark.asyncio
79+
async def test_init_with_parts_iterator():
80+
parts_list = [types.Part(text='hello'), types.Part(text=' world')]
81+
stream = content_stream.ContentStream(
82+
parts_iterator=_to_async_iter(parts_list)
83+
)
84+
assert await stream.text() == 'hello world'
85+
86+
# Attempting to read the content a second time should fail.
87+
with pytest.raises(RuntimeError):
88+
await stream.text()
89+
90+
91+
def test_init_with_multiple_fail():
92+
with pytest.raises(ValueError):
93+
content_stream.ContentStream(content=[], parts=[])
94+
with pytest.raises(ValueError):
95+
content_stream.ContentStream(
96+
content_iterator=_to_async_iter([]), parts_iterator=_to_async_iter([])
97+
)
98+
99+
100+
@pytest.mark.asyncio
101+
async def test_text_with_inline_data():
102+
stream = content_stream.ContentStream(
103+
content=types.Part.from_bytes(mime_type='text/plain', data=b'hello')
104+
)
105+
assert await stream.text() == 'hello'
106+
107+
108+
@pytest.mark.asyncio
109+
async def test_text_with_non_text_part_fail():
110+
stream = content_stream.ContentStream(
111+
content=types.Part.from_bytes(mime_type='image/png', data=b'123')
112+
)
113+
with pytest.raises(ValueError):
114+
await stream.text()
115+
116+
117+
@pytest.mark.asyncio
118+
async def test_await():
119+
parts = []
120+
121+
async def parts_generator():
122+
for i in range(3):
123+
parts.append(i)
124+
yield types.Part(text=str(i))
125+
126+
await content_stream.ContentStream(parts_iterator=parts_generator())
127+
assert parts == [0, 1, 2]

0 commit comments

Comments
 (0)