diff --git a/src/mcp/client/session_group.py b/src/mcp/client/session_group.py index 700b5417f..1ee18b828 100644 --- a/src/mcp/client/session_group.py +++ b/src/mcp/client/session_group.py @@ -16,7 +16,7 @@ from typing import Any, TypeAlias import anyio -from pydantic import BaseModel +from pydantic import AnyUrl, BaseModel from typing_extensions import Self import mcp @@ -98,6 +98,7 @@ class _ComponentNames(BaseModel): # Client-server connection management. _sessions: dict[mcp.ClientSession, _ComponentNames] _tool_to_session: dict[str, mcp.ClientSession] + _resource_to_session: dict[str, mcp.ClientSession] _exit_stack: contextlib.AsyncExitStack _session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack] @@ -114,20 +115,16 @@ def __init__( ) -> None: """Initializes the MCP client.""" - self._tools = {} - self._resources = {} + self._exit_stack = exit_stack or contextlib.AsyncExitStack() + self._owns_exit_stack = exit_stack is None + self._session_exit_stacks = {} + self._component_name_hook = component_name_hook self._prompts = {} - + self._resources = {} + self._tools = {} self._sessions = {} self._tool_to_session = {} - if exit_stack is None: - self._exit_stack = contextlib.AsyncExitStack() - self._owns_exit_stack = True - else: - self._exit_stack = exit_stack - self._owns_exit_stack = False - self._session_exit_stacks = {} - self._component_name_hook = component_name_hook + self._resource_to_session = {} # New mapping async def __aenter__(self) -> Self: # Enter the exit stack only if we created it ourselves @@ -172,12 +169,23 @@ def tools(self) -> dict[str, types.Tool]: """Returns the tools as a dictionary of names to tools.""" return self._tools + @property + def resource_templates(self) -> list[types.ResourceTemplate]: + """Return all unique resource templates from the resources.""" + templates: list[types.ResourceTemplate] = [] + for r in self._resources.values(): + t = getattr(r, "template", None) + if t is not None and t not in templates: + templates.append(t) + return templates + async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult: """Executes a tool given its name and arguments.""" session = self._tool_to_session[name] session_tool_name = self.tools[name].name return await session.call_tool(session_tool_name, args) + async def disconnect_from_server(self, session: mcp.ClientSession) -> None: """Disconnects from a single MCP server.""" @@ -290,8 +298,8 @@ async def _aggregate_components(self, server_info: types.Implementation, session resources_temp: dict[str, types.Resource] = {} tools_temp: dict[str, types.Tool] = {} tool_to_session_temp: dict[str, mcp.ClientSession] = {} + resource_to_session_temp: dict[str, mcp.ClientSession] = {} - # Query the server for its prompts and aggregate to list. try: prompts = (await session.list_prompts()).prompts for prompt in prompts: @@ -308,6 +316,7 @@ async def _aggregate_components(self, server_info: types.Implementation, session name = self._component_name(resource.name, server_info) resources_temp[name] = resource component_names.resources.add(name) + resource_to_session_temp[name] = session except McpError as err: logging.warning(f"Could not fetch resources: {err}") @@ -359,8 +368,70 @@ async def _aggregate_components(self, server_info: types.Implementation, session self._resources.update(resources_temp) self._tools.update(tools_temp) self._tool_to_session.update(tool_to_session_temp) + self._resource_to_session.update(resource_to_session_temp) def _component_name(self, name: str, server_info: types.Implementation) -> str: if self._component_name_hook: return self._component_name_hook(name, server_info) return name + + async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult: + """Read a resource from the appropriate session based on the URI.""" + for name, resource in self._resources.items(): + if resource.uri == uri: + session = self._resource_to_session.get(name) + if session: + return await session.read_resource(uri) + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"No session found for resource with URI '{uri}'", + ) + ) + + async def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/subscribe request.""" + for name, resource in self._resources.items(): + if resource.uri == uri: + session = self._resource_to_session[name] + if session: + return await session.subscribe_resource(uri) + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"No session found for resource with URI '{uri}'", + ) + ) + + async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult: + """Send a resources/unsubscribe request.""" + # Find the session that owns this resource + for name, resource in self._resources.items(): + if resource.uri == uri: + session = self._resource_to_session.get(name) + if session: + return await session.unsubscribe_resource(uri) + + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"No resource found with URI '{uri}'", + ) + ) + + async def get_prompt( + self, name: str, arguments: dict[str, str] | None = None + ) -> types.GetPromptResult: + """Send a prompts/get request.""" + if name in self._prompts: + prompt = self._prompts[name] + session = self._tool_to_session.get(name) + if session: + return await session.get_prompt(prompt.name, arguments) + raise McpError( + types.ErrorData( + code=types.INVALID_PARAMS, + message=f"No prompt found with name '{name}'", + ) + ) + diff --git a/tests/client/test_session_group.py b/tests/client/test_session_group.py index 16a887e00..9058e1258 100644 --- a/tests/client/test_session_group.py +++ b/tests/client/test_session_group.py @@ -2,6 +2,7 @@ from unittest import mock import pytest +from pydantic import AnyUrl import mcp from mcp import types @@ -365,3 +366,83 @@ async def test_establish_session_parameterized( # 3. Assert returned values assert returned_server_info is mock_initialize_result.serverInfo assert returned_session is mock_entered_session + + @pytest.mark.anyio + async def test_read_resource_not_found(self): + """Test reading a non-existent resource from a session group.""" + # --- Mock Dependencies --- + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + test_resource = types.Resource( + name="test_resource", + uri=AnyUrl("test://resource/1"), + description="Test resource", + ) + + # Mock all list methods + mock_session.list_resources.return_value = types.ListResourcesResult( + resources=[test_resource] + ) + mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[]) + mock_session.list_tools.return_value = types.ListToolsResult(tools=[]) + + # --- Test Setup --- + group = ClientSessionGroup() + group._session_exit_stacks[mock_session] = mock.AsyncMock( + spec=contextlib.AsyncExitStack + ) + await group.connect_with_session( + types.Implementation(name="test_server", version="1.0.0"), mock_session + ) + + # --- Test Execution & Assertions --- + with pytest.raises(ValueError, match="Resource not found: test://nonexistent"): + await group.read_resource(AnyUrl("test://nonexistent")) + + @pytest.mark.anyio + async def test_read_resource_success(self): + """Test successfully reading a resource from a session group.""" + # --- Mock Dependencies --- + mock_session = mock.AsyncMock(spec=mcp.ClientSession) + test_resource = types.Resource( + name="test_resource", + uri=AnyUrl("test://resource/1"), + description="Test resource", + ) + + # Mock all list methods + mock_session.list_resources.return_value = types.ListResourcesResult( + resources=[test_resource] + ) + mock_session.list_prompts.return_value = types.ListPromptsResult(prompts=[]) + mock_session.list_tools.return_value = types.ListToolsResult(tools=[]) + + # Mock the session's read_resource method + mock_read_result = mock.AsyncMock(spec=types.ReadResourceResult) + mock_read_result.contents = [ + types.TextContent(type="text", text="Resource content") + ] + mock_session.read_resource.return_value = mock_read_result + + # --- Test Setup --- + group = ClientSessionGroup() + group._session_exit_stacks[mock_session] = mock.AsyncMock( + spec=contextlib.AsyncExitStack + ) + await group.connect_with_session( + types.Implementation(name="test_server", version="1.0.0"), mock_session + ) + + # Verify resource was added + assert "test_resource" in group._resources + assert group._resources["test_resource"] == test_resource + assert "test_resource" in group._resource_to_session + assert group._resource_to_session["test_resource"] == mock_session + + # --- Test Execution --- + result = await group.read_resource(AnyUrl("test://resource/1")) + + # --- Assertions --- + assert result.contents == [ + types.TextContent(type="text", text="Resource content") + ] + mock_session.read_resource.assert_called_once_with(AnyUrl("test://resource/1"))