Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import typing
from typing import (
TYPE_CHECKING,
cast,
Any,
AsyncGenerator,
Dict,
Expand Down Expand Up @@ -38,6 +39,7 @@
MessageRole,
ThinkingBlock,
TextBlock,
ToolCallBlock,
)
from llama_index.core.bridge.pydantic import BaseModel, Field, PrivateAttr
from llama_index.core.callbacks import CallbackManager
Expand Down Expand Up @@ -376,7 +378,6 @@ def _stream_chat(

def gen() -> ChatResponseGen:
content = ""
existing_tool_calls = []
thoughts = ""
for r in response:
if not r.candidates:
Expand All @@ -390,14 +391,11 @@ def gen() -> ChatResponseGen:
else:
content += content_delta
llama_resp = chat_from_gemini_response(r)
existing_tool_calls.extend(
llama_resp.message.additional_kwargs.get("tool_calls", [])
)
llama_resp.delta = content_delta
llama_resp.message.blocks = [TextBlock(text=content)]
llama_resp.message.blocks.append(ThinkingBlock(content=thoughts))
llama_resp.message.additional_kwargs["tool_calls"] = existing_tool_calls
yield llama_resp
if content:
llama_resp.message.blocks.append(TextBlock(text=content))
if thoughts:
llama_resp.message.blocks.append(ThinkingBlock(content=thoughts))
yield llama_resp

if self.use_file_api:
asyncio.run(
Expand Down Expand Up @@ -429,7 +427,6 @@ async def _astream_chat(

async def gen() -> ChatResponseAsyncGen:
content = ""
existing_tool_calls = []
thoughts = ""
async for r in await chat.send_message_stream(
next_msg.parts if isinstance(next_msg, types.Content) else next_msg
Expand All @@ -448,19 +445,15 @@ async def gen() -> ChatResponseAsyncGen:
else:
content += content_delta
llama_resp = chat_from_gemini_response(r)
existing_tool_calls.extend(
llama_resp.message.additional_kwargs.get(
"tool_calls", []
)
)
llama_resp.delta = content_delta
llama_resp.message.blocks = [TextBlock(text=content)]
llama_resp.message.blocks.append(
ThinkingBlock(content=thoughts)
)
llama_resp.message.additional_kwargs["tool_calls"] = (
existing_tool_calls
)
if content:
llama_resp.message.blocks.append(
TextBlock(text=content)
)
if thoughts:
llama_resp.message.blocks.append(
ThinkingBlock(content=thoughts)
)
yield llama_resp

if self.use_file_api:
Expand Down Expand Up @@ -551,7 +544,11 @@ def get_tool_calls_from_response(
**kwargs: Any,
) -> List[ToolSelection]:
"""Predict and call the tool."""
tool_calls = response.message.additional_kwargs.get("tool_calls", [])
tool_calls = [
block
for block in response.message.blocks
if isinstance(block, ToolCallBlock)
]

if len(tool_calls) < 1:
if error_on_no_tool_call:
Expand All @@ -565,9 +562,9 @@ def get_tool_calls_from_response(
for tool_call in tool_calls:
tool_selections.append(
ToolSelection(
tool_id=tool_call["name"],
tool_name=tool_call["name"],
tool_kwargs=tool_call["args"],
tool_id=tool_call.tool_name,
tool_name=tool_call.tool_name,
tool_kwargs=cast(Dict[str, Any], tool_call.tool_kwargs),
)
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
import asyncio
import json
import logging
from collections.abc import Sequence
from io import BytesIO
from typing import (
TYPE_CHECKING,
Any,
Dict,
Union,
Optional,
Type,
Tuple,
)
from typing import TYPE_CHECKING, Any, Dict, Union, Optional, Type, Tuple, cast
import typing

import google.genai.types as types
Expand All @@ -29,6 +22,7 @@
DocumentBlock,
VideoBlock,
ThinkingBlock,
ToolCallBlock,
)
from llama_index.core.program.utils import _repair_incomplete_json
from tenacity import (
Expand Down Expand Up @@ -188,16 +182,33 @@ def chat_from_gemini_response(
)
additional_kwargs["thought_signatures"].append(part.thought_signature)
if part.function_call:
if "tool_calls" not in additional_kwargs:
additional_kwargs["tool_calls"] = []
additional_kwargs["tool_calls"].append(
{
"id": part.function_call.id if part.function_call.id else "",
"name": part.function_call.name,
"args": part.function_call.args,
"thought_signature": part.thought_signature,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might be losing thought signatures here? Although not totally sure either

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we might be, but ToolCallBlock does not have something that could properly handle the signature and I did not want to open it to additional_kwargs chaos... Although I can do that if we think it is necessary

}
if (
part.thought_signature
not in additional_kwargs["thought_signatures"]
):
additional_kwargs["thought_signatures"].append(
part.thought_signature
)
content_blocks.append(
ToolCallBlock(
tool_call_id=part.function_call.id or "",
tool_name=part.function_call.name or "",
tool_kwargs=part.function_call.args or {},
)
)
if part.function_response:
# follow the same pattern as for transforming a chatmessage into a gemini message: if it's a function response, package it alone and return it
additional_kwargs["tool_call_id"] = part.function_response.id
role = ROLES_FROM_GEMINI[top_candidate.content.role]
print("RESPONSE", json.dumps(part.function_response.response))
return ChatResponse(
message=ChatMessage(
role=role, content=json.dumps(part.function_response.response)
),
raw=raw,
additional_kwargs=additional_kwargs,
)

if thought_tokens:
thinking_blocks = [
i
Expand Down Expand Up @@ -271,6 +282,7 @@ async def chat_message_to_gemini(
message: ChatMessage, use_file_api: bool = False, client: Optional[Client] = None
) -> Union[types.Content, types.File]:
"""Convert ChatMessages to Gemini-specific history, including ImageDocuments."""
unique_tool_calls = []
parts = []
part = None
for index, block in enumerate(message.blocks):
Expand Down Expand Up @@ -326,6 +338,11 @@ async def chat_message_to_gemini(
part.thought_signature = block.additional_information.get(
"thought_signature", None
)
elif isinstance(block, ToolCallBlock):
part = types.Part.from_function_call(
name=block.tool_name, args=cast(Dict[str, Any], block.tool_kwargs)
)
unique_tool_calls.append((block.tool_name, str(block.tool_kwargs)))
else:
msg = f"Unsupported content block type: {type(block).__name__}"
raise ValueError(msg)
Expand All @@ -343,15 +360,20 @@ async def chat_message_to_gemini(

for tool_call in message.additional_kwargs.get("tool_calls", []):
if isinstance(tool_call, dict):
part = types.Part.from_function_call(
name=tool_call.get("name"), args=tool_call.get("args")
)
part.thought_signature = tool_call.get("thought_signature")
if (
tool_call.get("name", ""),
str(tool_call.get("args", {})),
) not in unique_tool_calls:
part = types.Part.from_function_call(
name=tool_call.get("name", ""), args=tool_call.get("args", {})
)
part.thought_signature = tool_call.get("thought_signature")
else:
part = types.Part.from_function_call(
name=tool_call.name, args=tool_call.args
)
part.thought_signature = tool_call.thought_signature
if (tool_call.name, str(tool_call.args)) not in unique_tool_calls:
part = types.Part.from_function_call(
name=tool_call.name, args=tool_call.args
)
part.thought_signature = tool_call.thought_signature
parts.append(part)

# the tool call id is the name of the tool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ dev = [

[project]
name = "llama-index-llms-google-genai"
version = "0.6.2"
version = "0.7.0"
description = "llama-index llms google genai integration"
authors = [{name = "Your Name", email = "[email protected]"}]
requires-python = ">=3.9,<4.0"
Expand All @@ -36,7 +36,7 @@ license = "MIT"
dependencies = [
"pillow>=10.2.0",
"google-genai>=1.24.0,<2",
"llama-index-core>=0.14.3,<0.15",
"llama-index-core>=0.14.5,<0.15",
]

[tool.codespell]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
TextBlock,
VideoBlock,
ThinkingBlock,
ToolCallBlock,
)
from llama_index.core.llms.llm import ToolSelection
from llama_index.core.program.function_program import get_function_tool
Expand Down Expand Up @@ -564,8 +565,16 @@ def test_tool_required_integration(llm: GoogleGenAI) -> None:
tools=[search_tool],
tool_required=True,
)
assert response.message.additional_kwargs.get("tool_calls") is not None
assert len(response.message.additional_kwargs["tool_calls"]) > 0
assert (
len(
[
block
for block in response.message.blocks
if isinstance(block, ToolCallBlock)
]
)
> 0
)

# Test with tool_required=False
response = llm.chat_with_tools(
Expand Down Expand Up @@ -729,6 +738,10 @@ async def test_prepare_chat_params_more_than_2_tool_calls():
)
],
),
ChatMessage(
blocks=[ToolCallBlock(tool_name="get_available_tools", tool_kwargs={})],
role=MessageRole.ASSISTANT,
),
ChatMessage(
content="Let me search for puppies.",
role=MessageRole.ASSISTANT,
Expand Down Expand Up @@ -777,10 +790,11 @@ async def test_prepare_chat_params_more_than_2_tool_calls():
text="The user is asking me for a puppy, so I should search for puppies using the available tools.",
thought=True,
),
types.Part.from_function_call(name="get_available_tools", args={}),
types.Part(text="Let me search for puppies."),
types.Part.from_function_call(name="tool_1", args=None),
types.Part.from_function_call(name="tool_2", args=None),
types.Part.from_function_call(name="tool_3", args=None),
types.Part.from_function_call(name="tool_1", args={}),
types.Part.from_function_call(name="tool_2", args={}),
types.Part.from_function_call(name="tool_3", args={}),
],
role=MessageRole.MODEL,
),
Expand Down Expand Up @@ -872,6 +886,10 @@ def test_cached_content_in_response() -> None:
mock_response.candidates[0].content.parts[0].text = "Test response"
mock_response.candidates[0].content.parts[0].thought = False
mock_response.candidates[0].content.parts[0].inline_data = None
mock_response.candidates[0].content.parts[0].function_call.id = ""
mock_response.candidates[0].content.parts[0].function_call.name = "hello"
mock_response.candidates[0].content.parts[0].function_call.args = {}
mock_response.candidates[0].content.parts[0].function_response = None
mock_response.prompt_feedback = None
mock_response.usage_metadata = None
mock_response.function_calls = None
Expand Down Expand Up @@ -899,6 +917,10 @@ def test_cached_content_without_cached_content() -> None:
mock_response.candidates[0].content.parts[0].text = "Test response"
mock_response.candidates[0].content.parts[0].thought = False
mock_response.candidates[0].content.parts[0].inline_data = None
mock_response.candidates[0].content.parts[0].function_call.id = ""
mock_response.candidates[0].content.parts[0].function_call.name = "hello"
mock_response.candidates[0].content.parts[0].function_call.args = {}
mock_response.candidates[0].content.parts[0].function_response = None
mock_response.prompt_feedback = None
mock_response.usage_metadata = None
mock_response.function_calls = None
Expand All @@ -923,9 +945,15 @@ def test_thoughts_in_response() -> None:
mock_response.candidates[0].content.parts[0].text = "This is a thought."
mock_response.candidates[0].content.parts[0].inline_data = None
mock_response.candidates[0].content.parts[0].thought = True
mock_response.candidates[0].content.parts[0].function_call.id = ""
mock_response.candidates[0].content.parts[0].function_call.name = "hello"
mock_response.candidates[0].content.parts[0].function_call.args = {}
mock_response.candidates[0].content.parts[1].text = "This is not a thought."
mock_response.candidates[0].content.parts[1].inline_data = None
mock_response.candidates[0].content.parts[1].thought = None
mock_response.candidates[0].content.parts[1].function_call = None
mock_response.candidates[0].content.parts[1].function_response = None
mock_response.candidates[0].content.parts[0].function_response = None
mock_response.candidates[0].content.parts[0].model_dump = MagicMock(return_value={})
mock_response.candidates[0].content.parts[1].model_dump = MagicMock(return_value={})
mock_response.prompt_feedback = None
Expand Down Expand Up @@ -967,6 +995,8 @@ def test_thoughts_without_thought_response() -> None:
mock_response.candidates[0].content.parts[0].text = "This is not a thought."
mock_response.candidates[0].content.parts[0].inline_data = None
mock_response.candidates[0].content.parts[0].thought = None
mock_response.candidates[0].content.parts[0].function_call = None
mock_response.candidates[0].content.parts[0].function_response = None
mock_response.prompt_feedback = None
mock_response.usage_metadata = None
mock_response.function_calls = None
Expand Down Expand Up @@ -1084,6 +1114,8 @@ def test_built_in_tool_in_response() -> None:
].text = "Test response with search results"
mock_response.candidates[0].content.parts[0].inline_data = None
mock_response.candidates[0].content.parts[0].thought = None
mock_response.candidates[0].content.parts[0].function_call = None
mock_response.candidates[0].content.parts[0].function_response = None
mock_response.prompt_feedback = None
mock_response.usage_metadata = MagicMock()
mock_response.usage_metadata.model_dump.return_value = {
Expand Down Expand Up @@ -1523,6 +1555,8 @@ def test_code_execution_response_parts() -> None:
)
mock_text_part.inline_data = None
mock_text_part.thought = None
mock_text_part.function_call = None
mock_text_part.function_response = None

mock_code_part = MagicMock()
mock_code_part.text = None
Expand All @@ -1532,6 +1566,8 @@ def test_code_execution_response_parts() -> None:
"code": "def is_prime(n):\n if n < 2:\n return False\n for i in range(2, int(n**0.5) + 1):\n if n % i == 0:\n return False\n return True\n\nprimes = []\nn = 2\nwhile len(primes) < 50:\n if is_prime(n):\n primes.append(n)\n n += 1\n\nprint(f'Sum of first 50 primes: {sum(primes)}')",
"language": types.Language.PYTHON,
}
mock_code_part.function_call = None
mock_code_part.function_response = None

mock_result_part = MagicMock()
mock_result_part.text = None
Expand All @@ -1541,11 +1577,15 @@ def test_code_execution_response_parts() -> None:
"outcome": types.Outcome.OUTCOME_OK,
"output": "Sum of first 50 primes: 5117",
}
mock_result_part.function_call = None
mock_result_part.function_response = None

mock_final_text_part = MagicMock()
mock_final_text_part.text = "The sum of the first 50 prime numbers is 5117."
mock_final_text_part.inline_data = None
mock_final_text_part.thought = None
mock_final_text_part.function_call = None
mock_final_text_part.function_response = None

mock_candidate.content.parts = [
mock_text_part,
Expand Down
Loading