Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,7 @@ def __init__(
client_headers: Headers to include in each request sent through this client.
"""

# If no aiohttp.ClientSession is provided, make our own
manage_session = False
if session is None:
manage_session = True
session = ClientSession()
self.__transport = ToolboxTransport(url, session, manage_session)
self.__transport = ToolboxTransport(url, session)
self.__client_headers = client_headers if client_headers is not None else {}

def __parse_tool(
Expand Down
2 changes: 1 addition & 1 deletion packages/toolbox-core/src/toolbox_core/itransport.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ async def tools_list(
@abstractmethod
async def tool_invoke(
self, tool_name: str, arguments: dict, headers: Mapping[str, str]
) -> dict:
) -> str:
"""Invokes a specific tool on the server."""
pass

Expand Down
14 changes: 1 addition & 13 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,17 +128,6 @@ def __init__(
# map of client headers to their value/callable/coroutine
self.__client_headers = client_headers

# ID tokens contain sensitive user information (claims). Transmitting
# these over HTTP exposes the data to interception and unauthorized
# access. Always use HTTPS to ensure secure communication and protect
# user privacy.
if self.__transport.base_url.startswith("http://") and (
required_authn_params or required_authz_tokens or client_headers
):
warn(
"Sending ID token over HTTP. User data may be exposed. Use HTTPS for secure communication."
)

@property
def _name(self) -> str:
return self.__name__
Expand Down Expand Up @@ -296,12 +285,11 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
token_getter
)

body = await self.__transport.tool_invoke(
return await self.__transport.tool_invoke(
self.__name__,
payload,
headers,
)
return body.get("result", body)

def add_auth_token_getters(
self,
Expand Down
46 changes: 30 additions & 16 deletions packages/toolbox-core/src/toolbox_core/toolbox_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import Mapping, Optional
from warnings import warn

from aiohttp import ClientSession

Expand All @@ -23,20 +24,26 @@
class ToolboxTransport(ITransport):
"""Transport for the native Toolbox protocol."""

def __init__(self, base_url: str, session: ClientSession, manage_session: bool):
def __init__(self, base_url: str, session: Optional[ClientSession]):
self.__base_url = base_url
self.__session = session
self.__manage_session = manage_session

# If no aiohttp.ClientSession is provided, make our own
self.__manage_session = False
if session is not None:
self.__session = session
else:
self.__manage_session = True
self.__session = ClientSession()

@property
def base_url(self) -> str:
"""The base URL for the transport."""
return self.__base_url

async def tool_get(
self, tool_name: str, headers: Optional[Mapping[str, str]] = None
async def _get_manifest(
self, url: str, headers: Optional[Mapping[str, str]]
) -> ManifestSchema:
url = f"{self.__base_url}/api/tool/{tool_name}"
"""Helper method to perform GET requests and parse the ManifestSchema."""
async with self.__session.get(url, headers=headers) as response:
if not response.ok:
error_text = await response.text()
Expand All @@ -46,24 +53,31 @@ async def tool_get(
json = await response.json()
return ManifestSchema(**json)

async def tool_get(
self, tool_name: str, headers: Optional[Mapping[str, str]] = None
) -> ManifestSchema:
url = f"{self.__base_url}/api/tool/{tool_name}"
return await self._get_manifest(url, headers)

async def tools_list(
self,
toolset_name: Optional[str] = None,
headers: Optional[Mapping[str, str]] = None,
) -> ManifestSchema:
url = f"{self.__base_url}/api/toolset/{toolset_name or ''}"
async with self.__session.get(url, headers=headers) as response:
if not response.ok:
error_text = await response.text()
raise RuntimeError(
f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}"
)
json = await response.json()
return ManifestSchema(**json)
return await self._get_manifest(url, headers)

async def tool_invoke(
self, tool_name: str, arguments: dict, headers: Mapping[str, str]
) -> dict:
) -> str:
# ID tokens contain sensitive user information (claims). Transmitting
# these over HTTP exposes the data to interception and unauthorized
# access. Always use HTTPS to ensure secure communication and protect
# user privacy.
if self.base_url.startswith("http://") and headers:
warn(
"Sending data token over HTTP. User data may be exposed. Use HTTPS for secure communication."
)
url = f"{self.__base_url}/api/tool/{tool_name}/invoke"
async with self.__session.post(
url,
Expand All @@ -74,7 +88,7 @@ async def tool_invoke(
if not resp.ok:
err = body.get("error", f"unexpected status from server: {resp.status}")
raise Exception(err)
return body
return body.get("result")

async def close(self):
if self.__manage_session and not self.__session.closed:
Expand Down
145 changes: 9 additions & 136 deletions packages/toolbox-core/tests/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def toolbox_tool(
sample_tool_description: str,
) -> ToolboxTool:
"""Fixture for a ToolboxTool instance with common test setup."""
transport = ToolboxTransport(TEST_BASE_URL, http_session, False)
transport = ToolboxTransport(TEST_BASE_URL, http_session)
return ToolboxTool(
transport=transport,
name=TEST_TOOL_NAME,
Expand Down Expand Up @@ -231,7 +231,7 @@ async def test_tool_creation_callable_and_run(

with aioresponses() as m:
m.post(invoke_url, status=200, payload=mock_server_response_body)
transport = ToolboxTransport(base_url, http_session, False)
transport = ToolboxTransport(base_url, http_session)

tool_instance = ToolboxTool(
transport=transport,
Expand Down Expand Up @@ -277,7 +277,7 @@ async def test_tool_run_with_pydantic_validation_error(

with aioresponses() as m:
m.post(invoke_url, status=200, payload={"result": "Should not be called"})
transport = ToolboxTransport(base_url, http_session, False)
transport = ToolboxTransport(base_url, http_session)

tool_instance = ToolboxTool(
transport=transport,
Expand Down Expand Up @@ -368,7 +368,7 @@ def test_tool_init_basic(http_session, sample_tool_params, sample_tool_descripti
"""Tests basic tool initialization without headers or auth."""
with catch_warnings(record=True) as record:
simplefilter("always")
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)
transport = ToolboxTransport(HTTPS_BASE_URL, http_session)

tool_instance = ToolboxTool(
transport=transport,
Expand Down Expand Up @@ -398,7 +398,7 @@ def test_tool_init_with_client_headers(
http_session, sample_tool_params, sample_tool_description, static_client_header
):
"""Tests tool initialization *with* client headers."""
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)
transport = ToolboxTransport(HTTPS_BASE_URL, http_session)
tool_instance = ToolboxTool(
transport=transport,
name=TEST_TOOL_NAME,
Expand All @@ -422,7 +422,7 @@ def test_tool_init_header_auth_conflict(
):
"""Tests ValueError on init if client header conflicts with auth token."""
conflicting_client_header = {auth_header_key: "some-client-value"}
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)
transport = ToolboxTransport(HTTPS_BASE_URL, http_session)

with pytest.raises(
ValueError, match=f"Client header\\(s\\) `{auth_header_key}` already registered"
Expand All @@ -449,7 +449,7 @@ def test_tool_add_auth_token_getters_conflict_with_existing_client_header(
Tests ValueError when add_auth_token_getters introduces an auth service
whose token name conflicts with an existing client header.
"""
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)
transport = ToolboxTransport(HTTPS_BASE_URL, http_session)
tool_instance = ToolboxTool(
transport=transport,
name="tool_with_client_header",
Expand Down Expand Up @@ -485,7 +485,7 @@ def test_add_auth_token_getters_unused_token(
Tests ValueError when add_auth_token_getters is called with a getter for
an unused authentication service.
"""
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)
transport = ToolboxTransport(HTTPS_BASE_URL, http_session)
tool_instance = ToolboxTool(
transport=transport,
name=TEST_TOOL_NAME,
Expand Down Expand Up @@ -514,7 +514,7 @@ def test_add_auth_token_getter_unused_token(
Tests ValueError when add_auth_token_getters is called with a getter for
an unused authentication service.
"""
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)
transport = ToolboxTransport(HTTPS_BASE_URL, http_session)
tool_instance = ToolboxTool(
transport=transport,
name=TEST_TOOL_NAME,
Expand Down Expand Up @@ -622,130 +622,3 @@ def test_toolbox_tool_underscore_client_headers_property(toolbox_tool: ToolboxTo
# Verify immutability
with pytest.raises(TypeError):
client_headers["new_header"] = "new_value"


# --- Test for the HTTP Warning ---
@pytest.mark.parametrize(
"trigger_condition_params",
[
{"client_headers": {"X-Some-Header": "value"}},
{"required_authn_params": {"param1": ["auth-service1"]}},
{"required_authz_tokens": ["auth-service2"]},
{
"client_headers": {"X-Some-Header": "value"},
"required_authn_params": {"param1": ["auth-service1"]},
},
{
"client_headers": {"X-Some-Header": "value"},
"required_authz_tokens": ["auth-service2"],
},
{
"required_authn_params": {"param1": ["auth-service1"]},
"required_authz_tokens": ["auth-service2"],
},
{
"client_headers": {"X-Some-Header": "value"},
"required_authn_params": {"param1": ["auth-service1"]},
"required_authz_tokens": ["auth-service2"],
},
],
ids=[
"client_headers_only",
"authn_params_only",
"authz_tokens_only",
"headers_and_authn",
"headers_and_authz",
"authn_and_authz",
"all_three_conditions",
],
)
def test_tool_init_http_warning_when_sensitive_info_over_http(
http_session: ClientSession,
sample_tool_params: list[ParameterSchema],
sample_tool_description: str,
trigger_condition_params: dict,
):
"""
Tests that a UserWarning is issued if client headers, auth params, or
auth tokens are present and the base_url is HTTP.
"""
expected_warning_message: str = (
"Sending ID token over HTTP. User data may be exposed. "
"Use HTTPS for secure communication."
)
transport = ToolboxTransport(TEST_BASE_URL, http_session, False)
init_kwargs = {
"transport": transport,
"name": "http_warning_tool",
"description": sample_tool_description,
"params": sample_tool_params,
"required_authn_params": {},
"required_authz_tokens": [],
"auth_service_token_getters": {},
"bound_params": {},
"client_headers": {},
}
# Apply the specific conditions for this parametrized test
init_kwargs.update(trigger_condition_params)

with pytest.warns(UserWarning, match=expected_warning_message):
ToolboxTool(**init_kwargs)


def test_tool_init_no_http_warning_if_https(
http_session: ClientSession,
sample_tool_params: list[ParameterSchema],
sample_tool_description: str,
static_client_header: dict,
):
"""
Tests that NO UserWarning is issued if client headers are present but
the base_url is HTTPS.
"""
with catch_warnings(record=True) as record:
simplefilter("always")
transport = ToolboxTransport(HTTPS_BASE_URL, http_session, False)

ToolboxTool(
transport=transport,
name="https_tool",
description=sample_tool_description,
params=sample_tool_params,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters={},
bound_params={},
client_headers=static_client_header,
)
assert (
len(record) == 0
), f"Expected no warnings, but got: {[f'{w.category.__name__}: {w.message}' for w in record]}"


def test_tool_init_no_http_warning_if_no_sensitive_info_on_http(
http_session: ClientSession,
sample_tool_params: list[ParameterSchema],
sample_tool_description: str,
):
"""
Tests that NO UserWarning is issued if the URL is HTTP but there are
no client headers, auth params, or auth tokens.
"""
with catch_warnings(record=True) as record:
simplefilter("always")
transport = ToolboxTransport(TEST_BASE_URL, http_session, False)

ToolboxTool(
transport=transport,
name="http_tool_no_sensitive",
description=sample_tool_description,
params=sample_tool_params,
required_authn_params={},
required_authz_tokens=[],
auth_service_token_getters={},
bound_params={},
client_headers={},
)
assert (
len(record) == 0
), f"Expected no warnings, but got: {[f'{w.category.__name__}: {w.message}' for w in record]}"
Loading