Skip to content

Commit a42f953

Browse files
authored
Add support for async context management to ClientSessionGroup
This changes enables context management for setting up and tearing down async exit stacks durring server connection and disconnection respectively. Documentation has been added to show an example use case that demonstrates how `ClientSessionGroup` can be used with `async with`.
1 parent ef3feb2 commit a42f953

File tree

2 files changed

+195
-95
lines changed

2 files changed

+195
-95
lines changed

src/mcp/client/session_group.py

+177-85
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@
99
"""
1010

1111
import contextlib
12+
import logging
1213
from collections.abc import Callable
1314
from datetime import timedelta
15+
from types import TracebackType
1416
from typing import Any, TypeAlias
1517

18+
import anyio
1619
from pydantic import BaseModel
20+
from typing_extensions import Self
1721

1822
import mcp
1923
from mcp import types
@@ -67,11 +71,18 @@ class ClientSessionGroup:
6771
"""Client for managing connections to multiple MCP servers.
6872
6973
This class is responsible for encapsulating management of server connections.
70-
It it aggregates tools, resources, and prompts from all connected servers.
74+
It aggregates tools, resources, and prompts from all connected servers.
7175
7276
For auxiliary handlers, such as resource subscription, this is delegated to
73-
the client and can be accessed via the session. For example:
74-
mcp_session_group.get_session("server_name").subscribe_to_resource(...)
77+
the client and can be accessed via the session.
78+
79+
Example Usage:
80+
name_fn = lambda name, server_info: f"{(server_info.name)}-{name}"
81+
async with ClientSessionGroup(component_name_hook=name_fn) as group:
82+
for server_params in server_params:
83+
group.connect_to_server(server_param)
84+
...
85+
7586
"""
7687

7788
class _ComponentNames(BaseModel):
@@ -90,6 +101,7 @@ class _ComponentNames(BaseModel):
90101
_sessions: dict[mcp.ClientSession, _ComponentNames]
91102
_tool_to_session: dict[str, mcp.ClientSession]
92103
_exit_stack: contextlib.AsyncExitStack
104+
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
93105

94106
# Optional fn consuming (component_name, serverInfo) for custom names.
95107
# This is provide a means to mitigate naming conflicts across servers.
@@ -99,7 +111,7 @@ class _ComponentNames(BaseModel):
99111

100112
def __init__(
101113
self,
102-
exit_stack: contextlib.AsyncExitStack = contextlib.AsyncExitStack(),
114+
exit_stack: contextlib.AsyncExitStack | None = None,
103115
component_name_hook: _ComponentNameHook | None = None,
104116
) -> None:
105117
"""Initializes the MCP client."""
@@ -110,9 +122,43 @@ def __init__(
110122

111123
self._sessions = {}
112124
self._tool_to_session = {}
113-
self._exit_stack = exit_stack
125+
if exit_stack is None:
126+
self._exit_stack = contextlib.AsyncExitStack()
127+
self._owns_exit_stack = True
128+
else:
129+
self._exit_stack = exit_stack
130+
self._owns_exit_stack = False
131+
self._session_exit_stacks = {}
114132
self._component_name_hook = component_name_hook
115133

134+
async def __aenter__(self) -> Self:
135+
# Enter the exit stack only if we created it ourselves
136+
if self._owns_exit_stack:
137+
await self._exit_stack.__aenter__()
138+
return self
139+
140+
async def __aexit__(
141+
self,
142+
_exc_type: type[BaseException] | None,
143+
_exc_val: BaseException | None,
144+
_exc_tb: TracebackType | None,
145+
) -> bool | None:
146+
"""Closes session exit stacks and main exit stack upon completion."""
147+
148+
# Concurrently close session stacks.
149+
async with anyio.create_task_group() as tg:
150+
for exit_stack in self._session_exit_stacks.values():
151+
tg.start_soon(exit_stack.aclose)
152+
153+
# Only close the main exit stack if we created it
154+
if self._owns_exit_stack:
155+
await self._exit_stack.aclose()
156+
157+
@property
158+
def sessions(self) -> list[mcp.ClientSession]:
159+
"""Returns the list of sessions being managed."""
160+
return list(self._sessions.keys())
161+
116162
@property
117163
def prompts(self) -> dict[str, types.Prompt]:
118164
"""Returns the prompts as a dictionary of names to prompts."""
@@ -131,33 +177,45 @@ def tools(self) -> dict[str, types.Tool]:
131177
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
132178
"""Executes a tool given its name and arguments."""
133179
session = self._tool_to_session[name]
134-
return await session.call_tool(name, args)
180+
session_tool_name = self.tools[name].name
181+
return await session.call_tool(session_tool_name, args)
135182

136-
def disconnect_from_server(self, session: mcp.ClientSession) -> None:
183+
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
137184
"""Disconnects from a single MCP server."""
138185

139-
if session not in self._sessions:
186+
session_known_for_components = session in self._sessions
187+
session_known_for_stack = session in self._session_exit_stacks
188+
189+
if not session_known_for_components and not session_known_for_stack:
140190
raise McpError(
141191
types.ErrorData(
142192
code=types.INVALID_PARAMS,
143-
message="Provided session is not being managed.",
193+
message="Provided session is not managed or already disconnected.",
144194
)
145195
)
146-
component_names = self._sessions[session]
147-
148-
# Remove prompts associated with the session.
149-
for name in component_names.prompts:
150-
del self._prompts[name]
151-
152-
# Remove resources associated with the session.
153-
for name in component_names.resources:
154-
del self._resources[name]
155-
156-
# Remove tools associated with the session.
157-
for name in component_names.tools:
158-
del self._tools[name]
159196

160-
del self._sessions[session]
197+
if session_known_for_components:
198+
component_names = self._sessions.pop(session) # Pop from _sessions tracking
199+
200+
# Remove prompts associated with the session.
201+
for name in component_names.prompts:
202+
if name in self._prompts:
203+
del self._prompts[name]
204+
# Remove resources associated with the session.
205+
for name in component_names.resources:
206+
if name in self._resources:
207+
del self._resources[name]
208+
# Remove tools associated with the session.
209+
for name in component_names.tools:
210+
if name in self._tools:
211+
del self._tools[name]
212+
if name in self._tool_to_session:
213+
del self._tool_to_session[name]
214+
215+
# Clean up the session's resources via its dedicated exit stack
216+
if session_known_for_stack:
217+
session_stack_to_close = self._session_exit_stacks.pop(session)
218+
await session_stack_to_close.aclose()
161219

162220
async def connect_to_server(
163221
self,
@@ -181,47 +239,66 @@ async def connect_to_server(
181239
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
182240

183241
# Query the server for its prompts and aggregate to list.
184-
prompts = (await session.list_prompts()).prompts
185-
for prompt in prompts:
186-
name = self._component_name(prompt.name, server_info)
187-
if name in self._prompts:
188-
raise McpError(
189-
types.ErrorData(
190-
code=types.INVALID_PARAMS,
191-
message=f"{name} already exists in group prompts.",
192-
)
193-
)
194-
prompts_temp[name] = prompt
195-
component_names.prompts.add(name)
242+
try:
243+
prompts = (await session.list_prompts()).prompts
244+
for prompt in prompts:
245+
name = self._component_name(prompt.name, server_info)
246+
prompts_temp[name] = prompt
247+
component_names.prompts.add(name)
248+
except McpError as err:
249+
logging.warning(f"Could not fetch prompts: {err}")
196250

197251
# Query the server for its resources and aggregate to list.
198-
resources = (await session.list_resources()).resources
199-
for resource in resources:
200-
name = self._component_name(resource.name, server_info)
201-
if name in self._resources:
202-
raise McpError(
203-
types.ErrorData(
204-
code=types.INVALID_PARAMS,
205-
message=f"{name} already exists in group resources.",
206-
)
207-
)
208-
resources_temp[name] = resource
209-
component_names.resources.add(name)
252+
try:
253+
resources = (await session.list_resources()).resources
254+
for resource in resources:
255+
name = self._component_name(resource.name, server_info)
256+
resources_temp[name] = resource
257+
component_names.resources.add(name)
258+
except McpError as err:
259+
logging.warning(f"Could not fetch resources: {err}")
210260

211261
# Query the server for its tools and aggregate to list.
212-
tools = (await session.list_tools()).tools
213-
for tool in tools:
214-
name = self._component_name(tool.name, server_info)
215-
if name in self._tools:
216-
raise McpError(
217-
types.ErrorData(
218-
code=types.INVALID_PARAMS,
219-
message=f"{name} already exists in group tools.",
220-
)
262+
try:
263+
tools = (await session.list_tools()).tools
264+
for tool in tools:
265+
name = self._component_name(tool.name, server_info)
266+
tools_temp[name] = tool
267+
tool_to_session_temp[name] = session
268+
component_names.tools.add(name)
269+
except McpError as err:
270+
logging.warning(f"Could not fetch tools: {err}")
271+
272+
# Clean up exit stack for session if we couldn't retrieve anything
273+
# from the server.
274+
if not any((prompts_temp, resources_temp, tools_temp)):
275+
del self._session_exit_stacks[session]
276+
277+
# Check for duplicates.
278+
matching_prompts = prompts_temp.keys() & self._prompts.keys()
279+
if matching_prompts:
280+
raise McpError(
281+
types.ErrorData(
282+
code=types.INVALID_PARAMS,
283+
message=f"{matching_prompts} already exist in group prompts.",
284+
)
285+
)
286+
matching_resources = resources_temp.keys() & self._resources.keys()
287+
if matching_resources:
288+
raise McpError(
289+
types.ErrorData(
290+
code=types.INVALID_PARAMS,
291+
message=f"{matching_resources} already exist in group resources.",
292+
)
293+
)
294+
matching_tools = tools_temp.keys() & self._tools.keys()
295+
if matching_tools:
296+
raise McpError(
297+
types.ErrorData(
298+
code=types.INVALID_PARAMS,
299+
message=f"{matching_tools} already exist in group tools.",
221300
)
222-
tools_temp[name] = tool
223-
tool_to_session_temp[name] = session
224-
component_names.tools.add(name)
301+
)
225302

226303
# Aggregate components.
227304
self._sessions[session] = component_names
@@ -237,33 +314,48 @@ async def _establish_session(
237314
) -> tuple[types.Implementation, mcp.ClientSession]:
238315
"""Establish a client session to an MCP server."""
239316

240-
# Create read and write streams that facilitate io with the server.
241-
if isinstance(server_params, StdioServerParameters):
242-
client = mcp.stdio_client(server_params)
243-
read, write = await self._exit_stack.enter_async_context(client)
244-
elif isinstance(server_params, SseServerParameters):
245-
client = sse_client(
246-
url=server_params.url,
247-
headers=server_params.headers,
248-
timeout=server_params.timeout,
249-
sse_read_timeout=server_params.sse_read_timeout,
250-
)
251-
read, write = await self._exit_stack.enter_async_context(client)
252-
else:
253-
client = streamablehttp_client(
254-
url=server_params.url,
255-
headers=server_params.headers,
256-
timeout=server_params.timeout,
257-
sse_read_timeout=server_params.sse_read_timeout,
258-
terminate_on_close=server_params.terminate_on_close,
259-
)
260-
read, write, _ = await self._exit_stack.enter_async_context(client)
317+
session_stack = contextlib.AsyncExitStack()
318+
try:
319+
# Create read and write streams that facilitate io with the server.
320+
if isinstance(server_params, StdioServerParameters):
321+
client = mcp.stdio_client(server_params)
322+
read, write = await session_stack.enter_async_context(client)
323+
elif isinstance(server_params, SseServerParameters):
324+
client = sse_client(
325+
url=server_params.url,
326+
headers=server_params.headers,
327+
timeout=server_params.timeout,
328+
sse_read_timeout=server_params.sse_read_timeout,
329+
)
330+
read, write = await session_stack.enter_async_context(client)
331+
else:
332+
client = streamablehttp_client(
333+
url=server_params.url,
334+
headers=server_params.headers,
335+
timeout=server_params.timeout,
336+
sse_read_timeout=server_params.sse_read_timeout,
337+
terminate_on_close=server_params.terminate_on_close,
338+
)
339+
read, write, _ = await session_stack.enter_async_context(client)
261340

262-
session = await self._exit_stack.enter_async_context(
263-
mcp.ClientSession(read, write)
264-
)
265-
result = await session.initialize()
266-
return result.serverInfo, session
341+
session = await session_stack.enter_async_context(
342+
mcp.ClientSession(read, write)
343+
)
344+
result = await session.initialize()
345+
346+
# Session successfully initialized.
347+
# Store its stack and register the stack with the main group stack.
348+
self._session_exit_stacks[session] = session_stack
349+
# session_stack itself becomes a resource managed by the
350+
# main _exit_stack.
351+
await self._exit_stack.enter_async_context(session_stack)
352+
353+
return result.serverInfo, session
354+
except Exception:
355+
# If anything during this setup fails, ensure the session-specific
356+
# stack is closed.
357+
await session_stack.aclose()
358+
raise
267359

268360
def _component_name(self, name: str, server_info: types.Implementation) -> str:
269361
if self._component_name_hook:

0 commit comments

Comments
 (0)