1+ # Copyright 2025 Google LLC
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ import asyncio
16+ from typing import Any
17+ from unittest .mock import AsyncMock , Mock
18+
19+ import pytest
20+ import pytest_asyncio
21+ from aiohttp import ClientSession
22+
23+ from toolbox_core .mcp_transport .base import _McpHttpTransportBase
24+ from toolbox_core .protocol import ManifestSchema , ToolSchema
25+
26+
27+ class ConcreteTransport (_McpHttpTransportBase ):
28+ """A concrete class for testing the abstract base class."""
29+
30+ async def _initialize_session (self ):
31+ pass # Will be mocked
32+
33+ async def _send_request (self , * args , ** kwargs ) -> Any :
34+ pass # Will be mocked
35+
36+
37+ def create_fake_initialize_response (
38+ server_version = "1.0.0" , protocol_version = "2025-06-18" , capabilities = {"tools" : {}}
39+ ):
40+ return {
41+ "serverInfo" : {"version" : server_version },
42+ "protocolVersion" : protocol_version ,
43+ "capabilities" : capabilities ,
44+ }
45+
46+
47+ def create_fake_tools_list_response ():
48+ return {
49+ "tools" : [
50+ {
51+ "name" : "get_weather" ,
52+ "description" : "Gets the weather." ,
53+ "inputSchema" : {
54+ "type" : "object" ,
55+ "properties" : {
56+ "location" : {"type" : "string" , "description" : "The location." }
57+ },
58+ "required" : ["location" ],
59+ },
60+ },
61+ {
62+ "name" : "send_email" ,
63+ "description" : "Sends an email." ,
64+ "inputSchema" : {
65+ "type" : "object" ,
66+ "properties" : {
67+ "recipient" : {"type" : "string" },
68+ "body" : {"type" : "string" },
69+ },
70+ },
71+ },
72+ ]
73+ }
74+
75+
76+ @pytest_asyncio .fixture
77+ async def transport (mocker ):
78+ """
79+ A pytest fixture that creates and tears down a ConcreteTransport instance
80+ for each test that uses it.
81+ """
82+ base_url = "http://fake-server.com"
83+ transport_instance = ConcreteTransport (base_url )
84+ mocker .patch .object (
85+ transport_instance , "_initialize_session" , new_callable = AsyncMock
86+ )
87+ mocker .patch .object (transport_instance , "_send_request" , new_callable = AsyncMock )
88+
89+ yield transport_instance
90+ await transport_instance .close ()
91+
92+
93+ class TestMcpHttpTransportBase :
94+
95+ @pytest .mark .asyncio
96+ async def test_initialization (self , transport ):
97+ """Test constructor properties."""
98+ assert transport .base_url == "http://fake-server.com/mcp/"
99+ assert transport ._manage_session is True
100+ assert isinstance (transport ._session , ClientSession )
101+
102+ @pytest .mark .asyncio
103+ async def test_initialization_with_external_session (self ):
104+ """Test that an external session is used and not managed."""
105+ mock_session = AsyncMock (spec = ClientSession )
106+ transport = ConcreteTransport ("http://fake-server.com" , session = mock_session )
107+ assert transport ._manage_session is False
108+ assert transport ._session is mock_session
109+ await transport .close ()
110+
111+ @pytest .mark .asyncio
112+ async def test_ensure_initialized_is_called (self , transport ):
113+ """Test that public methods trigger initialization."""
114+
115+ async def init_side_effect ():
116+ transport ._server_version = "1.0.0"
117+
118+ transport ._initialize_session .side_effect = init_side_effect
119+ transport ._send_request .return_value = create_fake_tools_list_response ()
120+
121+ await transport .tools_list ()
122+ transport ._initialize_session .assert_called_once ()
123+
124+ @pytest .mark .asyncio
125+ async def test_initialization_is_only_run_once (self , transport ):
126+ """Test the lock ensures initialization only happens once with concurrent calls."""
127+ init_started = asyncio .Event ()
128+
129+ async def slow_init ():
130+ init_started .set ()
131+ transport ._server_version = "1.0.0"
132+ await asyncio .sleep (0.01 )
133+
134+ transport ._initialize_session .side_effect = slow_init
135+ transport ._send_request .return_value = create_fake_tools_list_response ()
136+
137+ task1 = asyncio .create_task (transport .tools_list ())
138+ await init_started .wait ()
139+ task2 = asyncio .create_task (transport .tools_list ())
140+ await asyncio .gather (task1 , task2 )
141+
142+ transport ._initialize_session .assert_called_once ()
143+
144+ def test_convert_tool_schema (self , transport ):
145+ """Test the conversion from MCP tool schema to internal ToolSchema."""
146+ tool_data = {
147+ "name" : "get_weather" ,
148+ "description" : "A test tool." ,
149+ "inputSchema" : {
150+ "type" : "object" ,
151+ "properties" : {
152+ "location" : {"type" : "string" , "description" : "The city." },
153+ "unit" : {"type" : "string" },
154+ },
155+ "required" : ["location" ],
156+ },
157+ }
158+ tool_schema = transport ._convert_tool_schema (tool_data )
159+ assert tool_schema .description == "A test tool."
160+ location_param = next (p for p in tool_schema .parameters if p .name == "location" )
161+ assert location_param .required is True
162+ assert location_param .description == "The city."
163+
164+ def test_convert_tool_schema_with_auth (self , transport ):
165+ """Test schema conversion with authentication metadata."""
166+ tool_data = {
167+ "name" : "drive_tool" ,
168+ "description" : "A tool that requires auth." ,
169+ "inputSchema" : {"type" : "object" , "properties" : {}},
170+ "_meta" : {
171+ "toolbox/authInvoke" : ["google" ],
172+ },
173+ }
174+ tool_schema = transport ._convert_tool_schema (tool_data )
175+ assert tool_schema .authRequired == ["google" ]
176+
177+ @pytest .mark .asyncio
178+ async def test_tools_list_success (self , transport ):
179+ transport ._server_version = "1.0.0"
180+ transport ._init_task = asyncio .create_task (asyncio .sleep (0 ))
181+ transport ._send_request .return_value = create_fake_tools_list_response ()
182+ manifest = await transport .tools_list ()
183+ transport ._send_request .assert_called_once_with (
184+ url = transport .base_url , method = "tools/list" , params = {}, headers = None
185+ )
186+ assert isinstance (manifest , ManifestSchema )
187+
188+ @pytest .mark .asyncio
189+ async def test_tool_get_success (self , transport ):
190+ transport ._server_version = "1.0.0"
191+ transport ._init_task = asyncio .create_task (asyncio .sleep (0 ))
192+ transport ._send_request .return_value = create_fake_tools_list_response ()
193+ manifest = await transport .tool_get ("get_weather" )
194+ assert len (manifest .tools ) == 1
195+
196+ @pytest .mark .asyncio
197+ async def test_tool_get_not_found (self , transport ):
198+ transport ._server_version = "1.0.0"
199+ transport ._init_task = asyncio .create_task (asyncio .sleep (0 ))
200+ transport ._send_request .return_value = create_fake_tools_list_response ()
201+ with pytest .raises (ValueError , match = "Tool 'non_existent_tool' not found." ):
202+ await transport .tool_get ("non_existent_tool" )
203+
204+ @pytest .mark .asyncio
205+ async def test_tool_invoke_success (self , transport ):
206+ transport ._init_task = asyncio .create_task (asyncio .sleep (0 ))
207+ transport ._send_request .return_value = {
208+ "content" : [{"type" : "text" , "text" : "The weather is sunny." }]
209+ }
210+ result = await transport .tool_invoke (
211+ "get_weather" , {"location" : "London" }, headers = {"X-Test" : "true" }
212+ )
213+ assert result == "The weather is sunny."
214+
215+ @pytest .mark .asyncio
216+ async def test_perform_initialization_and_negotiation_failure (self , transport ):
217+ transport ._send_request .return_value = {}
218+ with pytest .raises (RuntimeError , match = "Server info not found" ):
219+ await transport ._perform_initialization_and_negotiation ({})
220+
221+ @pytest .mark .asyncio
222+ async def test_close_managed_session (self , mocker ):
223+ mock_close = mocker .patch ("aiohttp.ClientSession.close" , new_callable = AsyncMock )
224+ transport = ConcreteTransport ("http://fake-server.com" )
225+ transport ._init_task = asyncio .create_task (asyncio .sleep (0 ))
226+ await transport .close ()
227+ mock_close .assert_called_once ()
228+
229+ @pytest .mark .asyncio
230+ async def test_close_unmanaged_session (self ):
231+ mock_session = AsyncMock (spec = ClientSession )
232+ transport = ConcreteTransport ("http://fake-server.com" , session = mock_session )
233+ transport ._init_task = asyncio .create_task (asyncio .sleep (0 ))
234+ await transport .close ()
235+ mock_session .close .assert_not_called ()
0 commit comments