|
4 | 4 | from unittest.mock import MagicMock, patch |
5 | 5 |
|
6 | 6 | import pytest |
| 7 | +from mcp.types import Tool as MCPTool |
7 | 8 |
|
8 | 9 | from strands.tools.mcp import MCPClient |
9 | 10 | from strands.tools.mcp.mcp_agent_tool import MCPAgentTool |
@@ -41,12 +42,18 @@ def mock_agent_tool(mock_mcp_tool): |
41 | 42 | return agent_tool |
42 | 43 |
|
43 | 44 |
|
44 | | -def create_mock_tool(name: str) -> MagicMock: |
| 45 | +def create_mock_tool(tool_name: str, mcp_tool_name: str | None = None) -> MagicMock: |
45 | 46 | """Helper to create mock tools with specific names.""" |
46 | 47 | tool = MagicMock(spec=MCPAgentTool) |
47 | | - tool.tool_name = name |
48 | | - tool.mcp_tool = MagicMock() |
49 | | - tool.mcp_tool.name = name |
| 48 | + tool.tool_name = tool_name |
| 49 | + tool.tool_spec = { |
| 50 | + "name": tool_name, |
| 51 | + "description": f"Description for {tool_name}", |
| 52 | + "inputSchema": {"json": {"type": "object", "properties": {}}}, |
| 53 | + } |
| 54 | + tool.mcp_tool = MagicMock(spec=MCPTool) |
| 55 | + tool.mcp_tool.name = mcp_tool_name or tool_name |
| 56 | + tool.mcp_tool.description = f"Description for {tool_name}" |
50 | 57 | return tool |
51 | 58 |
|
52 | 59 |
|
@@ -146,8 +153,8 @@ async def test_load_tools_handles_pagination(mock_transport): |
146 | 153 | # Should have called list_tools_sync twice |
147 | 154 | assert mock_list_tools.call_count == 2 |
148 | 155 | # First call with no token, second call with "page2" token |
149 | | - mock_list_tools.assert_any_call(None) |
150 | | - mock_list_tools.assert_any_call("page2") |
| 156 | + mock_list_tools.assert_any_call(None, prefix=None) |
| 157 | + mock_list_tools.assert_any_call("page2", prefix=None) |
151 | 158 |
|
152 | 159 | assert len(tools) == 2 |
153 | 160 | assert tools[0] is tool1 |
@@ -236,31 +243,44 @@ async def test_rejected_filter_string_match(mock_transport): |
236 | 243 | @pytest.mark.asyncio |
237 | 244 | async def test_prefix_renames_tools(mock_transport): |
238 | 245 | """Test that prefix properly renames tools.""" |
239 | | - original_tool = create_mock_tool("original_name") |
240 | | - original_tool.mcp_client = MagicMock() |
| 246 | + # Create a mock MCP tool (not MCPAgentTool) |
| 247 | + mock_mcp_tool = MagicMock() |
| 248 | + mock_mcp_tool.name = "original_name" |
241 | 249 |
|
242 | 250 | client = MCPClient(mock_transport, prefix="prefix") |
243 | 251 | client._tool_provider_started = True |
244 | 252 |
|
| 253 | + # Mock the session active state |
| 254 | + mock_thread = MagicMock() |
| 255 | + mock_thread.is_alive.return_value = True |
| 256 | + client._background_thread = mock_thread |
| 257 | + |
245 | 258 | with ( |
246 | | - patch.object(client, "list_tools_sync") as mock_list_tools, |
| 259 | + patch.object(client, "_invoke_on_background_thread") as mock_invoke, |
247 | 260 | patch("strands.tools.mcp.mcp_client.MCPAgentTool") as mock_agent_tool_class, |
248 | 261 | ): |
249 | | - mock_list_tools.return_value = PaginatedList([original_tool]) |
| 262 | + # Mock the MCP server response |
| 263 | + mock_list_tools_result = MagicMock() |
| 264 | + mock_list_tools_result.tools = [mock_mcp_tool] |
| 265 | + mock_list_tools_result.nextCursor = None |
250 | 266 |
|
251 | | - new_tool = MagicMock(spec=MCPAgentTool) |
252 | | - new_tool.tool_name = "prefix_original_name" |
253 | | - mock_agent_tool_class.return_value = new_tool |
| 267 | + mock_future = MagicMock() |
| 268 | + mock_future.result.return_value = mock_list_tools_result |
| 269 | + mock_invoke.return_value = mock_future |
254 | 270 |
|
255 | | - tools = await client.load_tools() |
| 271 | + # Mock MCPAgentTool creation |
| 272 | + mock_agent_tool = MagicMock(spec=MCPAgentTool) |
| 273 | + mock_agent_tool.tool_name = "prefix_original_name" |
| 274 | + mock_agent_tool_class.return_value = mock_agent_tool |
256 | 275 |
|
257 | | - # Should create new MCPAgentTool with prefixed name |
258 | | - mock_agent_tool_class.assert_called_once_with( |
259 | | - original_tool.mcp_tool, original_tool.mcp_client, name_override="prefix_original_name" |
260 | | - ) |
| 276 | + # Call list_tools_sync directly to test prefix functionality |
| 277 | + result = client.list_tools_sync(prefix="prefix") |
261 | 278 |
|
262 | | - assert len(tools) == 1 |
263 | | - assert tools[0] is new_tool |
| 279 | + # Should create MCPAgentTool with prefixed name |
| 280 | + mock_agent_tool_class.assert_called_once_with(mock_mcp_tool, client, name_override="prefix_original_name") |
| 281 | + |
| 282 | + assert len(result) == 1 |
| 283 | + assert result[0] is mock_agent_tool |
264 | 284 |
|
265 | 285 |
|
266 | 286 | @pytest.mark.asyncio |
@@ -318,3 +338,41 @@ async def test_remove_consumer_cleanup_failure(mock_transport): |
318 | 338 |
|
319 | 339 | with pytest.raises(ToolProviderException, match="Failed to cleanup MCP client: Cleanup failed"): |
320 | 340 | await client.remove_consumer("consumer1") |
| 341 | + |
| 342 | + |
| 343 | +def test_mcp_client_reuse_across_multiple_agents(mock_transport): |
| 344 | + """Test that a single MCPClient can be used across multiple agents.""" |
| 345 | + from strands import Agent |
| 346 | + |
| 347 | + tool1 = create_mock_tool(tool_name="shared_echo", mcp_tool_name="echo") |
| 348 | + client = MCPClient(mock_transport, tool_filters={"allowed": ["echo"]}, prefix="shared") |
| 349 | + |
| 350 | + with ( |
| 351 | + patch.object(client, "list_tools_sync") as mock_list_tools, |
| 352 | + patch.object(client, "start") as mock_start, |
| 353 | + patch.object(client, "stop") as mock_stop, |
| 354 | + ): |
| 355 | + mock_list_tools.return_value = PaginatedList([tool1]) |
| 356 | + |
| 357 | + # Create two agents with the same client |
| 358 | + agent_1 = Agent(tools=[client]) |
| 359 | + agent_2 = Agent(tools=[client]) |
| 360 | + |
| 361 | + # Both agents should have the same tool |
| 362 | + assert "shared_echo" in agent_1.tool_names |
| 363 | + assert "shared_echo" in agent_2.tool_names |
| 364 | + assert agent_1.tool_names == agent_2.tool_names |
| 365 | + |
| 366 | + # Client should only be started once |
| 367 | + mock_start.assert_called_once() |
| 368 | + |
| 369 | + # First agent cleanup - client should remain active |
| 370 | + agent_1.cleanup() |
| 371 | + mock_stop.assert_not_called() # Should not stop yet |
| 372 | + |
| 373 | + # Second agent should still work |
| 374 | + assert "shared_echo" in agent_2.tool_names |
| 375 | + |
| 376 | + # Final cleanup when last agent is removed |
| 377 | + agent_2.cleanup() |
| 378 | + mock_stop.assert_called_once() # Now it should stop |
0 commit comments