Skip to content

Commit 6f618c3

Browse files
yyyu-googlecopybara-github
authored andcommitted
fix: automatic function calling for generate_content_stream and its async version
PiperOrigin-RevId: 829172999
1 parent 4b855e6 commit 6f618c3

File tree

4 files changed

+144
-154
lines changed

4 files changed

+144
-154
lines changed

google/genai/chats.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,18 +152,26 @@ def record_history(
152152
considered valid.
153153
"""
154154
input_contents = (
155-
# Because the AFC input contains the entire curated chat history in
156-
# addition to the new user input, we need to truncate the AFC history
157-
# to deduplicate the existing chat history.
158155
automatic_function_calling_history[len(self._curated_history) :]
159156
if automatic_function_calling_history
160157
else [user_input]
161158
)
162-
# Appends an empty content when model returns empty response, so that the
163-
# history is always alternating between user and model.
164-
output_contents = (
165-
model_output if model_output else [Content(role="model", parts=[])]
166-
)
159+
160+
if automatic_function_calling_history:
161+
filtered_outputs = []
162+
for output in model_output:
163+
if output not in input_contents:
164+
filtered_outputs.append(output)
165+
output_contents = (
166+
filtered_outputs
167+
if filtered_outputs
168+
else [Content(role="model", parts=[])]
169+
)
170+
else:
171+
output_contents = (
172+
model_output if model_output else [Content(role="model", parts=[])]
173+
)
174+
167175
self._comprehensive_history.extend(input_contents)
168176
self._comprehensive_history.extend(output_contents)
169177
if is_valid:

google/genai/models.py

Lines changed: 62 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -5227,7 +5227,6 @@ def generate_content_stream(
52275227
)
52285228
automatic_function_calling_history: list[types.Content] = []
52295229
chunk = None
5230-
func_response_parts = None
52315230
i = 0
52325231
while remaining_remote_calls_afc > 0:
52335232
i += 1
@@ -5237,47 +5236,30 @@ def generate_content_stream(
52375236

52385237
function_map = _extra_utils.get_function_map(parsed_config)
52395238

5240-
if i == 1:
5241-
# First request gets a function call.
5242-
# Then get function response parts.
5243-
# Yield chunks only if there's no function response parts.
5244-
for chunk in response:
5245-
if not function_map:
5246-
contents = _extra_utils.append_chunk_contents(contents, chunk) # type: ignore[assignment]
5247-
yield chunk
5248-
else:
5249-
if (
5250-
not chunk.candidates
5251-
or not chunk.candidates[0].content
5252-
or not chunk.candidates[0].content.parts
5253-
):
5254-
break
5255-
func_response_parts = _extra_utils.get_function_response_parts(
5256-
chunk, function_map
5257-
)
5258-
if not func_response_parts:
5259-
contents = _extra_utils.append_chunk_contents(contents, chunk) # type: ignore[assignment]
5260-
yield chunk
5239+
func_response_parts: list[types.Part] = []
5240+
for chunk in response:
5241+
if _extra_utils.should_append_afc_history(parsed_config):
5242+
chunk.automatic_function_calling_history = (
5243+
automatic_function_calling_history
5244+
)
5245+
5246+
contents = _extra_utils.append_chunk_contents(contents, chunk) # type: ignore[assignment]
5247+
yield chunk
52615248

5262-
else:
5263-
# Second request and beyond, yield chunks.
5264-
for chunk in response:
5265-
if _extra_utils.should_append_afc_history(parsed_config):
5266-
chunk.automatic_function_calling_history = (
5267-
automatic_function_calling_history
5268-
)
5269-
contents = _extra_utils.append_chunk_contents(contents, chunk) # type: ignore[assignment]
5270-
yield chunk
52715249
if (
52725250
chunk is None
52735251
or not chunk.candidates
52745252
or not chunk.candidates[0].content
52755253
or not chunk.candidates[0].content.parts
52765254
):
5277-
break
5278-
func_response_parts = _extra_utils.get_function_response_parts(
5279-
chunk, function_map
5280-
)
5255+
continue
5256+
5257+
if function_map:
5258+
func_response_parts_in_chunk = (
5259+
_extra_utils.get_function_response_parts(chunk, function_map)
5260+
)
5261+
if func_response_parts_in_chunk:
5262+
func_response_parts.extend(func_response_parts_in_chunk)
52815263

52825264
if not function_map:
52835265
break
@@ -5288,22 +5270,16 @@ def generate_content_stream(
52885270
if remaining_remote_calls_afc == 0:
52895271
logger.info('Reached max remote calls for automatic function calling.')
52905272

5291-
# Append function response parts to contents for the next request.
5292-
if chunk is not None and chunk.candidates is not None:
5293-
func_call_content = chunk.candidates[0].content
5294-
func_response_content = types.Content(
5295-
role='user',
5296-
parts=func_response_parts,
5297-
)
5298-
contents = t.t_contents(contents) # type: ignore[assignment]
5299-
if not automatic_function_calling_history:
5300-
automatic_function_calling_history.extend(contents) # type: ignore[arg-type]
5301-
if isinstance(contents, list) and func_call_content is not None:
5302-
contents.append(func_call_content) # type: ignore[arg-type]
5303-
contents.append(func_response_content) # type: ignore[arg-type]
5304-
if func_call_content is not None:
5305-
automatic_function_calling_history.append(func_call_content)
5306-
automatic_function_calling_history.append(func_response_content)
5273+
func_response_content = types.Content(
5274+
role='user',
5275+
parts=func_response_parts,
5276+
)
5277+
contents = t.t_contents(contents) # type: ignore[assignment]
5278+
contents.append(func_response_content) # type: ignore[arg-type, union-attr]
5279+
5280+
# Update AFC history - at the end of each iteration, it should match contents exactly
5281+
# using list to make a value copy instead of assigning by reference
5282+
automatic_function_calling_history = list(contents) # type: ignore[arg-type]
53075283

53085284
def generate_images(
53095285
self,
@@ -7042,96 +7018,69 @@ async def async_generator(model, contents, config): # type: ignore[no-untyped-d
70427018
f'AFC is enabled with max remote calls: {remaining_remote_calls_afc}.'
70437019
)
70447020
automatic_function_calling_history: list[types.Content] = []
7045-
func_response_parts = None
7046-
chunk = None
70477021
i = 0
70487022
while remaining_remote_calls_afc > 0:
70497023
i += 1
70507024
response = await self._generate_content_stream(
70517025
model=model, contents=contents, config=config
70527026
)
7053-
# TODO: b/453739108 - make AFC logic more robust like the other 3 methods.
7054-
if i > 1:
7055-
logger.info(f'AFC remote call {i} is done.')
7056-
remaining_remote_calls_afc -= 1
7057-
if i > 1 and remaining_remote_calls_afc == 0:
7058-
logger.info(
7059-
'Reached max remote calls for automatic function calling.'
7060-
)
70617027

70627028
function_map = _extra_utils.get_function_map(
70637029
config, mcp_to_genai_tool_adapters, is_caller_method_async=True
70647030
)
70657031

7066-
if i == 1:
7067-
# First request gets a function call.
7068-
# Then get function response parts.
7069-
# Yield chunks only if there's no function response parts.
7070-
async for chunk in response: # type: ignore[attr-defined]
7071-
if not function_map:
7072-
contents = _extra_utils.append_chunk_contents(contents, chunk)
7073-
yield chunk
7074-
else:
7075-
if (
7076-
not chunk.candidates
7077-
or not chunk.candidates[0].content
7078-
or not chunk.candidates[0].content.parts
7079-
):
7080-
break
7081-
func_response_parts = (
7082-
await _extra_utils.get_function_response_parts_async(
7083-
chunk, function_map
7084-
)
7085-
)
7086-
if not func_response_parts:
7087-
contents = _extra_utils.append_chunk_contents(contents, chunk)
7088-
yield chunk
7089-
7090-
else:
7091-
# Second request and beyond, yield chunks.
7092-
async for chunk in response: # type: ignore[attr-defined]
7093-
7094-
if _extra_utils.should_append_afc_history(config):
7095-
chunk.automatic_function_calling_history = (
7096-
automatic_function_calling_history
7097-
)
7098-
contents = _extra_utils.append_chunk_contents(contents, chunk)
7099-
yield chunk
7032+
func_response_parts: list[types.Part] = []
7033+
7034+
async for chunk in response: # type: ignore[attr-defined]
7035+
if _extra_utils.should_append_afc_history(config):
7036+
chunk.automatic_function_calling_history = (
7037+
automatic_function_calling_history
7038+
)
7039+
7040+
yield chunk
7041+
7042+
contents = _extra_utils.append_chunk_contents(contents, chunk)
7043+
71007044
if (
71017045
chunk is None
71027046
or not chunk.candidates
71037047
or not chunk.candidates[0].content
71047048
or not chunk.candidates[0].content.parts
71057049
):
7106-
break
7107-
func_response_parts = (
7108-
await _extra_utils.get_function_response_parts_async(
7109-
chunk, function_map
7110-
)
7111-
)
7050+
continue
7051+
7052+
if function_map:
7053+
func_response_parts_in_chunk = (
7054+
await _extra_utils.get_function_response_parts_async(
7055+
chunk, function_map
7056+
)
7057+
)
7058+
if func_response_parts_in_chunk:
7059+
func_response_parts.extend(func_response_parts_in_chunk)
7060+
71127061
if not function_map:
71137062
break
71147063

71157064
if not func_response_parts:
71167065
break
71177066

7118-
if chunk is None:
7119-
continue
7120-
# Append function response parts to contents for the next request.
7121-
func_call_content = chunk.candidates[0].content
7067+
logger.info(f'AFC remote call {i} is done.')
7068+
remaining_remote_calls_afc -= 1
7069+
if remaining_remote_calls_afc == 0:
7070+
logger.info(
7071+
'Reached max remote calls for automatic function calling.'
7072+
)
7073+
71227074
func_response_content = types.Content(
71237075
role='user',
71247076
parts=func_response_parts,
71257077
)
71267078
contents = t.t_contents(contents)
7127-
if not automatic_function_calling_history:
7128-
automatic_function_calling_history.extend(contents)
7129-
if isinstance(contents, list) and func_call_content is not None:
7130-
contents.append(func_call_content)
7131-
contents.append(func_response_content)
7132-
if func_call_content is not None:
7133-
automatic_function_calling_history.append(func_call_content)
7134-
automatic_function_calling_history.append(func_response_content)
7079+
contents.append(func_response_content)
7080+
7081+
# Update AFC history - at the end of each iteration, it should match contents exactly
7082+
# using list to make a value copy instead of assigning by reference
7083+
automatic_function_calling_history = list(contents)
71357084

71367085
return async_generator(model, contents, parsed_config) # type: ignore[no-untyped-call, no-any-return]
71377086

google/genai/tests/afc/test_generate_content_stream_afc.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,11 @@ def test_generate_content_stream_with_function_tools_used(
313313
)
314314

315315
chunk = None
316+
afc_text_present = False
316317
for chunk in stream:
317-
assert chunk.text == TEST_AFC_TEXT_PART.text
318-
318+
if chunk.text and chunk.text == TEST_AFC_TEXT_PART.text:
319+
afc_text_present = True
320+
assert afc_text_present
319321
assert mock_generate_content_stream_with_afc.call_count == 2
320322
assert mock_get_function_response_parts.call_count == 2
321323

@@ -346,9 +348,11 @@ def test_generate_content_stream_with_thought_summaries(
346348
)
347349

348350
chunk = None
351+
afc_text_present = False
349352
for chunk in stream:
350-
assert chunk.text == TEST_AFC_TEXT_PART.text
351-
353+
if chunk.text and chunk.text == TEST_AFC_TEXT_PART.text:
354+
afc_text_present = True
355+
assert afc_text_present
352356
assert mock_generate_content_stream_with_afc.call_count == 2
353357
assert mock_get_function_response_parts.call_count == 2
354358

@@ -449,9 +453,11 @@ async def test_generate_content_stream_with_function_tools_used_async(
449453
)
450454

451455
chunk = None
456+
received_afc_text = False
452457
async for chunk in stream:
453-
assert chunk.text == TEST_AFC_TEXT_PART.text
454-
458+
if chunk.text and chunk.text == TEST_AFC_TEXT_PART.text:
459+
received_afc_text = True
460+
assert received_afc_text
455461
assert mock_generate_content_stream_with_afc_async.call_count == 2
456462

457463
assert mock_get_function_response_parts_async.call_count == 2
@@ -481,9 +487,11 @@ async def test_generate_content_stream_with_function_async_function_used_async(
481487
)
482488

483489
chunk = None
490+
received_afc_text = False
484491
async for chunk in stream:
485-
assert chunk.text == TEST_AFC_TEXT_PART.text
486-
492+
if chunk.text and chunk.text == TEST_AFC_TEXT_PART.text:
493+
received_afc_text = True
494+
assert received_afc_text
487495
assert mock_generate_content_stream_with_afc_async.call_count == 2
488496

489497
assert mock_get_function_response_parts_async.call_count == 2
@@ -516,9 +524,11 @@ async def test_generate_content_stream_with_thought_summaries_async(
516524
)
517525

518526
chunk = None
527+
received_afc_text = False
519528
async for chunk in stream:
520-
assert chunk.text == TEST_AFC_TEXT_PART.text
521-
529+
if chunk.text and chunk.text == TEST_AFC_TEXT_PART.text:
530+
received_afc_text = True
531+
assert received_afc_text
522532
assert mock_generate_content_stream_with_afc_async.call_count == 2
523533

524534
assert mock_get_function_response_parts_async.call_count == 2

0 commit comments

Comments
 (0)