Skip to content

Commit cca631c

Browse files
committed
Add support for async context management to ClientSessionGroup
This changes enables context management for setting up and tearing down async exit stacks durring server connection and disconnection respectively. Documentation has been added to show an example use case that demonstrates how `ClientSessionGroup` can be used with `async with`.
1 parent ef3feb2 commit cca631c

File tree

2 files changed

+186
-92
lines changed

2 files changed

+186
-92
lines changed

src/mcp/client/session_group.py

+168-82
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@
99
"""
1010

1111
import contextlib
12+
import logging
1213
from collections.abc import Callable
1314
from datetime import timedelta
15+
from types import TracebackType
1416
from typing import Any, TypeAlias
1517

18+
import anyio
1619
from pydantic import BaseModel
20+
from typing_extensions import Self
1721

1822
import mcp
1923
from mcp import types
@@ -72,6 +76,14 @@ class ClientSessionGroup:
7276
For auxiliary handlers, such as resource subscription, this is delegated to
7377
the client and can be accessed via the session. For example:
7478
mcp_session_group.get_session("server_name").subscribe_to_resource(...)
79+
80+
Example Usage:
81+
name_fn = lambda name, server_info: f"{(server_info.name)}-{name}"
82+
with async ClientSessionGroup(component_name_hook=name_fn) as group:
83+
for server_params in server_params:
84+
group.connect_to_server(server_param)
85+
...
86+
7587
"""
7688

7789
class _ComponentNames(BaseModel):
@@ -90,6 +102,7 @@ class _ComponentNames(BaseModel):
90102
_sessions: dict[mcp.ClientSession, _ComponentNames]
91103
_tool_to_session: dict[str, mcp.ClientSession]
92104
_exit_stack: contextlib.AsyncExitStack
105+
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
93106

94107
# Optional fn consuming (component_name, serverInfo) for custom names.
95108
# This is provide a means to mitigate naming conflicts across servers.
@@ -99,7 +112,7 @@ class _ComponentNames(BaseModel):
99112

