Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion src/strands/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,30 @@
T = TypeVar("T", bound=pydantic.BaseModel)


def _validate_gemini_tools(gemini_tools: list[genai.types.Tool]) -> None:
"""Validate that gemini_tools does not contain FunctionDeclarations.

Gemini-specific tools should only include tools that cannot be represented
as FunctionDeclarations (e.g., GoogleSearch, CodeExecution, ComputerUse).
Standard function calling tools should use the tools interface instead.

Args:
gemini_tools: List of Gemini tools to validate

Raises:
ValueError: If any tool contains function_declarations
"""
for tool in gemini_tools:
# Check if the tool has function_declarations attribute and it's not empty
if hasattr(tool, "function_declarations") and tool.function_declarations:
raise ValueError(
"gemini_tools should not contain FunctionDeclarations. "
"Use the standard tools interface for function calling tools. "
"gemini_tools is reserved for Gemini-specific tools like "
"GoogleSearch, CodeExecution, ComputerUse, UrlContext, and FileSearch."
)


class GeminiModel(Model):
"""Google Gemini model provider implementation.

Expand All @@ -40,10 +64,16 @@ class GeminiConfig(TypedDict, total=False):
params: Additional model parameters (e.g., temperature).
For a complete list of supported parameters, see
https://ai.google.dev/api/generate-content#generationconfig.
gemini_tools: Gemini-specific tools that are not FunctionDeclarations
(e.g., GoogleSearch, CodeExecution, ComputerUse, UrlContext, FileSearch).
Use the standard tools interface for function calling tools.
For a complete list of supported tools, see
https://ai.google.dev/api/caching#Tool
"""

model_id: Required[str]
params: dict[str, Any]
gemini_tools: list[genai.types.Tool]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pgrayy thoughts on naming here? Would we want to be consistent among providers - maybe something like built_in_tools?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think built_in_tools is better — it’s more consistent across providers and the meaning is clearer. If you agree, I’ll make the change even though the PR has already been approved.


def __init__(
self,
Expand All @@ -61,6 +91,10 @@ def __init__(
validate_config_keys(model_config, GeminiModel.GeminiConfig)
self.config = GeminiModel.GeminiConfig(**model_config)

# Validate gemini_tools if provided
if "gemini_tools" in self.config:
_validate_gemini_tools(self.config["gemini_tools"])

logger.debug("config=<%s> | initializing", self.config)

self.client_args = client_args or {}
Expand All @@ -72,6 +106,10 @@ def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type:
Args:
**model_config: Configuration overrides.
"""
# Validate gemini_tools if provided
if "gemini_tools" in model_config:
_validate_gemini_tools(model_config["gemini_tools"])

self.config.update(model_config)

@override
Expand Down Expand Up @@ -181,7 +219,7 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge
Return:
Gemini tool list.
"""
return [
tools = [
genai.types.Tool(
function_declarations=[
genai.types.FunctionDeclaration(
Expand All @@ -193,6 +231,9 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge
],
),
]
if self.config.get("gemini_tools"):
tools.extend(self.config.get("gemini_tools", []))
return tools

def _format_request_config(
self,
Expand Down
83 changes: 83 additions & 0 deletions tests/strands/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,3 +621,86 @@ async def test_structured_output(gemini_client, model, messages, model_id, weath
"model": model_id,
}
gemini_client.aio.models.generate_content.assert_called_with(**exp_request)


def test_gemini_tools_validation_rejects_function_declarations(model_id):
tool_with_function_declarations = genai.types.Tool(
function_declarations=[
genai.types.FunctionDeclaration(
name="test_function",
description="A test function",
)
]
)

with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"):
GeminiModel(model_id=model_id, gemini_tools=[tool_with_function_declarations])


def test_gemini_tools_validation_allows_non_function_tools(model_id):
tool_with_google_search = genai.types.Tool(google_search=genai.types.GoogleSearch())

model = GeminiModel(model_id=model_id, gemini_tools=[tool_with_google_search])
assert "gemini_tools" in model.config


def test_gemini_tools_validation_on_update_config(model):
tool_with_function_declarations = genai.types.Tool(
function_declarations=[
genai.types.FunctionDeclaration(
name="test_function",
description="A test function",
)
]
)

with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"):
model.update_config(gemini_tools=[tool_with_function_declarations])


@pytest.mark.asyncio
async def test_stream_request_with_gemini_tools(gemini_client, messages, model_id):
google_search_tool = genai.types.Tool(google_search=genai.types.GoogleSearch())
model = GeminiModel(model_id=model_id, gemini_tools=[google_search_tool])

await anext(model.stream(messages))

exp_request = {
"config": {
"tools": [
{"function_declarations": []},
{"google_search": {}},
]
},
"contents": [{"parts": [{"text": "test"}], "role": "user"}],
"model": model_id,
}
gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request)


@pytest.mark.asyncio
async def test_stream_request_with_gemini_tools_and_function_tools(gemini_client, messages, tool_spec, model_id):
code_execution_tool = genai.types.Tool(code_execution=genai.types.ToolCodeExecution())
model = GeminiModel(model_id=model_id, gemini_tools=[code_execution_tool])

await anext(model.stream(messages, tool_specs=[tool_spec]))

exp_request = {
"config": {
"tools": [
{
"function_declarations": [
{
"description": tool_spec["description"],
"name": tool_spec["name"],
"parameters_json_schema": tool_spec["inputSchema"]["json"],
}
]
},
{"code_execution": {}},
]
},
"contents": [{"parts": [{"text": "test"}], "role": "user"}],
"model": model_id,
}
gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request)
20 changes: 20 additions & 0 deletions tests_integ/models/test_model_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pydantic
import pytest
from google import genai

import strands
from strands import Agent
Expand All @@ -21,6 +22,16 @@ def model():
)


@pytest.fixture
def gemini_tool_model():
return GeminiModel(
client_args={"api_key": os.getenv("GOOGLE_API_KEY")},
model_id="gemini-2.5-flash",
params={"temperature": 0.15}, # Lower temperature for consistent test behavior
gemini_tools=[genai.types.Tool(code_execution=genai.types.ToolCodeExecution())],
)


@pytest.fixture
def tools():
@strands.tool
Expand Down Expand Up @@ -175,3 +186,12 @@ def test_agent_structured_output_image_input(assistant_agent, yellow_img, yellow
tru_color = assistant_agent.structured_output(type(yellow_color), content)
exp_color = yellow_color
assert tru_color == exp_color


def test_agent_with_gemini_code_execution_tool(gemini_tool_model):
# FIXME: Should verify tool usage history, but currently validates by solving a complex calculation
system_prompt = "Execute calculations and output only the numerical result. No explanations or units needed."
agent = Agent(model=gemini_tool_model, system_prompt=system_prompt)
result = agent("Calculate 931567 * 81364")
text = result.message.get("content", [{}])[0].get("text", "")
assert "75796017388" in text