1+ # Copyright 2025 Google LLC
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+ import asyncio
15+ import os
16+ import uuid
17+ from typing import Any , Mapping , Optional , Union
18+
19+ from aiohttp import ClientSession
20+
21+ from . import version
22+ from .itransport import ITransport
23+ from .protocol import (
24+ AdditionalPropertiesSchema ,
25+ ManifestSchema ,
26+ ParameterSchema ,
27+ Protocol ,
28+ ToolSchema ,
29+ )
30+
31+
32+ class McpHttpTransport (ITransport ):
33+ """Transport for the MCP protocol."""
34+
35+ def __init__ (
36+ self ,
37+ base_url : str ,
38+ session : Optional [ClientSession ] = None ,
39+ protocol : Protocol = Protocol .MCP ,
40+ ):
41+ self .__mcp_base_url = base_url + "/mcp/"
42+ # Will be updated after negotiation
43+ self .__protocol_version = protocol .value
44+ self .__server_version : Optional [str ] = None
45+ self .__session_id : Optional [str ] = None
46+
47+ self .__manage_session = session is None
48+ self .__session = session or ClientSession ()
49+ self .__init_task = asyncio .create_task (self .__initialize_session ())
50+
51+ @property
52+ def base_url (self ) -> str :
53+ return self .__mcp_base_url
54+
55+ def __convert_tool_schema (self , tool_data : dict ) -> ToolSchema :
56+ meta = tool_data .get ("_meta" , {})
57+ param_auth = meta .get ("toolbox/authParam" , {})
58+ invoke_auth = meta .get ("toolbox/authInvoke" , [])
59+
60+ parameters = []
61+ input_schema = tool_data .get ("inputSchema" , {})
62+ properties = input_schema .get ("properties" , {})
63+ required = input_schema .get ("required" , [])
64+
65+ for name , schema in properties .items ():
66+ additional_props = schema .get ("additionalProperties" )
67+ if isinstance (additional_props , dict ):
68+ additional_props = AdditionalPropertiesSchema (
69+ type = additional_props ["type" ]
70+ )
71+ else :
72+ additional_props = True
73+
74+ auth_sources = param_auth .get (name )
75+
76+ parameters .append (
77+ ParameterSchema (
78+ name = name ,
79+ type = schema ["type" ],
80+ description = schema .get ("description" , "" ),
81+ required = name in required ,
82+ authSources = auth_sources ,
83+ additionalProperties = additional_props ,
84+ )
85+ )
86+
87+ return ToolSchema (
88+ description = tool_data ["description" ],
89+ parameters = parameters ,
90+ authRequired = invoke_auth ,
91+ )
92+
93+ async def __list_tools (
94+ self ,
95+ toolset_name : Optional [str ] = None ,
96+ headers : Optional [Mapping [str , str ]] = None ,
97+ ) -> Any :
98+ """Private helper to fetch the raw tool list from the server."""
99+ if toolset_name :
100+ url = self .__mcp_base_url + toolset_name
101+ else :
102+ url = self .__mcp_base_url
103+ return await self .__send_request (
104+ url = url , method = "tools/list" , params = {}, headers = headers
105+ )
106+
107+ async def tool_get (
108+ self , tool_name : str , headers : Optional [Mapping [str , str ]] = None
109+ ) -> ManifestSchema :
110+ """Gets a single tool from the server by listing all and filtering."""
111+ await self .__init_task
112+
113+ if self .__server_version is None :
114+ raise RuntimeError ("Server version not available." )
115+
116+ result = await self .__list_tools (headers = headers )
117+ tool_def = None
118+ for tool in result .get ("tools" , []):
119+ if tool .get ("name" ) == tool_name :
120+ tool_def = self .__convert_tool_schema (tool )
121+ break
122+
123+ if tool_def is None :
124+ raise ValueError (f"Tool '{ tool_name } ' not found." )
125+
126+ tool_details = ManifestSchema (
127+ serverVersion = self .__server_version ,
128+ tools = {tool_name : tool_def },
129+ )
130+ return tool_details
131+
132+ async def tools_list (
133+ self ,
134+ toolset_name : Optional [str ] = None ,
135+ headers : Optional [Mapping [str , str ]] = None ,
136+ ) -> ManifestSchema :
137+ """Lists available tools from the server using the MCP protocol."""
138+ await self .__init_task
139+
140+ if self .__server_version is None :
141+ raise RuntimeError ("Server version not available." )
142+
143+ result = await self .__list_tools (toolset_name , headers )
144+ tools = result .get ("tools" )
145+
146+ return ManifestSchema (
147+ serverVersion = self .__server_version ,
148+ tools = {tool ["name" ]: self .__convert_tool_schema (tool ) for tool in tools },
149+ )
150+
151+ async def tool_invoke (
152+ self , tool_name : str , arguments : dict , headers : Optional [Mapping [str , str ]]
153+ ) -> str :
154+ """Invokes a specific tool on the server using the MCP protocol."""
155+ await self .__init_task
156+
157+ url = self .__mcp_base_url
158+ params = {"name" : tool_name , "arguments" : arguments }
159+ result = await self .__send_request (
160+ url = url , method = "tools/call" , params = params , headers = headers
161+ )
162+ all_content = result .get ("content" , result )
163+ content_str = "" .join (
164+ content .get ("text" , "" )
165+ for content in all_content
166+ if isinstance (content , dict )
167+ )
168+ return content_str or "null"
169+
170+ async def close (self ):
171+ try :
172+ await self .__init_task
173+ except Exception :
174+ # If initialization failed, we can still try to close the session.
175+ pass
176+ finally :
177+ if self .__manage_session and self .__session and not self .__session .closed :
178+ await self .__session .close ()
179+
180+ async def __initialize_session (self ):
181+ """Initializes the MCP session."""
182+ if self .__session is None and self .__manage_session :
183+ self .__session = ClientSession ()
184+
185+ url = self .__mcp_base_url
186+
187+ # Perform version negotitation
188+ client_supported_versions = Protocol .get_supported_mcp_versions ()
189+ proposed_protocol_version = self .__protocol_version
190+ params = {
191+ "processId" : os .getpid (),
192+ "clientInfo" : {
193+ "name" : "toolbox-python-sdk" ,
194+ "version" : version .__version__ ,
195+ },
196+ "capabilities" : {},
197+ "protocolVersion" : proposed_protocol_version ,
198+ }
199+ # Send initialize notification
200+ initialize_result = await self .__send_request (
201+ url = url , method = "initialize" , params = params
202+ )
203+
204+ # Get the session id if the proposed version requires it
205+ if proposed_protocol_version == "2025-03-26" :
206+ self .__session_id = initialize_result .get ("Mcp-Session-Id" )
207+ if not self .__session_id :
208+ if self .__manage_session :
209+ await self .close ()
210+ raise RuntimeError (
211+ "Server did not return a Mcp-Session-Id during initialization."
212+ )
213+ server_info = initialize_result .get ("serverInfo" )
214+ if not server_info :
215+ raise RuntimeError ("Server info not found in initialize response" )
216+
217+ self .__server_version = server_info .get ("version" )
218+ if not self .__server_version :
219+ raise RuntimeError ("Server version not found in initialize response" )
220+
221+ # Perform version negotiation based on server response
222+ server_protcol_version = initialize_result .get ("protocolVersion" )
223+ if server_protcol_version :
224+ if server_protcol_version not in client_supported_versions :
225+ if self .__manage_session :
226+ await self .close ()
227+ raise RuntimeError (
228+ f"MCP version mismatch: client does not support server version { server_protcol_version } "
229+ )
230+ # Update the protocol version to the one agreed upon by the server.
231+ self .__protocol_version = server_protcol_version
232+ else :
233+ if self .__manage_session :
234+ await self .close ()
235+ raise RuntimeError ("MCP Protocol version not found in initialize response" )
236+
237+ server_capabilities = initialize_result .get ("capabilities" )
238+ if not server_capabilities or "tools" not in server_capabilities :
239+ if self .__manage_session :
240+ await self .close ()
241+ raise RuntimeError ("Server does not support the 'tools' capability." )
242+ await self .__send_request (
243+ url = url , method = "notifications/initialized" , params = {}
244+ )
245+
246+ async def __send_request (
247+ self ,
248+ url : str ,
249+ method : str ,
250+ params : dict ,
251+ headers : Optional [Mapping [str , str ]] = None ,
252+ ) -> Any :
253+ """Sends a JSON-RPC request to the MCP server."""
254+
255+ request_params = params .copy ()
256+ req_headers = dict (headers or {})
257+
258+ # Check based on the NEGOTIATED version (self.__protocol_version)
259+ if (
260+ self .__protocol_version == "2025-03-26"
261+ and method != "initialize"
262+ and self .__session_id
263+ ):
264+ request_params ["Mcp-Session-Id" ] = self .__session_id
265+
266+ if self .__protocol_version == "2025-06-18" :
267+ req_headers ["MCP-Protocol-Version" ] = self .__protocol_version
268+
269+ payload = {
270+ "jsonrpc" : "2.0" ,
271+ "method" : method ,
272+ "params" : request_params ,
273+ }
274+
275+ if not method .startswith ("notifications/" ):
276+ payload ["id" ] = str (uuid .uuid4 ())
277+
278+ async with self .__session .post (
279+ url , json = payload , headers = req_headers
280+ ) as response :
281+ if not response .ok :
282+ error_text = await response .text ()
283+ raise RuntimeError (
284+ f"API request failed with status { response .status } ({ response .reason } ). Server response: { error_text } "
285+ )
286+
287+ # Handle potential empty body (e.g. 204 No Content for notifications)
288+ if response .status == 204 or response .content .at_eof ():
289+ return None
290+
291+ json_response = await response .json ()
292+ if "error" in json_response :
293+ error = json_response ["error" ]
294+ if error ["code" ] == - 32000 :
295+ raise RuntimeError (f"MCP version mismatch: { error ['message' ]} " )
296+ else :
297+ raise RuntimeError (
298+ f"MCP request failed with code { error ['code' ]} : { error ['message' ]} "
299+ )
300+ return json_response .get ("result" )
0 commit comments