diff --git a/src/google/adk/tools/mcp_tool/mcp_session_manager.py b/src/google/adk/tools/mcp_tool/mcp_session_manager.py index fbe843a510..44e03928e2 100644 --- a/src/google/adk/tools/mcp_tool/mcp_session_manager.py +++ b/src/google/adk/tools/mcp_tool/mcp_session_manager.py @@ -37,6 +37,7 @@ from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client from mcp.client.streamable_http import streamablehttp_client + from mcp.types import EmptyResult except ImportError as e: if sys.version_info < (3, 10): @@ -240,7 +241,7 @@ def _merge_headers( return base_headers - def _is_session_disconnected(self, session: ClientSession) -> bool: + async def _is_session_disconnected(self, session: ClientSession) -> bool: """Checks if a session is disconnected or closed. Args: @@ -249,7 +250,19 @@ def _is_session_disconnected(self, session: ClientSession) -> bool: Returns: True if the session is disconnected, False otherwise. """ - return session._read_stream._closed or session._write_stream._closed + + try: + response = await asyncio.wait_for(session.send_ping(), timeout=5.0) + if isinstance(response, EmptyResult): + return False + else: + logger.debug(f'Session ping returns illegal response {response}, treating as disconnected') + return True + + except (asyncio.TimeoutError, anyio.ClosedResourceError, Exception) as e: + logger.debug(f'Session ping failed with error {e}, treating as disconnected') + return True + def _create_client(self, merged_headers: Optional[Dict[str, str]] = None): """Creates an MCP client based on the connection parameters. @@ -324,7 +337,7 @@ async def create_session( session, exit_stack = self._sessions[session_key] # Check if the existing session is still connected - if not self._is_session_disconnected(session): + if not await self._is_session_disconnected(session): # Session is still good, return it return session else: diff --git a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py index 559e51719a..64d3598500 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_session_manager.py @@ -70,6 +70,7 @@ def __init__(self): self._read_stream._closed = False self._write_stream._closed = False self.initialize = AsyncMock() + self.send_ping = AsyncMock() class MockAsyncExitStack: @@ -204,19 +205,77 @@ def test_merge_headers_sse(self): } assert merged == expected - def test_is_session_disconnected(self): - """Test session disconnection detection.""" + @pytest.mark.asyncio + async def test_is_session_disconnected_success(self): + """Test session disconnection detection when ping succeeds.""" + from mcp.types import EmptyResult + manager = MCPSessionManager(self.mock_stdio_connection_params) # Create mock session session = MockClientSession() + session.send_ping.return_value = EmptyResult() + + # Session is connected (ping returns EmptyResult) + is_disconnected = await manager._is_session_disconnected(session) + assert not is_disconnected + session.send_ping.assert_called_once() + + @pytest.mark.asyncio + async def test_is_session_disconnected_timeout(self): + """Test session disconnection detection when ping times out.""" + import asyncio + + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Create mock session that times out + session = MockClientSession() + session.send_ping.side_effect = asyncio.TimeoutError() + + # Session is disconnected (ping times out) + is_disconnected = await manager._is_session_disconnected(session) + assert is_disconnected + + @pytest.mark.asyncio + async def test_is_session_disconnected_closed_resource(self): + """Test session disconnection detection when resource is closed.""" + import anyio + + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Create mock session that raises ClosedResourceError + session = MockClientSession() + session.send_ping.side_effect = anyio.ClosedResourceError() + + # Session is disconnected (resource closed) + is_disconnected = await manager._is_session_disconnected(session) + assert is_disconnected + + @pytest.mark.asyncio + async def test_is_session_disconnected_exception(self): + """Test session disconnection detection when ping raises exception.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Create mock session that raises exception + session = MockClientSession() + session.send_ping.side_effect = Exception("Connection error") - # Not disconnected - assert not manager._is_session_disconnected(session) + # Session is disconnected (ping failed) + is_disconnected = await manager._is_session_disconnected(session) + assert is_disconnected - # Disconnected - read stream closed - session._read_stream._closed = True - assert manager._is_session_disconnected(session) + @pytest.mark.asyncio + async def test_is_session_disconnected_invalid_response(self): + """Test session disconnection detection when ping returns invalid response.""" + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Create mock session that returns unexpected response + session = MockClientSession() + session.send_ping.return_value = {"invalid": "response"} + + # Session is disconnected (invalid response) + is_disconnected = await manager._is_session_disconnected(session) + assert is_disconnected @pytest.mark.asyncio async def test_create_session_stdio_new(self): @@ -259,6 +318,8 @@ async def test_create_session_stdio_new(self): @pytest.mark.asyncio async def test_create_session_reuse_existing(self): """Test reusing an existing connected session.""" + from mcp.types import EmptyResult + manager = MCPSessionManager(self.mock_stdio_connection_params) # Create mock existing session @@ -266,9 +327,8 @@ async def test_create_session_reuse_existing(self): existing_exit_stack = MockAsyncExitStack() manager._sessions["stdio_session"] = (existing_session, existing_exit_stack) - # Session is connected - existing_session._read_stream._closed = False - existing_session._write_stream._closed = False + # Session is connected (ping succeeds) + existing_session.send_ping.return_value = EmptyResult() session = await manager.create_session() @@ -276,9 +336,64 @@ async def test_create_session_reuse_existing(self): assert session == existing_session assert len(manager._sessions) == 1 - # Should not create new session + # Should not create new session (initialize should not be called again) existing_session.initialize.assert_not_called() + # Should have checked if session is connected + existing_session.send_ping.assert_called_once() + + @pytest.mark.asyncio + async def test_create_session_replace_disconnected(self): + """Test replacing a disconnected session with a new one.""" + import asyncio + + manager = MCPSessionManager(self.mock_stdio_connection_params) + + # Create mock existing session that is disconnected + old_session = MockClientSession() + old_exit_stack = MockAsyncExitStack() + manager._sessions["stdio_session"] = (old_session, old_exit_stack) + + # Old session is disconnected (ping times out) + old_session.send_ping.side_effect = asyncio.TimeoutError() + + # Create new session + new_session = MockClientSession() + new_exit_stack = MockAsyncExitStack() + + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.stdio_client" + ) as mock_stdio: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack" + ) as mock_exit_stack_class: + with patch( + "google.adk.tools.mcp_tool.mcp_session_manager.ClientSession" + ) as mock_session_class: + + # Setup mocks for new session creation + mock_exit_stack_class.return_value = new_exit_stack + mock_stdio.return_value = AsyncMock() + new_exit_stack.enter_async_context.side_effect = [ + ("read", "write"), # First call returns transports + new_session, # Second call returns session + ] + mock_session_class.return_value = new_session + + # Create session (should replace disconnected one) + session = await manager.create_session() + + # Should return new session, not old one + assert session == new_session + assert session != old_session + assert len(manager._sessions) == 1 + + # Old session should have been cleaned up + old_exit_stack.aclose.assert_called_once() + + # New session should have been initialized + new_session.initialize.assert_called_once() + @pytest.mark.asyncio async def test_close_success(self): """Test successful cleanup of all sessions."""