Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions src/google/adk/tools/mcp_tool/mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import EmptyResult
except ImportError as e:

if sys.version_info < (3, 10):
Expand Down Expand Up @@ -240,7 +241,7 @@ def _merge_headers(

return base_headers

def _is_session_disconnected(self, session: ClientSession) -> bool:
async def _is_session_disconnected(self, session: ClientSession) -> bool:
"""Checks if a session is disconnected or closed.

Args:
Expand All @@ -249,7 +250,19 @@ def _is_session_disconnected(self, session: ClientSession) -> bool:
Returns:
True if the session is disconnected, False otherwise.
"""
return session._read_stream._closed or session._write_stream._closed

try:
response = await asyncio.wait_for(session.send_ping(), timeout=5.0)
if isinstance(response, EmptyResult):
return False
else:
logger.debug(f'Session ping returns illegal response {response}, treating as disconnected')
return True

except (asyncio.TimeoutError, anyio.ClosedResourceError, Exception) as e:
logger.debug(f'Session ping failed with error {e}, treating as disconnected')
return True


def _create_client(self, merged_headers: Optional[Dict[str, str]] = None):
"""Creates an MCP client based on the connection parameters.
Expand Down Expand Up @@ -324,7 +337,7 @@ async def create_session(
session, exit_stack = self._sessions[session_key]

# Check if the existing session is still connected
if not self._is_session_disconnected(session):
if not await self._is_session_disconnected(session):
# Session is still good, return it
return session
else:
Expand Down
137 changes: 126 additions & 11 deletions tests/unittests/tools/mcp_tool/test_mcp_session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(self):
self._read_stream._closed = False
self._write_stream._closed = False
self.initialize = AsyncMock()
self.send_ping = AsyncMock()


class MockAsyncExitStack:
Expand Down Expand Up @@ -204,19 +205,77 @@ def test_merge_headers_sse(self):
}
assert merged == expected

def test_is_session_disconnected(self):
"""Test session disconnection detection."""
@pytest.mark.asyncio
async def test_is_session_disconnected_success(self):
"""Test session disconnection detection when ping succeeds."""
from mcp.types import EmptyResult

manager = MCPSessionManager(self.mock_stdio_connection_params)

# Create mock session
session = MockClientSession()
session.send_ping.return_value = EmptyResult()

# Session is connected (ping returns EmptyResult)
is_disconnected = await manager._is_session_disconnected(session)
assert not is_disconnected
session.send_ping.assert_called_once()

@pytest.mark.asyncio
async def test_is_session_disconnected_timeout(self):
"""Test session disconnection detection when ping times out."""
import asyncio

manager = MCPSessionManager(self.mock_stdio_connection_params)

# Create mock session that times out
session = MockClientSession()
session.send_ping.side_effect = asyncio.TimeoutError()

# Session is disconnected (ping times out)
is_disconnected = await manager._is_session_disconnected(session)
assert is_disconnected

@pytest.mark.asyncio
async def test_is_session_disconnected_closed_resource(self):
"""Test session disconnection detection when resource is closed."""
import anyio

manager = MCPSessionManager(self.mock_stdio_connection_params)

# Create mock session that raises ClosedResourceError
session = MockClientSession()
session.send_ping.side_effect = anyio.ClosedResourceError()

# Session is disconnected (resource closed)
is_disconnected = await manager._is_session_disconnected(session)
assert is_disconnected

@pytest.mark.asyncio
async def test_is_session_disconnected_exception(self):
"""Test session disconnection detection when ping raises exception."""
manager = MCPSessionManager(self.mock_stdio_connection_params)

# Create mock session that raises exception
session = MockClientSession()
session.send_ping.side_effect = Exception("Connection error")

# Not disconnected
assert not manager._is_session_disconnected(session)
# Session is disconnected (ping failed)
is_disconnected = await manager._is_session_disconnected(session)
assert is_disconnected

# Disconnected - read stream closed
session._read_stream._closed = True
assert manager._is_session_disconnected(session)
@pytest.mark.asyncio
async def test_is_session_disconnected_invalid_response(self):
"""Test session disconnection detection when ping returns invalid response."""
manager = MCPSessionManager(self.mock_stdio_connection_params)

# Create mock session that returns unexpected response
session = MockClientSession()
session.send_ping.return_value = {"invalid": "response"}

# Session is disconnected (invalid response)
is_disconnected = await manager._is_session_disconnected(session)
assert is_disconnected

@pytest.mark.asyncio
async def test_create_session_stdio_new(self):
Expand Down Expand Up @@ -259,26 +318,82 @@ async def test_create_session_stdio_new(self):
@pytest.mark.asyncio
async def test_create_session_reuse_existing(self):
"""Test reusing an existing connected session."""
from mcp.types import EmptyResult

manager = MCPSessionManager(self.mock_stdio_connection_params)

# Create mock existing session
existing_session = MockClientSession()
existing_exit_stack = MockAsyncExitStack()
manager._sessions["stdio_session"] = (existing_session, existing_exit_stack)

# Session is connected
existing_session._read_stream._closed = False
existing_session._write_stream._closed = False
# Session is connected (ping succeeds)
existing_session.send_ping.return_value = EmptyResult()

session = await manager.create_session()

# Should reuse existing session
assert session == existing_session
assert len(manager._sessions) == 1

# Should not create new session
# Should not create new session (initialize should not be called again)
existing_session.initialize.assert_not_called()

# Should have checked if session is connected
existing_session.send_ping.assert_called_once()

@pytest.mark.asyncio
async def test_create_session_replace_disconnected(self):
"""Test replacing a disconnected session with a new one."""
import asyncio

manager = MCPSessionManager(self.mock_stdio_connection_params)

# Create mock existing session that is disconnected
old_session = MockClientSession()
old_exit_stack = MockAsyncExitStack()
manager._sessions["stdio_session"] = (old_session, old_exit_stack)

# Old session is disconnected (ping times out)
old_session.send_ping.side_effect = asyncio.TimeoutError()

# Create new session
new_session = MockClientSession()
new_exit_stack = MockAsyncExitStack()

with patch(
"google.adk.tools.mcp_tool.mcp_session_manager.stdio_client"
) as mock_stdio:
with patch(
"google.adk.tools.mcp_tool.mcp_session_manager.AsyncExitStack"
) as mock_exit_stack_class:
with patch(
"google.adk.tools.mcp_tool.mcp_session_manager.ClientSession"
) as mock_session_class:

# Setup mocks for new session creation
mock_exit_stack_class.return_value = new_exit_stack
mock_stdio.return_value = AsyncMock()
new_exit_stack.enter_async_context.side_effect = [
("read", "write"), # First call returns transports
new_session, # Second call returns session
]
mock_session_class.return_value = new_session

# Create session (should replace disconnected one)
session = await manager.create_session()

# Should return new session, not old one
assert session == new_session
assert session != old_session
assert len(manager._sessions) == 1

# Old session should have been cleaned up
old_exit_stack.aclose.assert_called_once()

# New session should have been initialized
new_session.initialize.assert_called_once()

@pytest.mark.asyncio
async def test_close_success(self):
"""Test successful cleanup of all sessions."""
Expand Down