Skip to content

Commit d4a28ca

Browse files
authored
Add support for async context management to ClientSessionGroup
This changes enables context management for setting up and tearing down async exit stacks durring server connection and disconnection respectively. Documentation has been added to show an example use case that demonstrates how `ClientSessionGroup` can be used with `async with`.
1 parent ef3feb2 commit d4a28ca

File tree

2 files changed

+191
-92
lines changed

2 files changed

+191
-92
lines changed

src/mcp/client/session_group.py

+173-82
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@
99
"""
1010

1111
import contextlib
12+
import logging
1213
from collections.abc import Callable
1314
from datetime import timedelta
15+
from types import TracebackType
1416
from typing import Any, TypeAlias
1517

18+
import anyio
1619
from pydantic import BaseModel
20+
from typing_extensions import Self
1721

1822
import mcp
1923
from mcp import types
@@ -72,6 +76,14 @@ class ClientSessionGroup:
7276
For auxiliary handlers, such as resource subscription, this is delegated to
7377
the client and can be accessed via the session. For example:
7478
mcp_session_group.get_session("server_name").subscribe_to_resource(...)
79+
80+
Example Usage:
81+
name_fn = lambda name, server_info: f"{(server_info.name)}-{name}"
82+
with async ClientSessionGroup(component_name_hook=name_fn) as group:
83+
for server_params in server_params:
84+
group.connect_to_server(server_param)
85+
...
86+
7587
"""
7688

7789
class _ComponentNames(BaseModel):
@@ -90,6 +102,7 @@ class _ComponentNames(BaseModel):
90102
_sessions: dict[mcp.ClientSession, _ComponentNames]
91103
_tool_to_session: dict[str, mcp.ClientSession]
92104
_exit_stack: contextlib.AsyncExitStack
105+
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
93106

94107
# Optional fn consuming (component_name, serverInfo) for custom names.
95108
# This is provide a means to mitigate naming conflicts across servers.
@@ -99,7 +112,7 @@ class _ComponentNames(BaseModel):
99112

