Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 10 additions & 36 deletions packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from aiohttp import ClientSession
from deprecated import deprecated

from .protocol import ManifestSchema, ToolSchema
from .itransport import ITransport
from .protocol import ToolSchema
from .tool import ToolboxTool
from .toolbox_transport import ToolboxTransport
from .utils import identify_auth_requirements, resolve_value


Expand All @@ -33,9 +35,7 @@ class ToolboxClient:
is not provided.
"""

__base_url: str
__session: ClientSession
__manage_session: bool
__transport: ITransport

def __init__(
self,
Expand All @@ -56,15 +56,8 @@ def __init__(
should typically be managed externally.
client_headers: Headers to include in each request sent through this client.
"""
self.__base_url = url

# If no aiohttp.ClientSession is provided, make our own
self.__manage_session = False
if session is None:
self.__manage_session = True
session = ClientSession()
self.__session = session

self.__transport = ToolboxTransport(url, session)
self.__client_headers = client_headers if client_headers is not None else {}

def __parse_tool(
Expand Down Expand Up @@ -103,8 +96,7 @@ def __parse_tool(
)

tool = ToolboxTool(
session=self.__session,
base_url=self.__base_url,
transport=self.__transport,
name=name,
description=schema.description,
# create a read-only values to prevent mutation
Expand Down Expand Up @@ -149,8 +141,7 @@ async def close(self):
If the session was provided externally during initialization, the caller
is responsible for its lifecycle.
"""
if self.__manage_session and not self.__session.closed:
await self.__session.close()
await self.__transport.close()

async def load_tool(
self,
Expand Down Expand Up @@ -191,16 +182,7 @@ async def load_tool(
for name, val in self.__client_headers.items()
}

# request the definition of the tool from the server
url = f"{self.__base_url}/api/tool/{name}"
async with self.__session.get(url, headers=resolved_headers) as response:
if not response.ok:
error_text = await response.text()
raise RuntimeError(
f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}"
)
json = await response.json()
manifest: ManifestSchema = ManifestSchema(**json)
manifest = await self.__transport.tool_get(name, resolved_headers)

# parse the provided definition to a tool
if name not in manifest.tools:
Expand Down Expand Up @@ -274,16 +256,8 @@ async def load_toolset(
header_name: await resolve_value(original_headers[header_name])
for header_name in original_headers
}
# Request the definition of the toolset from the server
url = f"{self.__base_url}/api/toolset/{name or ''}"
async with self.__session.get(url, headers=resolved_headers) as response:
if not response.ok:
error_text = await response.text()
raise RuntimeError(
f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}"
)
json = await response.json()
manifest: ManifestSchema = ManifestSchema(**json)

manifest = await self.__transport.tools_list(name, resolved_headers)

tools: list[ToolboxTool] = []
overall_used_auth_keys: set[str] = set()
Expand Down
58 changes: 58 additions & 0 deletions packages/toolbox-core/src/toolbox_core/itransport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Mapping, Optional

from .protocol import ManifestSchema


class ITransport(ABC):
"""Defines the contract for a 'smart' transport that handles both
protocol formatting and network communication.
"""

@property
@abstractmethod
def base_url(self) -> str:
"""The base URL for the transport."""
pass

@abstractmethod
async def tool_get(
self, tool_name: str, headers: Optional[Mapping[str, str]] = None
) -> ManifestSchema:
"""Gets a single tool from the server."""
pass

@abstractmethod
async def tools_list(
self,
toolset_name: Optional[str] = None,
headers: Optional[Mapping[str, str]] = None,
) -> ManifestSchema:
"""Lists available tools from the server."""
pass

@abstractmethod
async def tool_invoke(
self, tool_name: str, arguments: dict, headers: Mapping[str, str]
) -> str:
"""Invokes a specific tool on the server."""
pass

@abstractmethod
async def close(self):
"""Closes any underlying connections."""
pass
48 changes: 12 additions & 36 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union
from warnings import warn

from aiohttp import ClientSession

from .itransport import ITransport
from .protocol import ParameterSchema
from .utils import (
create_func_docstring,
Expand All @@ -46,8 +45,7 @@ class ToolboxTool:

def __init__(
self,
session: ClientSession,
base_url: str,
transport: ITransport,
name: str,
description: str,
params: Sequence[ParameterSchema],
Expand All @@ -68,8 +66,7 @@ def __init__(
Toolbox server.

Args:
session: The `aiohttp.ClientSession` used for making API requests.
base_url: The base URL of the Toolbox server API.
transport: The transport used for making API requests.
name: The name of the remote tool.
description: The description of the remote tool.
params: The args of the tool.
Expand All @@ -84,9 +81,7 @@ def __init__(
client_headers: Client specific headers bound to the tool.
"""
# used to invoke the toolbox API
self.__session: ClientSession = session
self.__base_url: str = base_url
self.__url = f"{base_url}/api/tool/{name}/invoke"
self.__transport = transport
self.__description = description
self.__params = params
self.__pydantic_model = params_to_pydantic_model(name, self.__params)
Expand Down Expand Up @@ -120,17 +115,6 @@ def __init__(
# map of client headers to their value/callable/coroutine
self.__client_headers = client_headers

# ID tokens contain sensitive user information (claims). Transmitting
# these over HTTP exposes the data to interception and unauthorized
# access. Always use HTTPS to ensure secure communication and protect
# user privacy.
if (
required_authn_params or required_authz_tokens or client_headers
) and not self.__url.startswith("https://"):
warn(
"Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication."
)

@property
def _name(self) -> str:
return self.__name__
Expand Down Expand Up @@ -171,8 +155,7 @@ def _client_headers(

def __copy(
self,
session: Optional[ClientSession] = None,
base_url: Optional[str] = None,
transport: Optional[ITransport] = None,
name: Optional[str] = None,
description: Optional[str] = None,
params: Optional[Sequence[ParameterSchema]] = None,
Expand All @@ -192,8 +175,7 @@ def __copy(
Creates a copy of the ToolboxTool, overriding specific fields.

Args:
session: The `aiohttp.ClientSession` used for making API requests.
base_url: The base URL of the Toolbox server API.
transport: The transport used for making API requests.
name: The name of the remote tool.
description: The description of the remote tool.
params: The args of the tool.
Expand All @@ -209,8 +191,7 @@ def __copy(
"""
check = lambda val, default: val if val is not None else default
return ToolboxTool(
session=check(session, self.__session),
base_url=check(base_url, self.__base_url),
transport=check(transport, self.__transport),
name=check(name, self.__name__),
description=check(description, self.__description),
params=check(params, self.__params),
Expand Down Expand Up @@ -291,16 +272,11 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
token_getter
)

async with self.__session.post(
self.__url,
json=payload,
headers=headers,
) as resp:
body = await resp.json()
if not resp.ok:
err = body.get("error", f"unexpected status from server: {resp.status}")
raise Exception(err)
return body.get("result", body)
return await self.__transport.tool_invoke(
self.__name__,
payload,
headers,
)

def add_auth_token_getters(
self,
Expand Down
95 changes: 95 additions & 0 deletions packages/toolbox-core/src/toolbox_core/toolbox_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Mapping, Optional
from warnings import warn

from aiohttp import ClientSession

from .itransport import ITransport
from .protocol import ManifestSchema


class ToolboxTransport(ITransport):
"""Transport for the native Toolbox protocol."""

def __init__(self, base_url: str, session: Optional[ClientSession]):
self.__base_url = base_url

# If no aiohttp.ClientSession is provided, make our own
self.__manage_session = False
if session is not None:
self.__session = session
else:
self.__manage_session = True
self.__session = ClientSession()

@property
def base_url(self) -> str:
"""The base URL for the transport."""
return self.__base_url

async def __get_manifest(
self, url: str, headers: Optional[Mapping[str, str]]
) -> ManifestSchema:
"""Helper method to perform GET requests and parse the ManifestSchema."""
async with self.__session.get(url, headers=headers) as response:
if not response.ok:
error_text = await response.text()
raise RuntimeError(
f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}"
)
json = await response.json()
return ManifestSchema(**json)

async def tool_get(
self, tool_name: str, headers: Optional[Mapping[str, str]] = None
) -> ManifestSchema:
url = f"{self.__base_url}/api/tool/{tool_name}"
return await self.__get_manifest(url, headers)

async def tools_list(
self,
toolset_name: Optional[str] = None,
headers: Optional[Mapping[str, str]] = None,
) -> ManifestSchema:
url = f"{self.__base_url}/api/toolset/{toolset_name or ''}"
return await self.__get_manifest(url, headers)

async def tool_invoke(
self, tool_name: str, arguments: dict, headers: Mapping[str, str]
) -> str:
# ID tokens contain sensitive user information (claims). Transmitting
# these over HTTP exposes the data to interception and unauthorized
# access. Always use HTTPS to ensure secure communication and protect
# user privacy.
if self.base_url.startswith("http://") and headers:
warn(
"Sending data token over HTTP. User data may be exposed. Use HTTPS for secure communication."
)
url = f"{self.__base_url}/api/tool/{tool_name}/invoke"
async with self.__session.post(
url,
json=arguments,
headers=headers,
) as resp:
body = await resp.json()
if not resp.ok:
err = body.get("error", f"unexpected status from server: {resp.status}")
raise Exception(err)
return body.get("result")

async def close(self):
if self.__manage_session and not self.__session.closed:
await self.__session.close()
Loading