100113
def __init__(
101114
self,
102-
exit_stack: contextlib.AsyncExitStack = contextlib.AsyncExitStack(),
115+
exit_stack: contextlib.AsyncExitStack | None = None,
103116
component_name_hook: _ComponentNameHook | None = None,
104117
) -> None:
105118
"""Initializes the MCP client."""
@@ -110,9 +123,36 @@ def __init__(
110123

111124
self._sessions = {}
112125
self._tool_to_session = {}
113-
self._exit_stack = exit_stack
126+
self._exit_stack = exit_stack or contextlib.AsyncExitStack()
127+
self._session_exit_stacks = {}
114128
self._component_name_hook = component_name_hook
115129

130+
async def __aenter__(self) -> Self:
131+
# If ClientSessionGroup itself is managing the lifecycle of _exit_stack
132+
# (i.e., it created it), it should enter it.
133+
# If _exit_stack was passed in, it's assumed the caller manages
134+
# its entry/exit.
135+
# For simplicity and consistency with how AsyncExitStack is often used when
136+
# provided as a dependency, we might not need to enter it here if it's
137+
# managed externally. However, if this class is the primary owner, entering it
138+
# ensures its 'aclose' is called even if passed in. Let's assume the
139+
# passed-in stack is already entered by the caller if needed.
140+
# For now, we just return self as the main stack's lifecycle is tied to aclose.
141+
return self
142+
143+
async def __aexit__(
144+
self,
145+
_exc_type: type[BaseException] | None,
146+
_exc_val: BaseException | None,
147+
_exc_tb: TracebackType | None,
148+
) -> bool | None:
149+
"""Closes session exit stacks and main exit stack upon completion."""
150+
async with anyio.create_task_group() as tg:
151+
for exit_stack in self._session_exit_stacks.values():
152+
tg.start_soon(exit_stack.aclose)
153+
await self._exit_stack.aclose()
154+
return None
155+
116156
@property
117157
def prompts(self) -> dict[str, types.Prompt]:
118158
"""Returns the prompts as a dictionary of names to prompts."""
@@ -131,33 +171,45 @@ def tools(self) -> dict[str, types.Tool]:
131171
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
132172
"""Executes a tool given its name and arguments."""
133173
session = self._tool_to_session[name]
134-
return await session.call_tool(name, args)
174+
session_tool_name = self.tools[name].name
175+
return await session.call_tool(session_tool_name, args)
135176

136-
def disconnect_from_server(self, session: mcp.ClientSession) -> None:
177+
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
137178
"""Disconnects from a single MCP server."""
138179

139-
if session not in self._sessions:
180+
session_known_for_components = session in self._sessions
181+
session_known_for_stack = session in self._session_exit_stacks
182+
183+
if not session_known_for_components and not session_known_for_stack:
140184
raise McpError(
141185
types.ErrorData(
142186
code=types.INVALID_PARAMS,
143-
message="Provided session is not being managed.",
187+
message="Provided session is not managed or already disconnected.",
144188
)
145189
)
146-
component_names = self._sessions[session]
147-
148-
# Remove prompts associated with the session.
149-
for name in component_names.prompts:
150-
del self._prompts[name]
151-
152-
# Remove resources associated with the session.
153-
for name in component_names.resources:
154-
del self._resources[name]
155-
156-
# Remove tools associated with the session.
157-
for name in component_names.tools:
158-
del self._tools[name]
159190

160-
del self._sessions[session]
191+
if session_known_for_components:
192+
component_names = self._sessions.pop(session) # Pop from _sessions tracking
193+
194+
# Remove prompts associated with the session.
195+
for name in component_names.prompts:
196+
if name in self._prompts:
197+
del self._prompts[name]
198+
# Remove resources associated with the session.
199+
for name in component_names.resources:
200+
if name in self._resources:
201+
del self._resources[name]
202+
# Remove tools associated with the session.
203+
for name in component_names.tools:
204+
if name in self._tools:
205+
del self._tools[name]
206+
if name in self._tool_to_session:
207+
del self._tool_to_session[name]
208+
209+
# Clean up the session's resources via its dedicated exit stack
210+
if session_known_for_stack:
211+
session_stack_to_close = self._session_exit_stacks.pop(session)
212+
await session_stack_to_close.aclose()
161213

162214
async def connect_to_server(
163215
self,
@@ -181,47 +233,66 @@ async def connect_to_server(
181233
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
182234

183235
# Query the server for its prompts and aggregate to list.
184-
prompts = (await session.list_prompts()).prompts
185-
for prompt in prompts:
186-
name = self._component_name(prompt.name, server_info)
187-
if name in self._prompts:
188-
raise McpError(
189-
types.ErrorData(
190-
code=types.INVALID_PARAMS,
191-
message=f"{name} already exists in group prompts.",
192-
)
193-
)
194-
prompts_temp[name] = prompt
195-
component_names.prompts.add(name)
236+
try:
237+
prompts = (await session.list_prompts()).prompts
238+
for prompt in prompts:
239+
name = self._component_name(prompt.name, server_info)
240+
prompts_temp[name] = prompt
241+
component_names.prompts.add(name)
242+
except McpError as err:
243+
logging.warning(f"Could not fetch prompts: {err}")
196244

197245
# Query the server for its resources and aggregate to list.
198-
resources = (await session.list_resources()).resources
199-
for resource in resources:
200-
name = self._component_name(resource.name, server_info)
201-
if name in self._resources:
202-
raise McpError(
203-
types.ErrorData(
204-
code=types.INVALID_PARAMS,
205-
message=f"{name} already exists in group resources.",
206-
)
207-
)
208-
resources_temp[name] = resource
209-
component_names.resources.add(name)
246+
try:
247+
resources = (await session.list_resources()).resources
248+
for resource in resources:
249+
name = self._component_name(resource.name, server_info)
250+
resources_temp[name] = resource
251+
component_names.resources.add(name)
252+
except McpError as err:
253+
logging.warning(f"Could not fetch resources: {err}")
210254

211255
# Query the server for its tools and aggregate to list.
212-
tools = (await session.list_tools()).tools
213-
for tool in tools:
214-
name = self._component_name(tool.name, server_info)
215-
if name in self._tools:
216-
raise McpError(
217-
types.ErrorData(
218-
code=types.INVALID_PARAMS,
219-
message=f"{name} already exists in group tools.",
220-
)
256+
try:
257+
tools = (await session.list_tools()).tools
258+
for tool in tools:
259+
name = self._component_name(tool.name, server_info)
260+
tools_temp[name] = tool
261+
tool_to_session_temp[name] = session
262+
component_names.tools.add(name)
263+
except McpError as err:
264+
logging.warning(f"Could not fetch tools: {err}")
265+
266+
# Clean up exit stack for session if we couldn't retrieve anything
267+
# from the server.
268+
if not any((prompts_temp, resources_temp, tools_temp)):
269+
del self._session_exit_stacks[session]
270+
271+
# Check for duplicates.
272+
matching_prompts = prompts_temp.keys() & self._prompts.keys()
273+
if matching_prompts:
274+
raise McpError(
275+
types.ErrorData(
276+
code=types.INVALID_PARAMS,
277+
message=f"{matching_prompts} already exist in group prompts.",
221278
)
222-
tools_temp[name] = tool
223-
tool_to_session_temp[name] = session
224-
component_names.tools.add(name)
279+
)
280+
matching_resources = resources_temp.keys() & self._resources.keys()
281+
if matching_resources:
282+
raise McpError(
283+
types.ErrorData(
284+
code=types.INVALID_PARAMS,
285+
message=f"{matching_resources} already exist in group resources.",
286+
)
287+
)
288+
matching_tools = tools_temp.keys() & self._tools.keys()
289+
if matching_tools:
290+
raise McpError(
291+
types.ErrorData(
292+
code=types.INVALID_PARAMS,
293+
message=f"{matching_tools} already exist in group tools.",
294+
)
295+
)
225296

226297
# Aggregate components.
227298
self._sessions[session] = component_names
@@ -237,33 +308,48 @@ async def _establish_session(
237308
) -> tuple[types.Implementation, mcp.ClientSession]:
238309
"""Establish a client session to an MCP server."""
239310

240-
# Create read and write streams that facilitate io with the server.
241-
if isinstance(server_params, StdioServerParameters):
242-
client = mcp.stdio_client(server_params)
243-
read, write = await self._exit_stack.enter_async_context(client)
244-
elif isinstance(server_params, SseServerParameters):
245-
client = sse_client(
246-
url=server_params.url,
247-
headers=server_params.headers,
248-
timeout=server_params.timeout,
249-
sse_read_timeout=server_params.sse_read_timeout,
250-
)
251-
read, write = await self._exit_stack.enter_async_context(client)
252-
else:
253-
client = streamablehttp_client(
254-
url=server_params.url,
255-
headers=server_params.headers,
256-
timeout=server_params.timeout,
257-
sse_read_timeout=server_params.sse_read_timeout,
258-
terminate_on_close=server_params.terminate_on_close,
259-
)
260-
read, write, _ = await self._exit_stack.enter_async_context(client)
311+
session_specific_stack = contextlib.AsyncExitStack()
312+
try:
313+
# Create read and write streams that facilitate io with the server.
314+
if isinstance(server_params, StdioServerParameters):
315+
client = mcp.stdio_client(server_params)
316+
read, write = await self._exit_stack.enter_async_context(client)
317+
elif isinstance(server_params, SseServerParameters):
318+
client = sse_client(
319+
url=server_params.url,
320+
headers=server_params.headers,
321+
timeout=server_params.timeout,
322+
sse_read_timeout=server_params.sse_read_timeout,
323+
)
324+
read, write = await self._exit_stack.enter_async_context(client)
325+
else:
326+
client = streamablehttp_client(
327+
url=server_params.url,
328+
headers=server_params.headers,
329+
timeout=server_params.timeout,
330+
sse_read_timeout=server_params.sse_read_timeout,
331+
terminate_on_close=server_params.terminate_on_close,
332+
)
333+
read, write, _ = await self._exit_stack.enter_async_context(client)
261334

262-
session = await self._exit_stack.enter_async_context(
263-
mcp.ClientSession(read, write)
264-
)
265-
result = await session.initialize()
266-
return result.serverInfo, session
335+
session = await self._exit_stack.enter_async_context(
336+
mcp.ClientSession(read, write)
337+
)
338+
result = await session.initialize()
339+
340+
# Session successfully initialized.
341+
# Store its stack and register the stack with the main group stack.
342+
self._session_exit_stacks[session] = session_specific_stack
343+
# session_specific_stack itself becomes a resource managed by the
344+
# main _exit_stack.
345+
await self._exit_stack.enter_async_context(session_specific_stack)
346+
347+
return result.serverInfo, session
348+
except Exception:
349+
# If anything during this setup fails, ensure the session-specific
350+
# stack is closed.
351+
await session_specific_stack.aclose()
352+
raise
267353

268354
def _component_name(self, name: str, server_info: types.Implementation) -> str:
269355
if self._component_name_hook:

0 commit comments

Comments
 (0)