100113
def __init__(
101114
self,
102-
exit_stack: contextlib.AsyncExitStack = contextlib.AsyncExitStack(),
115+
exit_stack: contextlib.AsyncExitStack | None = None,
103116
component_name_hook: _ComponentNameHook | None = None,
104117
) -> None:
105118
"""Initializes the MCP client."""
@@ -110,9 +123,41 @@ def __init__(
110123

111124
self._sessions = {}
112125
self._tool_to_session = {}
113-
self._exit_stack = exit_stack
126+
self._exit_stack = exit_stack or contextlib.AsyncExitStack()
127+
self._session_exit_stacks = {}
114128
self._component_name_hook = component_name_hook
115129

130+
async def __aenter__(self) -> Self:
131+
# If ClientSessionGroup itself is managing the lifecycle of _exit_stack
132+
# (i.e., it created it), it should enter it.
133+
# If _exit_stack was passed in, it's assumed the caller manages
134+
# its entry/exit.
135+
# For simplicity and consistency with how AsyncExitStack is often used when
136+
# provided as a dependency, we might not need to enter it here if it's
137+
# managed externally. However, if this class is the primary owner, entering it
138+
# ensures its 'aclose' is called even if passed in. Let's assume the
139+
# passed-in stack is already entered by the caller if needed.
140+
# For now, we just return self as the main stack's lifecycle is tied to aclose.
141+
return self
142+
143+
async def __aexit__(
144+
self,
145+
_exc_type: type[BaseException] | None,
146+
_exc_val: BaseException | None,
147+
_exc_tb: TracebackType | None,
148+
) -> bool | None:
149+
"""Closes session exit stacks and main exit stack upon completion."""
150+
async with anyio.create_task_group() as tg:
151+
for exit_stack in self._session_exit_stacks.values():
152+
tg.start_soon(exit_stack.aclose)
153+
await self._exit_stack.aclose()
154+
return None
155+
156+
@property
157+
def sessions(self) -> list[mcp.ClientSession]:
158+
"""Returns the list of sessions being managed."""
159+
return list(self._sessions.keys())
160+
116161
@property
117162
def prompts(self) -> dict[str, types.Prompt]:
118163
"""Returns the prompts as a dictionary of names to prompts."""
@@ -131,33 +176,45 @@ def tools(self) -> dict[str, types.Tool]:
131176
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
132177
"""Executes a tool given its name and arguments."""
133178
session = self._tool_to_session[name]
134-
return await session.call_tool(name, args)
179+
session_tool_name = self.tools[name].name
180+
return await session.call_tool(session_tool_name, args)
135181

136-
def disconnect_from_server(self, session: mcp.ClientSession) -> None:
182+
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
137183
"""Disconnects from a single MCP server."""
138184

139-
if session not in self._sessions:
185+
session_known_for_components = session in self._sessions
186+
session_known_for_stack = session in self._session_exit_stacks
187+
188+
if not session_known_for_components and not session_known_for_stack:
140189
raise McpError(
141190
types.ErrorData(
142191
code=types.INVALID_PARAMS,
143-
message="Provided session is not being managed.",
192+
message="Provided session is not managed or already disconnected.",
144193
)
145194
)
146-
component_names = self._sessions[session]
147-
148-
# Remove prompts associated with the session.
149-
for name in component_names.prompts:
150-
del self._prompts[name]
151195

152-
# Remove resources associated with the session.
153-
for name in component_names.resources:
154-
del self._resources[name]
155-
156-
# Remove tools associated with the session.
157-
for name in component_names.tools:
158-
del self._tools[name]
159-
160-
del self._sessions[session]
196+
if session_known_for_components:
197+
component_names = self._sessions.pop(session) # Pop from _sessions tracking
198+
199+
# Remove prompts associated with the session.
200+
for name in component_names.prompts:
201+
if name in self._prompts:
202+
del self._prompts[name]
203+
# Remove resources associated with the session.
204+
for name in component_names.resources:
205+
if name in self._resources:
206+
del self._resources[name]
207+
# Remove tools associated with the session.
208+
for name in component_names.tools:
209+
if name in self._tools:
210+
del self._tools[name]
211+
if name in self._tool_to_session:
212+
del self._tool_to_session[name]
213+
214+
# Clean up the session's resources via its dedicated exit stack
215+
if session_known_for_stack:
216+
session_stack_to_close = self._session_exit_stacks.pop(session)
217+
await session_stack_to_close.aclose()
161218

162219
async def connect_to_server(
163220
self,
@@ -181,47 +238,66 @@ async def connect_to_server(
181238
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
182239

183240
# Query the server for its prompts and aggregate to list.
184-
prompts = (await session.list_prompts()).prompts
185-
for prompt in prompts:
186-
name = self._component_name(prompt.name, server_info)
187-
if name in self._prompts:
188-
raise McpError(
189-
types.ErrorData(
190-
code=types.INVALID_PARAMS,
191-
message=f"{name} already exists in group prompts.",
192-
)
193-
)
194-
prompts_temp[name] = prompt
195-
component_names.prompts.add(name)
241+
try:
242+
prompts = (await session.list_prompts()).prompts
243+
for prompt in prompts:
244+
name = self._component_name(prompt.name, server_info)
245+
prompts_temp[name] = prompt
246+
component_names.prompts.add(name)
247+
except McpError as err:
248+
logging.warning(f"Could not fetch prompts: {err}")
196249

197250
# Query the server for its resources and aggregate to list.
198-
resources = (await session.list_resources()).resources
199-
for resource in resources:
200-
name = self._component_name(resource.name, server_info)
201-
if name in self._resources:
202-
raise McpError(
203-
types.ErrorData(
204-
code=types.INVALID_PARAMS,
205-
message=f"{name} already exists in group resources.",
206-
)
207-
)
208-
resources_temp[name] = resource
209-
component_names.resources.add(name)
251+
try:
252+
resources = (await session.list_resources()).resources
253+
for resource in resources:
254+
name = self._component_name(resource.name, server_info)
255+
resources_temp[name] = resource
256+
component_names.resources.add(name)
257+
except McpError as err:
258+
logging.warning(f"Could not fetch resources: {err}")
210259

211260
# Query the server for its tools and aggregate to list.
212-
tools = (await session.list_tools()).tools
213-
for tool in tools:
214-
name = self._component_name(tool.name, server_info)
215-
if name in self._tools:
216-
raise McpError(
217-
types.ErrorData(
218-
code=types.INVALID_PARAMS,
219-
message=f"{name} already exists in group tools.",
220-
)
261+
try:
262+
tools = (await session.list_tools()).tools
263+
for tool in tools:
264+
name = self._component_name(tool.name, server_info)
265+
tools_temp[name] = tool
266+
tool_to_session_temp[name] = session
267+
component_names.tools.add(name)
268+
except McpError as err:
269+
logging.warning(f"Could not fetch tools: {err}")
270+
271+
# Clean up exit stack for session if we couldn't retrieve anything
272+
# from the server.
273+
if not any((prompts_temp, resources_temp, tools_temp)):
274+
del self._session_exit_stacks[session]
275+
276+
# Check for duplicates.
277+
matching_prompts = prompts_temp.keys() & self._prompts.keys()
278+
if matching_prompts:
279+
raise McpError(
280+
types.ErrorData(
281+
code=types.INVALID_PARAMS,
282+
message=f"{matching_prompts} already exist in group prompts.",
221283
)
222-
tools_temp[name] = tool
223-
tool_to_session_temp[name] = session
224-
component_names.tools.add(name)
284+
)
285+
matching_resources = resources_temp.keys() & self._resources.keys()
286+
if matching_resources:
287+
raise McpError(
288+
types.ErrorData(
289+
code=types.INVALID_PARAMS,
290+
message=f"{matching_resources} already exist in group resources.",
291+
)
292+
)
293+
matching_tools = tools_temp.keys() & self._tools.keys()
294+
if matching_tools:
295+
raise McpError(
296+
types.ErrorData(
297+
code=types.INVALID_PARAMS,
298+
message=f"{matching_tools} already exist in group tools.",
299+
)
300+
)
225301

226302
# Aggregate components.
227303
self._sessions[session] = component_names
@@ -237,33 +313,48 @@ async def _establish_session(
237313
) -> tuple[types.Implementation, mcp.ClientSession]:
238314
"""Establish a client session to an MCP server."""
239315

240-
# Create read and write streams that facilitate io with the server.
241-
if isinstance(server_params, StdioServerParameters):
242-
client = mcp.stdio_client(server_params)
243-
read, write = await self._exit_stack.enter_async_context(client)
244-
elif isinstance(server_params, SseServerParameters):
245-
client = sse_client(
246-
url=server_params.url,
247-
headers=server_params.headers,
248-
timeout=server_params.timeout,
249-
sse_read_timeout=server_params.sse_read_timeout,
250-
)
251-
read, write = await self._exit_stack.enter_async_context(client)
252-
else:
253-
client = streamablehttp_client(
254-
url=server_params.url,
255-
headers=server_params.headers,
256-
timeout=server_params.timeout,
257-
sse_read_timeout=server_params.sse_read_timeout,
258-
terminate_on_close=server_params.terminate_on_close,
259-
)
260-
read, write, _ = await self._exit_stack.enter_async_context(client)
316+
session_specific_stack = contextlib.AsyncExitStack()
317+
try:
318+
# Create read and write streams that facilitate io with the server.
319+
if isinstance(server_params, StdioServerParameters):
320+
client = mcp.stdio_client(server_params)
321+
read, write = await self._exit_stack.enter_async_context(client)
322+
elif isinstance(server_params, SseServerParameters):
323+
client = sse_client(
324+
url=server_params.url,
325+
headers=server_params.headers,
326+
timeout=server_params.timeout,
327+
sse_read_timeout=server_params.sse_read_timeout,
328+
)
329+
read, write = await self._exit_stack.enter_async_context(client)
330+
else:
331+
client = streamablehttp_client(
332+
url=server_params.url,
333+
headers=server_params.headers,
334+
timeout=server_params.timeout,
335+
sse_read_timeout=server_params.sse_read_timeout,
336+
terminate_on_close=server_params.terminate_on_close,
337+
)
338+
read, write, _ = await self._exit_stack.enter_async_context(client)
261339

262-
session = await self._exit_stack.enter_async_context(
263-
mcp.ClientSession(read, write)
264-
)
265-
result = await session.initialize()
266-
return result.serverInfo, session
340+
session = await self._exit_stack.enter_async_context(
341+
mcp.ClientSession(read, write)
342+
)
343+
result = await session.initialize()
344+
345+
# Session successfully initialized.
346+
# Store its stack and register the stack with the main group stack.
347+
self._session_exit_stacks[session] = session_specific_stack
348+
# session_specific_stack itself becomes a resource managed by the
349+
# main _exit_stack.
350+
await self._exit_stack.enter_async_context(session_specific_stack)
351+
352+
return result.serverInfo, session
353+
except Exception:
354+
# If anything during this setup fails, ensure the session-specific
355+
# stack is closed.
356+
await session_specific_stack.aclose()
357+
raise
267358

268359
def _component_name(self, name: str, server_info: types.Implementation) -> str:
269360
if self._component_name_hook:

0 commit comments

Comments
 (0)