Skip to content

Commit e4673e7

Browse files
kiberguscopybara-github
authored andcommitted
feat: Introduce a ContentStream interface: a syntax sugar for reducing a stream of Parts to the needed modality.
FUTURE_COPYBARA_INTEGRATE_REVIEW=#1433 from googleapis:release-please--branches--main 171b659 PiperOrigin-RevId: 810874176
1 parent 5c4d7ee commit e4673e7

File tree

2 files changed

+358
-0
lines changed

2 files changed

+358
-0
lines changed

google/genai/content_stream.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utilities for working with streams of genai.Part content."""
16+
17+
from collections.abc import AsyncIterable, AsyncIterator
18+
from typing import Any, Callable, Generator, Generic, TypeVar
19+
20+
from . import _transformers
21+
from . import types
22+
23+
24+
class ContentStream:
25+
"""A syntax sugar mixin for streams of Content / Parts.
26+
27+
Adapts whatever producer has to whatever consumer needs. Producer initializes
28+
ContentStream with an AsyncIterable or Collection of Content, Part or
29+
ContentListUnionDict. The consumer can iterate over the content in the stream
30+
or use accessor like .text() and reduce it to the given modality.
31+
32+
Models and agents need to work with multimodal streaming content. Consuming
33+
such streams without ContentStream may look like:
34+
35+
text = ''
36+
async for response in client.generate_content_stream(...)
37+
for content in response.candidates[0]:
38+
for part in content.parts:
39+
if not part.text:
40+
raise ValueError('Non text part received')
41+
text += part.text
42+
43+
Many consumers would benefit from more constrained interfaces, such as "just
44+
return a string". But producers have to provide a generic interface to satisfy
45+
all clients. It works another way too: if consumer can deal with streaming
46+
multimodal content, it should be able to ingest unary text-only inputs.
47+
48+
Whether ContentStream can be iterated over only once (like generator) or
49+
multiple times (like a list) depends on the specific implementation.
50+
It is strongly advised to allow reading content multiple times as this allows
51+
consumers to retry failures and tee response to multiple consumers. If the
52+
ContentStream is backed by a generator attempt to get content again (through
53+
any method) will raise a RuntimeError.
54+
55+
Producers may chose to subclass ContentStream and provide additional methods.
56+
"""
57+
58+
def __init__(
59+
self,
60+
*,
61+
content: (
62+
types.ContentListUnionDict
63+
| AsyncIterable[types.ContentListUnionDict]
64+
| None
65+
) = None,
66+
content_iterator: AsyncIterator[types.ContentListUnionDict] | None = None,
67+
parts: AsyncIterable[types.Part] | None = None,
68+
parts_iterator: AsyncIterator[types.Part] | None = None,
69+
):
70+
"""Initializes the stream.
71+
72+
Only one of content_stream, parts_stream or content can be set.
73+
74+
Args:
75+
content: Constructs the stream from a static content object or
76+
AsyncIterable of objects convertible to Content. It must allow iterating
77+
over it multiple times.
78+
content_iterator: Same as `content`, but can be iterated only once.
79+
ContentStream will raise a RuntimeError on consecutive attempts. Use it
80+
if underlying iterable discards the content as soon as it is consumed.
81+
parts: An optimized version for the case when producer already yields Part
82+
objects.
83+
parts_iterator: Same as `parts`, but can be iterated only once.
84+
ContentStream will raise a RuntimeError on consecutive attempts. Use it
85+
if underlying iterable discards the content as soon as it is consumed.
86+
"""
87+
if (
88+
sum(
89+
x is not None
90+
for x in [content, content_iterator, parts, parts_iterator]
91+
)
92+
> 1
93+
):
94+
raise ValueError(
95+
'At most one of content, content_iterator, parts, parts_iterator can '
96+
'be provided.'
97+
)
98+
99+
if content_iterator:
100+
content = _StreamOnce(content_iterator)
101+
if content:
102+
if isinstance(content, AsyncIterable):
103+
parts = _StreamContentIterable(content)
104+
else:
105+
# We have a static content object, use optimized implementation for it.
106+
parts = _StreamContent(content)
107+
108+
if parts_iterator:
109+
parts = _StreamOnce(parts_iterator)
110+
111+
if parts:
112+
self.parts: Callable[[], AsyncIterator[types.Part]] = parts.__aiter__
113+
114+
def parts(self) -> AsyncIterator[types.Part]:
115+
"""Returns an iterator that yields all genai.Parts from the stream.
116+
117+
Consecutive calls to this method return independent iterators that start
118+
from the beginning of the stream. If the stream can only be iterated once,
119+
a RuntimeError will be risen on the second attempt.
120+
"""
121+
# This method is overriden in the __init__ depending on the source type and
122+
# is defined here to provide a good docstring.
123+
124+
# Subclasses of ContentStream may also override this method directly.
125+
# Subclasses may also provide methods that return views of the original
126+
# ContentStream e.g. `.last_turn(self) -> ContentStream`
127+
raise NotImplementedError('ContentStream.parts is not implemented.')
128+
129+
async def text(self) -> str:
130+
"""Returns the stream contents as string.
131+
132+
Returns:
133+
The text of the part.
134+
135+
Raises:
136+
ValueError the underlying content contans non-text parts.
137+
"""
138+
text_parts = []
139+
async for part in self.parts():
140+
if part.text is not None:
141+
text_parts.append(part.text)
142+
elif (
143+
part.inline_data is not None
144+
and part.inline_data.mime_type.startswith('text/')
145+
):
146+
text_parts.append(part.inline_data.data.decode('utf-8'))
147+
else:
148+
raise ValueError(f'Part is not text: {part}')
149+
return ''.join(text_parts)
150+
151+
async def content(self) -> list[types.Content]:
152+
"""Returns all the contents of the stream as a list.
153+
154+
Any consecutive Content objects with matching roles will be merged in-to one
155+
Content object. This way even if the producer streams its output (which it
156+
has to do in separate Content objects), the consumer can rely on "each item
157+
is a turn". Though note that in live bidirectional setups the notion of turn
158+
may be fuzzy or not defined.
159+
"""
160+
# TODO(kibergus): To implement this we need part.part_metadata change to
161+
# reach production to represent roles in parts.
162+
raise NotImplementedError('CotentStream.content is not implemented yet.')
163+
164+
def __await__(self) -> Generator[Any, None, None]:
165+
"""Awaits until the stream is finished.
166+
167+
Useful if we are not interested in the content itself, but in the side
168+
effect of the code that produces it.
169+
170+
Returns:
171+
An awaitable that completes when the stream is finished.
172+
"""
173+
174+
async def await_parts():
175+
async for _ in self.parts():
176+
pass
177+
178+
return await_parts().__await__()
179+
180+
# More methods will be added here on as-needed basis. Candidates are:
181+
# async def get_dataclass(self, json_dataclass: type[T]) -> T:
182+
# Interprets contents of the stream as JSON from which the json_dataclass
183+
# can be instantiated. Works with models constrained with
184+
# `response_schema=json_dataclass`. Also can be used to pass structured data
185+
# between agents.
186+
#
187+
# async def pil_image(self) -> PIL.Image.Image:
188+
# For gen-media models. Asserts that the output is a single image and
189+
# returns it as PIL image.
190+
191+
192+
class _StreamContent(AsyncIterable[types.Part]):
193+
194+
def __init__(self, content: types.ContentListUnionDict):
195+
self._content = _transformers.t_content(content)
196+
197+
def __aiter__(self) -> AsyncIterator[types.Part]:
198+
async def yield_content():
199+
for part in self._content.parts:
200+
yield part
201+
202+
return yield_content()
203+
204+
205+
class _StreamContentIterable(AsyncIterable[types.Part]):
206+
207+
def __init__(self, content: AsyncIterable[types.ContentListUnionDict]):
208+
self._content = content
209+
210+
def __aiter__(self) -> AsyncIterator[types.Part]:
211+
async def yield_content():
212+
async for content in self._content:
213+
for part in _transformers.t_content(content).parts:
214+
yield part
215+
216+
return yield_content()
217+
218+
219+
T = TypeVar('T')
220+
221+
222+
class _StreamOnce(Generic[T]):
223+
"""An AsyncIterable that can be iterated over only once."""
224+
225+
def __init__(self, stream: AsyncIterator[T]):
226+
self._stream = stream
227+
self._exhausted = False
228+
229+
def __aiter__(self) -> AsyncIterator[T]:
230+
if self._exhausted:
231+
raise RuntimeError(
232+
'This ContentStream is backed by an generator and can only be'
233+
' iterated over once.'
234+
)
235+
self._exhausted = True
236+
return self._stream
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for content_stream.py."""
16+
17+
from collections.abc import AsyncIterator, Iterable
18+
from typing import TypeVar
19+
20+
from google.genai import types
21+
from google.genai.content_stream import ContentStream
22+
import pytest
23+
24+
T = TypeVar('T')
25+
26+
27+
async def _to_async_iter(iterable: Iterable[T]) -> AsyncIterator[T]:
28+
for item in iterable:
29+
yield item
30+
31+
32+
async def _parts_to_list(
33+
parts_iter: AsyncIterator[types.Part],
34+
) -> list[types.Part]:
35+
return [part async for part in parts_iter]
36+
37+
38+
@pytest.mark.asyncio
39+
async def test_init_with_static_content_obj():
40+
stream = ContentStream(content=types.UserContent('hello'))
41+
assert await stream.text() == 'hello'
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_init_with_static_content_list():
46+
stream = ContentStream(content=['hello', ' world'])
47+
assert await stream.text() == 'hello world'
48+
49+
50+
@pytest.mark.asyncio
51+
async def test_init_with_content_iterable():
52+
content = [types.UserContent('hello'), ' world']
53+
stream = ContentStream(content=_to_async_iter(content))
54+
assert await stream.text() == 'hello world'
55+
56+
57+
@pytest.mark.asyncio
58+
async def test_init_with_content_iterator():
59+
content = [types.UserContent('hello'), ' world']
60+
stream = ContentStream(content_iterator=_to_async_iter(content))
61+
assert await stream.text() == 'hello world'
62+
63+
# Attempting to read the content a second time should fail.
64+
with pytest.raises(RuntimeError):
65+
await stream.text()
66+
67+
68+
@pytest.mark.asyncio
69+
async def test_init_with_parts_iterable():
70+
parts_list = [types.Part(text='hello'), types.Part(text=' world')]
71+
stream = ContentStream(parts=_to_async_iter(parts_list))
72+
assert await stream.text() == 'hello world'
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_init_with_parts_iterator():
77+
parts_list = [types.Part(text='hello'), types.Part(text=' world')]
78+
stream = ContentStream(parts_iterator=_to_async_iter(parts_list))
79+
assert await stream.text() == 'hello world'
80+
81+
# Attempting to read the content a second time should fail.
82+
with pytest.raises(RuntimeError):
83+
await stream.text()
84+
85+
86+
def test_init_with_multiple_fail():
87+
with pytest.raises(ValueError):
88+
ContentStream(content=[], parts=[])
89+
with pytest.raises(ValueError):
90+
ContentStream(
91+
content_iterator=_to_async_iter([]), parts_iterator=_to_async_iter([])
92+
)
93+
94+
95+
@pytest.mark.asyncio
96+
async def test_text_with_inline_data():
97+
stream = ContentStream(
98+
content=types.Part.from_bytes(mime_type='text/plain', data=b'hello')
99+
)
100+
assert await stream.text() == 'hello'
101+
102+
103+
@pytest.mark.asyncio
104+
async def test_text_with_non_text_part_fail():
105+
stream = ContentStream(
106+
content=types.Part.from_bytes(mime_type='image/png', data=b'123')
107+
)
108+
with pytest.raises(ValueError):
109+
await stream.text()
110+
111+
112+
@pytest.mark.asyncio
113+
async def test_await():
114+
parts = []
115+
116+
async def parts_generator():
117+
for i in range(3):
118+
parts.append(i)
119+
yield types.Part(text=str(i))
120+
121+
await ContentStream(parts_iterator=parts_generator())
122+
assert parts == [0, 1, 2]

0 commit comments

Comments
 (0)