Skip to content

Commit 7b6a903

Browse files
authored
Create ClientSessionGroup for managing multiple session connections. (#639)
1 parent fdb538b commit 7b6a903

File tree

3 files changed

+771
-0
lines changed

3 files changed

+771
-0
lines changed

src/mcp/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .client.session import ClientSession
2+
from .client.session_group import ClientSessionGroup
23
from .client.stdio import StdioServerParameters, stdio_client
34
from .server.session import ServerSession
45
from .server.stdio import stdio_server
@@ -63,6 +64,7 @@
6364
"ClientRequest",
6465
"ClientResult",
6566
"ClientSession",
67+
"ClientSessionGroup",
6668
"CreateMessageRequest",
6769
"CreateMessageResult",
6870
"ErrorData",

src/mcp/client/session_group.py

+372
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
"""
2+
SessionGroup concurrently manages multiple MCP session connections.
3+
4+
Tools, resources, and prompts are aggregated across servers. Servers may
5+
be connected to or disconnected from at any point after initialization.
6+
7+
This abstractions can handle naming collisions using a custom user-provided
8+
hook.
9+
"""
10+
11+
import contextlib
12+
import logging
13+
from collections.abc import Callable
14+
from datetime import timedelta
15+
from types import TracebackType
16+
from typing import Any, TypeAlias
17+
18+
import anyio
19+
from pydantic import BaseModel
20+
from typing_extensions import Self
21+
22+
import mcp
23+
from mcp import types
24+
from mcp.client.sse import sse_client
25+
from mcp.client.stdio import StdioServerParameters
26+
from mcp.client.streamable_http import streamablehttp_client
27+
from mcp.shared.exceptions import McpError
28+
29+
30+
class SseServerParameters(BaseModel):
31+
"""Parameters for intializing a sse_client."""
32+
33+
# The endpoint URL.
34+
url: str
35+
36+
# Optional headers to include in requests.
37+
headers: dict[str, Any] | None = None
38+
39+
# HTTP timeout for regular operations.
40+
timeout: float = 5
41+
42+
# Timeout for SSE read operations.
43+
sse_read_timeout: float = 60 * 5
44+
45+
46+
class StreamableHttpParameters(BaseModel):
47+
"""Parameters for intializing a streamablehttp_client."""
48+
49+
# The endpoint URL.
50+
url: str
51+
52+
# Optional headers to include in requests.
53+
headers: dict[str, Any] | None = None
54+
55+
# HTTP timeout for regular operations.
56+
timeout: timedelta = timedelta(seconds=30)
57+
58+
# Timeout for SSE read operations.
59+
sse_read_timeout: timedelta = timedelta(seconds=60 * 5)
60+
61+
# Close the client session when the transport closes.
62+
terminate_on_close: bool = True
63+
64+
65+
ServerParameters: TypeAlias = (
66+
StdioServerParameters | SseServerParameters | StreamableHttpParameters
67+
)
68+
69+
70+
class ClientSessionGroup:
71+
"""Client for managing connections to multiple MCP servers.
72+
73+
This class is responsible for encapsulating management of server connections.
74+
It aggregates tools, resources, and prompts from all connected servers.
75+
76+
For auxiliary handlers, such as resource subscription, this is delegated to
77+
the client and can be accessed via the session.
78+
79+
Example Usage:
80+
name_fn = lambda name, server_info: f"{(server_info.name)}-{name}"
81+
async with ClientSessionGroup(component_name_hook=name_fn) as group:
82+
for server_params in server_params:
83+
group.connect_to_server(server_param)
84+
...
85+
86+
"""
87+
88+
class _ComponentNames(BaseModel):
89+
"""Used for reverse index to find components."""
90+
91+
prompts: set[str] = set()
92+
resources: set[str] = set()
93+
tools: set[str] = set()
94+
95+
# Standard MCP components.
96+
_prompts: dict[str, types.Prompt]
97+
_resources: dict[str, types.Resource]
98+
_tools: dict[str, types.Tool]
99+
100+
# Client-server connection management.
101+
_sessions: dict[mcp.ClientSession, _ComponentNames]
102+
_tool_to_session: dict[str, mcp.ClientSession]
103+
_exit_stack: contextlib.AsyncExitStack
104+
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
105+
106+
# Optional fn consuming (component_name, serverInfo) for custom names.
107+
# This is provide a means to mitigate naming conflicts across servers.
108+
# Example: (tool_name, serverInfo) => "{result.serverInfo.name}.{tool_name}"
109+
_ComponentNameHook: TypeAlias = Callable[[str, types.Implementation], str]
110+
_component_name_hook: _ComponentNameHook | None
111+
112+
def __init__(
113+
self,
114+
exit_stack: contextlib.AsyncExitStack | None = None,
115+
component_name_hook: _ComponentNameHook | None = None,
116+
) -> None:
117+
"""Initializes the MCP client."""
118+
119+
self._tools = {}
120+
self._resources = {}
121+
self._prompts = {}
122+
123+
self._sessions = {}
124+
self._tool_to_session = {}
125+
if exit_stack is None:
126+
self._exit_stack = contextlib.AsyncExitStack()
127+
self._owns_exit_stack = True
128+
else:
129+
self._exit_stack = exit_stack
130+
self._owns_exit_stack = False
131+
self._session_exit_stacks = {}
132+
self._component_name_hook = component_name_hook
133+
134+
async def __aenter__(self) -> Self:
135+
# Enter the exit stack only if we created it ourselves
136+
if self._owns_exit_stack:
137+
await self._exit_stack.__aenter__()
138+
return self
139+
140+
async def __aexit__(
141+
self,
142+
_exc_type: type[BaseException] | None,
143+
_exc_val: BaseException | None,
144+
_exc_tb: TracebackType | None,
145+
) -> bool | None:
146+
"""Closes session exit stacks and main exit stack upon completion."""
147+
148+
# Concurrently close session stacks.
149+
async with anyio.create_task_group() as tg:
150+
for exit_stack in self._session_exit_stacks.values():
151+
tg.start_soon(exit_stack.aclose)
152+
153+
# Only close the main exit stack if we created it
154+
if self._owns_exit_stack:
155+
await self._exit_stack.aclose()
156+
157+
@property
158+
def sessions(self) -> list[mcp.ClientSession]:
159+
"""Returns the list of sessions being managed."""
160+
return list(self._sessions.keys())
161+
162+
@property
163+
def prompts(self) -> dict[str, types.Prompt]:
164+
"""Returns the prompts as a dictionary of names to prompts."""
165+
return self._prompts
166+
167+
@property
168+
def resources(self) -> dict[str, types.Resource]:
169+
"""Returns the resources as a dictionary of names to resources."""
170+
return self._resources
171+
172+
@property
173+
def tools(self) -> dict[str, types.Tool]:
174+
"""Returns the tools as a dictionary of names to tools."""
175+
return self._tools
176+
177+
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
178+
"""Executes a tool given its name and arguments."""
179+
session = self._tool_to_session[name]
180+
session_tool_name = self.tools[name].name
181+
return await session.call_tool(session_tool_name, args)
182+
183+
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
184+
"""Disconnects from a single MCP server."""
185+
186+
session_known_for_components = session in self._sessions
187+
session_known_for_stack = session in self._session_exit_stacks
188+
189+
if not session_known_for_components and not session_known_for_stack:
190+
raise McpError(
191+
types.ErrorData(
192+
code=types.INVALID_PARAMS,
193+
message="Provided session is not managed or already disconnected.",
194+
)
195+
)
196+
197+
if session_known_for_components:
198+
component_names = self._sessions.pop(session) # Pop from _sessions tracking
199+
200+
# Remove prompts associated with the session.
201+
for name in component_names.prompts:
202+
if name in self._prompts:
203+
del self._prompts[name]
204+
# Remove resources associated with the session.
205+
for name in component_names.resources:
206+
if name in self._resources:
207+
del self._resources[name]
208+
# Remove tools associated with the session.
209+
for name in component_names.tools:
210+
if name in self._tools:
211+
del self._tools[name]
212+
if name in self._tool_to_session:
213+
del self._tool_to_session[name]
214+
215+
# Clean up the session's resources via its dedicated exit stack
216+
if session_known_for_stack:
217+
session_stack_to_close = self._session_exit_stacks.pop(session)
218+
await session_stack_to_close.aclose()
219+
220+
async def connect_with_session(
221+
self, server_info: types.Implementation, session: mcp.ClientSession
222+
) -> mcp.ClientSession:
223+
"""Connects to a single MCP server."""
224+
await self._aggregate_components(server_info, session)
225+
return session
226+
227+
async def connect_to_server(
228+
self,
229+
server_params: ServerParameters,
230+
) -> mcp.ClientSession:
231+
"""Connects to a single MCP server."""
232+
server_info, session = await self._establish_session(server_params)
233+
return await self.connect_with_session(server_info, session)
234+
235+
async def _establish_session(
236+
self, server_params: ServerParameters
237+
) -> tuple[types.Implementation, mcp.ClientSession]:
238+
"""Establish a client session to an MCP server."""
239+
240+
session_stack = contextlib.AsyncExitStack()
241+
try:
242+
# Create read and write streams that facilitate io with the server.
243+
if isinstance(server_params, StdioServerParameters):
244+
client = mcp.stdio_client(server_params)
245+
read, write = await session_stack.enter_async_context(client)
246+
elif isinstance(server_params, SseServerParameters):
247+
client = sse_client(
248+
url=server_params.url,
249+
headers=server_params.headers,
250+
timeout=server_params.timeout,
251+
sse_read_timeout=server_params.sse_read_timeout,
252+
)
253+
read, write = await session_stack.enter_async_context(client)
254+
else:
255+
client = streamablehttp_client(
256+
url=server_params.url,
257+
headers=server_params.headers,
258+
timeout=server_params.timeout,
259+
sse_read_timeout=server_params.sse_read_timeout,
260+
terminate_on_close=server_params.terminate_on_close,
261+
)
262+
read, write, _ = await session_stack.enter_async_context(client)
263+
264+
session = await session_stack.enter_async_context(
265+
mcp.ClientSession(read, write)
266+
)
267+
result = await session.initialize()
268+
269+
# Session successfully initialized.
270+
# Store its stack and register the stack with the main group stack.
271+
self._session_exit_stacks[session] = session_stack
272+
# session_stack itself becomes a resource managed by the
273+
# main _exit_stack.
274+
await self._exit_stack.enter_async_context(session_stack)
275+
276+
return result.serverInfo, session
277+
except Exception:
278+
# If anything during this setup fails, ensure the session-specific
279+
# stack is closed.
280+
await session_stack.aclose()
281+
raise
282+
283+
async def _aggregate_components(
284+
self, server_info: types.Implementation, session: mcp.ClientSession
285+
) -> None:
286+
"""Aggregates prompts, resources, and tools from a given session."""
287+
288+
# Create a reverse index so we can find all prompts, resources, and
289+
# tools belonging to this session. Used for removing components from
290+
# the session group via self.disconnect_from_server.
291+
component_names = self._ComponentNames()
292+
293+
# Temporary components dicts. We do not want to modify the aggregate
294+
# lists in case of an intermediate failure.
295+
prompts_temp: dict[str, types.Prompt] = {}
296+
resources_temp: dict[str, types.Resource] = {}
297+
tools_temp: dict[str, types.Tool] = {}
298+
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
299+
300+
# Query the server for its prompts and aggregate to list.
301+
try:
302+
prompts = (await session.list_prompts()).prompts
303+
for prompt in prompts:
304+
name = self._component_name(prompt.name, server_info)
305+
prompts_temp[name] = prompt
306+
component_names.prompts.add(name)
307+
except McpError as err:
308+
logging.warning(f"Could not fetch prompts: {err}")
309+
310+
# Query the server for its resources and aggregate to list.
311+
try:
312+
resources = (await session.list_resources()).resources
313+
for resource in resources:
314+
name = self._component_name(resource.name, server_info)
315+
resources_temp[name] = resource
316+
component_names.resources.add(name)
317+
except McpError as err:
318+
logging.warning(f"Could not fetch resources: {err}")
319+
320+
# Query the server for its tools and aggregate to list.
321+
try:
322+
tools = (await session.list_tools()).tools
323+
for tool in tools:
324+
name = self._component_name(tool.name, server_info)
325+
tools_temp[name] = tool
326+
tool_to_session_temp[name] = session
327+
component_names.tools.add(name)
328+
except McpError as err:
329+
logging.warning(f"Could not fetch tools: {err}")
330+
331+
# Clean up exit stack for session if we couldn't retrieve anything
332+
# from the server.
333+
if not any((prompts_temp, resources_temp, tools_temp)):
334+
del self._session_exit_stacks[session]
335+
336+
# Check for duplicates.
337+
matching_prompts = prompts_temp.keys() & self._prompts.keys()
338+
if matching_prompts:
339+
raise McpError(
340+
types.ErrorData(
341+
code=types.INVALID_PARAMS,
342+
message=f"{matching_prompts} already exist in group prompts.",
343+
)
344+
)
345+
matching_resources = resources_temp.keys() & self._resources.keys()
346+
if matching_resources:
347+
raise McpError(
348+
types.ErrorData(
349+
code=types.INVALID_PARAMS,
350+
message=f"{matching_resources} already exist in group resources.",
351+
)
352+
)
353+
matching_tools = tools_temp.keys() & self._tools.keys()
354+
if matching_tools:
355+
raise McpError(
356+
types.ErrorData(
357+
code=types.INVALID_PARAMS,
358+
message=f"{matching_tools} already exist in group tools.",
359+
)
360+
)
361+
362+
# Aggregate components.
363+
self._sessions[session] = component_names
364+
self._prompts.update(prompts_temp)
365+
self._resources.update(resources_temp)
366+
self._tools.update(tools_temp)
367+
self._tool_to_session.update(tool_to_session_temp)
368+
369+
def _component_name(self, name: str, server_info: types.Implementation) -> str:
370+
if self._component_name_hook:
371+
return self._component_name_hook(name, server_info)
372+
return name

0 commit comments

Comments
 (0)