Skip to content

Commit 6f1a234

Browse files
committed
add auth param support code
1 parent db1d9a6 commit 6f1a234

File tree

1 file changed

+234
-0
lines changed
  • packages/toolbox-core/src/toolbox_core/mcp_transport

1 file changed

+234
-0
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
from abc import ABC, abstractmethod
17+
from typing import Any, Mapping, Optional
18+
19+
from aiohttp import ClientSession
20+
21+
from ..itransport import ITransport
22+
from ..protocol import (
23+
AdditionalPropertiesSchema,
24+
ManifestSchema,
25+
ParameterSchema,
26+
Protocol,
27+
ToolSchema,
28+
)
29+
30+
31+
class _McpHttpTransportBase(ITransport, ABC):
32+
"""Base transport for MCP protocols."""
33+
34+
def __init__(
35+
self,
36+
base_url: str,
37+
session: Optional[ClientSession] = None,
38+
protocol: Protocol = Protocol.MCP,
39+
):
40+
self._mcp_base_url = base_url + "/mcp/"
41+
self._protocol_version = protocol.value
42+
self._server_version: Optional[str] = None
43+
44+
self._manage_session = session is None
45+
self._session = session or ClientSession()
46+
self._init_lock = asyncio.Lock()
47+
self._init_task: Optional[asyncio.Task] = None
48+
49+
async def _ensure_initialized(self):
50+
"""Ensures the session is initialized before making requests."""
51+
async with self._init_lock:
52+
if self._init_task is None:
53+
self._init_task = asyncio.create_task(self._initialize_session())
54+
await self._init_task
55+
56+
@property
57+
def base_url(self) -> str:
58+
return self._mcp_base_url
59+
60+
def _convert_tool_schema(self, tool_data: dict) -> ToolSchema:
61+
meta = tool_data.get("_meta", {})
62+
param_auth = meta.get("toolbox/authParams", {})
63+
invoke_auth = meta.get("toolbox/authInvoke", [])
64+
65+
parameters = []
66+
input_schema = tool_data.get("inputSchema", {})
67+
properties = input_schema.get("properties", {})
68+
required = input_schema.get("required", [])
69+
70+
71+
for name, schema in properties.items():
72+
additional_props = schema.get("additionalProperties")
73+
if isinstance(additional_props, dict):
74+
additional_props = AdditionalPropertiesSchema(
75+
type=additional_props["type"]
76+
)
77+
else:
78+
additional_props = True
79+
80+
auth_sources = param_auth.get(name)
81+
parameters.append(
82+
ParameterSchema(
83+
name=name,
84+
type=schema["type"],
85+
description=schema.get("description", ""),
86+
required=name in required,
87+
additionalProperties=additional_props,
88+
)
89+
)
90+
91+
return ToolSchema(
92+
description=tool_data["description"],
93+
parameters=parameters,
94+
authRequired=invoke_auth,
95+
)
96+
97+
async def _list_tools(
98+
self,
99+
toolset_name: Optional[str] = None,
100+
headers: Optional[Mapping[str, str]] = None,
101+
) -> Any:
102+
"""Private helper to fetch the raw tool list from the server."""
103+
if toolset_name:
104+
url = self._mcp_base_url + toolset_name
105+
else:
106+
url = self._mcp_base_url
107+
return await self._send_request(
108+
url=url, method="tools/list", params={}, headers=headers
109+
)
110+
111+
async def tool_get(
112+
self, tool_name: str, headers: Optional[Mapping[str, str]] = None
113+
) -> ManifestSchema:
114+
"""Gets a single tool from the server by listing all and filtering."""
115+
await self._ensure_initialized()
116+
117+
if self._server_version is None:
118+
raise RuntimeError("Server version not available.")
119+
120+
result = await self._list_tools(headers=headers)
121+
tool_def = None
122+
for tool in result.get("tools", []):
123+
if tool.get("name") == tool_name:
124+
tool_def = self._convert_tool_schema(tool)
125+
break
126+
127+
if tool_def is None:
128+
raise ValueError(f"Tool '{tool_name}' not found.")
129+
130+
tool_details = ManifestSchema(
131+
serverVersion=self._server_version,
132+
tools={tool_name: tool_def},
133+
)
134+
return tool_details
135+
136+
async def tools_list(
137+
self,
138+
toolset_name: Optional[str] = None,
139+
headers: Optional[Mapping[str, str]] = None,
140+
) -> ManifestSchema:
141+
"""Lists available tools from the server using the MCP protocol."""
142+
await self._ensure_initialized()
143+
144+
if self._server_version is None:
145+
raise RuntimeError("Server version not available.")
146+
147+
result = await self._list_tools(toolset_name, headers)
148+
tools = result.get("tools")
149+
150+
return ManifestSchema(
151+
serverVersion=self._server_version,
152+
tools={tool["name"]: self._convert_tool_schema(tool) for tool in tools},
153+
)
154+
155+
async def tool_invoke(
156+
self, tool_name: str, arguments: dict, headers: Optional[Mapping[str, str]]
157+
) -> str:
158+
"""Invokes a specific tool on the server using the MCP protocol."""
159+
await self._ensure_initialized()
160+
161+
url = self._mcp_base_url
162+
params = {"name": tool_name, "arguments": arguments}
163+
result = await self._send_request(
164+
url=url, method="tools/call", params=params, headers=headers
165+
)
166+
all_content = result.get("content", result)
167+
content_str = "".join(
168+
content.get("text", "")
169+
for content in all_content
170+
if isinstance(content, dict)
171+
)
172+
return content_str or "null"
173+
174+
async def close(self):
175+
async with self._init_lock:
176+
if self._init_task:
177+
try:
178+
await self._init_task
179+
except Exception:
180+
# If initialization failed, we can still try to close.
181+
pass
182+
if self._manage_session and self._session and not self._session.closed:
183+
await self._session.close()
184+
185+
async def _perform_initialization_and_negotiation(
186+
self, params: dict, headers: Optional[Mapping[str, str]] = None
187+
) -> Any:
188+
"""Performs the common initialization and version negotiation logic."""
189+
initialize_result = await self._send_request(
190+
url=self._mcp_base_url, method="initialize", params=params, headers=headers
191+
)
192+
193+
server_info = initialize_result.get("serverInfo")
194+
if not server_info:
195+
raise RuntimeError("Server info not found in initialize response")
196+
197+
self._server_version = server_info.get("version")
198+
if not self._server_version:
199+
raise RuntimeError("Server version not found in initialize response")
200+
201+
server_protocol_version = initialize_result.get("protocolVersion")
202+
if server_protocol_version:
203+
if server_protocol_version != self._protocol_version:
204+
raise RuntimeError(
205+
"MCP version mismatch: client does not support server version"
206+
f" {server_protocol_version}"
207+
)
208+
else:
209+
if self._manage_session:
210+
await self.close()
211+
raise RuntimeError("MCP Protocol version not found in initialize response")
212+
213+
server_capabilities = initialize_result.get("capabilities")
214+
if not server_capabilities or "tools" not in server_capabilities:
215+
if self._manage_session:
216+
await self.close()
217+
raise RuntimeError("Server does not support the 'tools' capability.")
218+
return initialize_result
219+
220+
@abstractmethod
221+
async def _initialize_session(self):
222+
"""Initializes the MCP session."""
223+
pass
224+
225+
@abstractmethod
226+
async def _send_request(
227+
self,
228+
url: str,
229+
method: str,
230+
params: dict,
231+
headers: Optional[Mapping[str, str]] = None,
232+
) -> Any:
233+
"""Sends a JSON-RPC request to the MCP server."""
234+
pass

0 commit comments

Comments
 (0)