Skip to content

Commit db1d9a6

Browse files
committed
add auth tests
1 parent 2307a11 commit db1d9a6

File tree

2 files changed

+392
-0
lines changed

2 files changed

+392
-0
lines changed
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import asyncio
15+
import os
16+
import uuid
17+
from typing import Any, Mapping, Optional, Union
18+
19+
from aiohttp import ClientSession
20+
21+
from . import version
22+
from .itransport import ITransport
23+
from .protocol import (
24+
AdditionalPropertiesSchema,
25+
ManifestSchema,
26+
ParameterSchema,
27+
Protocol,
28+
ToolSchema,
29+
)
30+
31+
32+
class McpHttpTransport(ITransport):
33+
"""Transport for the MCP protocol."""
34+
35+
def __init__(
36+
self,
37+
base_url: str,
38+
session: Optional[ClientSession] = None,
39+
protocol: Protocol = Protocol.MCP,
40+
):
41+
self.__mcp_base_url = base_url + "/mcp/"
42+
# Will be updated after negotiation
43+
self.__protocol_version = protocol.value
44+
self.__server_version: Optional[str] = None
45+
self.__session_id: Optional[str] = None
46+
47+
self.__manage_session = session is None
48+
self.__session = session or ClientSession()
49+
self.__init_task = asyncio.create_task(self.__initialize_session())
50+
51+
@property
52+
def base_url(self) -> str:
53+
return self.__mcp_base_url
54+
55+
def __convert_tool_schema(self, tool_data: dict) -> ToolSchema:
56+
meta = tool_data.get("_meta", {})
57+
param_auth = meta.get("toolbox/authParam", {})
58+
invoke_auth = meta.get("toolbox/authInvoke", [])
59+
60+
parameters = []
61+
input_schema = tool_data.get("inputSchema", {})
62+
properties = input_schema.get("properties", {})
63+
required = input_schema.get("required", [])
64+
65+
for name, schema in properties.items():
66+
additional_props = schema.get("additionalProperties")
67+
if isinstance(additional_props, dict):
68+
additional_props = AdditionalPropertiesSchema(
69+
type=additional_props["type"]
70+
)
71+
else:
72+
additional_props = True
73+
74+
auth_sources = param_auth.get(name)
75+
76+
parameters.append(
77+
ParameterSchema(
78+
name=name,
79+
type=schema["type"],
80+
description=schema.get("description", ""),
81+
required=name in required,
82+
authSources=auth_sources,
83+
additionalProperties=additional_props,
84+
)
85+
)
86+
87+
return ToolSchema(
88+
description=tool_data["description"],
89+
parameters=parameters,
90+
authRequired=invoke_auth,
91+
)
92+
93+
async def __list_tools(
94+
self,
95+
toolset_name: Optional[str] = None,
96+
headers: Optional[Mapping[str, str]] = None,
97+
) -> Any:
98+
"""Private helper to fetch the raw tool list from the server."""
99+
if toolset_name:
100+
url = self.__mcp_base_url + toolset_name
101+
else:
102+
url = self.__mcp_base_url
103+
return await self.__send_request(
104+
url=url, method="tools/list", params={}, headers=headers
105+
)
106+
107+
async def tool_get(
108+
self, tool_name: str, headers: Optional[Mapping[str, str]] = None
109+
) -> ManifestSchema:
110+
"""Gets a single tool from the server by listing all and filtering."""
111+
await self.__init_task
112+
113+
if self.__server_version is None:
114+
raise RuntimeError("Server version not available.")
115+
116+
result = await self.__list_tools(headers=headers)
117+
tool_def = None
118+
for tool in result.get("tools", []):
119+
if tool.get("name") == tool_name:
120+
tool_def = self.__convert_tool_schema(tool)
121+
break
122+
123+
if tool_def is None:
124+
raise ValueError(f"Tool '{tool_name}' not found.")
125+
126+
tool_details = ManifestSchema(
127+
serverVersion=self.__server_version,
128+
tools={tool_name: tool_def},
129+
)
130+
return tool_details
131+
132+
async def tools_list(
133+
self,
134+
toolset_name: Optional[str] = None,
135+
headers: Optional[Mapping[str, str]] = None,
136+
) -> ManifestSchema:
137+
"""Lists available tools from the server using the MCP protocol."""
138+
await self.__init_task
139+
140+
if self.__server_version is None:
141+
raise RuntimeError("Server version not available.")
142+
143+
result = await self.__list_tools(toolset_name, headers)
144+
tools = result.get("tools")
145+
146+
return ManifestSchema(
147+
serverVersion=self.__server_version,
148+
tools={tool["name"]: self.__convert_tool_schema(tool) for tool in tools},
149+
)
150+
151+
async def tool_invoke(
152+
self, tool_name: str, arguments: dict, headers: Optional[Mapping[str, str]]
153+
) -> str:
154+
"""Invokes a specific tool on the server using the MCP protocol."""
155+
await self.__init_task
156+
157+
url = self.__mcp_base_url
158+
params = {"name": tool_name, "arguments": arguments}
159+
result = await self.__send_request(
160+
url=url, method="tools/call", params=params, headers=headers
161+
)
162+
all_content = result.get("content", result)
163+
content_str = "".join(
164+
content.get("text", "")
165+
for content in all_content
166+
if isinstance(content, dict)
167+
)
168+
return content_str or "null"
169+
170+
async def close(self):
171+
try:
172+
await self.__init_task
173+
except Exception:
174+
# If initialization failed, we can still try to close the session.
175+
pass
176+
finally:
177+
if self.__manage_session and self.__session and not self.__session.closed:
178+
await self.__session.close()
179+
180+
async def __initialize_session(self):
181+
"""Initializes the MCP session."""
182+
if self.__session is None and self.__manage_session:
183+
self.__session = ClientSession()
184+
185+
url = self.__mcp_base_url
186+
187+
# Perform version negotitation
188+
client_supported_versions = Protocol.get_supported_mcp_versions()
189+
proposed_protocol_version = self.__protocol_version
190+
params = {
191+
"processId": os.getpid(),
192+
"clientInfo": {
193+
"name": "toolbox-python-sdk",
194+
"version": version.__version__,
195+
},
196+
"capabilities": {},
197+
"protocolVersion": proposed_protocol_version,
198+
}
199+
# Send initialize notification
200+
initialize_result = await self.__send_request(
201+
url=url, method="initialize", params=params
202+
)
203+
204+
# Get the session id if the proposed version requires it
205+
if proposed_protocol_version == "2025-03-26":
206+
self.__session_id = initialize_result.get("Mcp-Session-Id")
207+
if not self.__session_id:
208+
if self.__manage_session:
209+
await self.close()
210+
raise RuntimeError(
211+
"Server did not return a Mcp-Session-Id during initialization."
212+
)
213+
server_info = initialize_result.get("serverInfo")
214+
if not server_info:
215+
raise RuntimeError("Server info not found in initialize response")
216+
217+
self.__server_version = server_info.get("version")
218+
if not self.__server_version:
219+
raise RuntimeError("Server version not found in initialize response")
220+
221+
# Perform version negotiation based on server response
222+
server_protcol_version = initialize_result.get("protocolVersion")
223+
if server_protcol_version:
224+
if server_protcol_version not in client_supported_versions:
225+
if self.__manage_session:
226+
await self.close()
227+
raise RuntimeError(
228+
f"MCP version mismatch: client does not support server version {server_protcol_version}"
229+
)
230+
# Update the protocol version to the one agreed upon by the server.
231+
self.__protocol_version = server_protcol_version
232+
else:
233+
if self.__manage_session:
234+
await self.close()
235+
raise RuntimeError("MCP Protocol version not found in initialize response")
236+
237+
server_capabilities = initialize_result.get("capabilities")
238+
if not server_capabilities or "tools" not in server_capabilities:
239+
if self.__manage_session:
240+
await self.close()
241+
raise RuntimeError("Server does not support the 'tools' capability.")
242+
await self.__send_request(
243+
url=url, method="notifications/initialized", params={}
244+
)
245+
246+
async def __send_request(
247+
self,
248+
url: str,
249+
method: str,
250+
params: dict,
251+
headers: Optional[Mapping[str, str]] = None,
252+
) -> Any:
253+
"""Sends a JSON-RPC request to the MCP server."""
254+
255+
request_params = params.copy()
256+
req_headers = dict(headers or {})
257+
258+
# Check based on the NEGOTIATED version (self.__protocol_version)
259+
if (
260+
self.__protocol_version == "2025-03-26"
261+
and method != "initialize"
262+
and self.__session_id
263+
):
264+
request_params["Mcp-Session-Id"] = self.__session_id
265+
266+
if self.__protocol_version == "2025-06-18":
267+
req_headers["MCP-Protocol-Version"] = self.__protocol_version
268+
269+
payload = {
270+
"jsonrpc": "2.0",
271+
"method": method,
272+
"params": request_params,
273+
}
274+
275+
if not method.startswith("notifications/"):
276+
payload["id"] = str(uuid.uuid4())
277+
278+
async with self.__session.post(
279+
url, json=payload, headers=req_headers
280+
) as response:
281+
if not response.ok:
282+
error_text = await response.text()
283+
raise RuntimeError(
284+
f"API request failed with status {response.status} ({response.reason}). Server response: {error_text}"
285+
)
286+
287+
# Handle potential empty body (e.g. 204 No Content for notifications)
288+
if response.status == 204 or response.content.at_eof():
289+
return None
290+
291+
json_response = await response.json()
292+
if "error" in json_response:
293+
error = json_response["error"]
294+
if error["code"] == -32000:
295+
raise RuntimeError(f"MCP version mismatch: {error['message']}")
296+
else:
297+
raise RuntimeError(
298+
f"MCP request failed with code {error['code']}: {error['message']}"
299+
)
300+
return json_response.get("result")

0 commit comments

Comments
 (0)