Skip to content

Commit 6dc0ae1

Browse files
committed
add unit test
1 parent 6f1a234 commit 6dc0ae1

File tree

1 file changed

+235
-0
lines changed

1 file changed

+235
-0
lines changed
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
from typing import Any
17+
from unittest.mock import AsyncMock, Mock
18+
19+
import pytest
20+
import pytest_asyncio
21+
from aiohttp import ClientSession
22+
23+
from toolbox_core.mcp_transport.base import _McpHttpTransportBase
24+
from toolbox_core.protocol import ManifestSchema, ToolSchema
25+
26+
27+
class ConcreteTransport(_McpHttpTransportBase):
28+
"""A concrete class for testing the abstract base class."""
29+
30+
async def _initialize_session(self):
31+
pass # Will be mocked
32+
33+
async def _send_request(self, *args, **kwargs) -> Any:
34+
pass # Will be mocked
35+
36+
37+
def create_fake_initialize_response(
38+
server_version="1.0.0", protocol_version="2025-06-18", capabilities={"tools": {}}
39+
):
40+
return {
41+
"serverInfo": {"version": server_version},
42+
"protocolVersion": protocol_version,
43+
"capabilities": capabilities,
44+
}
45+
46+
47+
def create_fake_tools_list_response():
48+
return {
49+
"tools": [
50+
{
51+
"name": "get_weather",
52+
"description": "Gets the weather.",
53+
"inputSchema": {
54+
"type": "object",
55+
"properties": {
56+
"location": {"type": "string", "description": "The location."}
57+
},
58+
"required": ["location"],
59+
},
60+
},
61+
{
62+
"name": "send_email",
63+
"description": "Sends an email.",
64+
"inputSchema": {
65+
"type": "object",
66+
"properties": {
67+
"recipient": {"type": "string"},
68+
"body": {"type": "string"},
69+
},
70+
},
71+
},
72+
]
73+
}
74+
75+
76+
@pytest_asyncio.fixture
77+
async def transport(mocker):
78+
"""
79+
A pytest fixture that creates and tears down a ConcreteTransport instance
80+
for each test that uses it.
81+
"""
82+
base_url = "http://fake-server.com"
83+
transport_instance = ConcreteTransport(base_url)
84+
mocker.patch.object(
85+
transport_instance, "_initialize_session", new_callable=AsyncMock
86+
)
87+
mocker.patch.object(transport_instance, "_send_request", new_callable=AsyncMock)
88+
89+
yield transport_instance
90+
await transport_instance.close()
91+
92+
93+
class TestMcpHttpTransportBase:
94+
95+
@pytest.mark.asyncio
96+
async def test_initialization(self, transport):
97+
"""Test constructor properties."""
98+
assert transport.base_url == "http://fake-server.com/mcp/"
99+
assert transport._manage_session is True
100+
assert isinstance(transport._session, ClientSession)
101+
102+
@pytest.mark.asyncio
103+
async def test_initialization_with_external_session(self):
104+
"""Test that an external session is used and not managed."""
105+
mock_session = AsyncMock(spec=ClientSession)
106+
transport = ConcreteTransport("http://fake-server.com", session=mock_session)
107+
assert transport._manage_session is False
108+
assert transport._session is mock_session
109+
await transport.close()
110+
111+
@pytest.mark.asyncio
112+
async def test_ensure_initialized_is_called(self, transport):
113+
"""Test that public methods trigger initialization."""
114+
115+
async def init_side_effect():
116+
transport._server_version = "1.0.0"
117+
118+
transport._initialize_session.side_effect = init_side_effect
119+
transport._send_request.return_value = create_fake_tools_list_response()
120+
121+
await transport.tools_list()
122+
transport._initialize_session.assert_called_once()
123+
124+
@pytest.mark.asyncio
125+
async def test_initialization_is_only_run_once(self, transport):
126+
"""Test the lock ensures initialization only happens once with concurrent calls."""
127+
init_started = asyncio.Event()
128+
129+
async def slow_init():
130+
init_started.set()
131+
transport._server_version = "1.0.0"
132+
await asyncio.sleep(0.01)
133+
134+
transport._initialize_session.side_effect = slow_init
135+
transport._send_request.return_value = create_fake_tools_list_response()
136+
137+
task1 = asyncio.create_task(transport.tools_list())
138+
await init_started.wait()
139+
task2 = asyncio.create_task(transport.tools_list())
140+
await asyncio.gather(task1, task2)
141+
142+
transport._initialize_session.assert_called_once()
143+
144+
def test_convert_tool_schema(self, transport):
145+
"""Test the conversion from MCP tool schema to internal ToolSchema."""
146+
tool_data = {
147+
"name": "get_weather",
148+
"description": "A test tool.",
149+
"inputSchema": {
150+
"type": "object",
151+
"properties": {
152+
"location": {"type": "string", "description": "The city."},
153+
"unit": {"type": "string"},
154+
},
155+
"required": ["location"],
156+
},
157+
}
158+
tool_schema = transport._convert_tool_schema(tool_data)
159+
assert tool_schema.description == "A test tool."
160+
location_param = next(p for p in tool_schema.parameters if p.name == "location")
161+
assert location_param.required is True
162+
assert location_param.description == "The city."
163+
164+
def test_convert_tool_schema_with_auth(self, transport):
165+
"""Test schema conversion with authentication metadata."""
166+
tool_data = {
167+
"name": "drive_tool",
168+
"description": "A tool that requires auth.",
169+
"inputSchema": {"type": "object", "properties": {}},
170+
"_meta": {
171+
"toolbox/authInvoke": ["google"],
172+
},
173+
}
174+
tool_schema = transport._convert_tool_schema(tool_data)
175+
assert tool_schema.authRequired == ["google"]
176+
177+
@pytest.mark.asyncio
178+
async def test_tools_list_success(self, transport):
179+
transport._server_version = "1.0.0"
180+
transport._init_task = asyncio.create_task(asyncio.sleep(0))
181+
transport._send_request.return_value = create_fake_tools_list_response()
182+
manifest = await transport.tools_list()
183+
transport._send_request.assert_called_once_with(
184+
url=transport.base_url, method="tools/list", params={}, headers=None
185+
)
186+
assert isinstance(manifest, ManifestSchema)
187+
188+
@pytest.mark.asyncio
189+
async def test_tool_get_success(self, transport):
190+
transport._server_version = "1.0.0"
191+
transport._init_task = asyncio.create_task(asyncio.sleep(0))
192+
transport._send_request.return_value = create_fake_tools_list_response()
193+
manifest = await transport.tool_get("get_weather")
194+
assert len(manifest.tools) == 1
195+
196+
@pytest.mark.asyncio
197+
async def test_tool_get_not_found(self, transport):
198+
transport._server_version = "1.0.0"
199+
transport._init_task = asyncio.create_task(asyncio.sleep(0))
200+
transport._send_request.return_value = create_fake_tools_list_response()
201+
with pytest.raises(ValueError, match="Tool 'non_existent_tool' not found."):
202+
await transport.tool_get("non_existent_tool")
203+
204+
@pytest.mark.asyncio
205+
async def test_tool_invoke_success(self, transport):
206+
transport._init_task = asyncio.create_task(asyncio.sleep(0))
207+
transport._send_request.return_value = {
208+
"content": [{"type": "text", "text": "The weather is sunny."}]
209+
}
210+
result = await transport.tool_invoke(
211+
"get_weather", {"location": "London"}, headers={"X-Test": "true"}
212+
)
213+
assert result == "The weather is sunny."
214+
215+
@pytest.mark.asyncio
216+
async def test_perform_initialization_and_negotiation_failure(self, transport):
217+
transport._send_request.return_value = {}
218+
with pytest.raises(RuntimeError, match="Server info not found"):
219+
await transport._perform_initialization_and_negotiation({})
220+
221+
@pytest.mark.asyncio
222+
async def test_close_managed_session(self, mocker):
223+
mock_close = mocker.patch("aiohttp.ClientSession.close", new_callable=AsyncMock)
224+
transport = ConcreteTransport("http://fake-server.com")
225+
transport._init_task = asyncio.create_task(asyncio.sleep(0))
226+
await transport.close()
227+
mock_close.assert_called_once()
228+
229+
@pytest.mark.asyncio
230+
async def test_close_unmanaged_session(self):
231+
mock_session = AsyncMock(spec=ClientSession)
232+
transport = ConcreteTransport("http://fake-server.com", session=mock_session)
233+
transport._init_task = asyncio.create_task(asyncio.sleep(0))
234+
await transport.close()
235+
mock_session.close.assert_not_called()

0 commit comments

Comments
 (0)