Skip to content

Commit fcbdd79

Browse files
committed
Add support for async context management to ClientSessionGroup
This changes enables context management for setting up and tearing down async exit stacks durring server connection and disconnection respectively. Documentation has been added to show an example use case that demonstrates how `ClientSessionGroup` can be used with `async with`.
1 parent ef3feb2 commit fcbdd79

File tree

2 files changed

+181
-92
lines changed

2 files changed

+181
-92
lines changed

src/mcp/client/session_group.py

+163-82
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
"""
1010

1111
import contextlib
12+
import logging
1213
from collections.abc import Callable
1314
from datetime import timedelta
15+
from types import TracebackType
1416
from typing import Any, TypeAlias
1517

1618
from pydantic import BaseModel
19+
from typing_extensions import Self
1720

1821
import mcp
1922
from mcp import types
@@ -72,6 +75,14 @@ class ClientSessionGroup:
7275
For auxiliary handlers, such as resource subscription, this is delegated to
7376
the client and can be accessed via the session. For example:
7477
mcp_session_group.get_session("server_name").subscribe_to_resource(...)
78+
79+
Example Usage:
80+
name_fn = lambda name, server_info: f"{(server_info.name)}.{name}"
81+
with async ClientSessionGroup(component_name_hook=name_fn) as group:
82+
for server_params in server_params:
83+
group.connect_to_server(server_param)
84+
...
85+
7586
"""
7687

7788
class _ComponentNames(BaseModel):
@@ -90,6 +101,7 @@ class _ComponentNames(BaseModel):
90101
_sessions: dict[mcp.ClientSession, _ComponentNames]
91102
_tool_to_session: dict[str, mcp.ClientSession]
92103
_exit_stack: contextlib.AsyncExitStack
104+
_session_exit_stacks: dict[mcp.ClientSession, contextlib.AsyncExitStack]
93105

94106
# Optional fn consuming (component_name, serverInfo) for custom names.
95107
# This is provide a means to mitigate naming conflicts across servers.
@@ -99,7 +111,7 @@ class _ComponentNames(BaseModel):
99111

100112
def __init__(
101113
self,
102-
exit_stack: contextlib.AsyncExitStack = contextlib.AsyncExitStack(),
114+
exit_stack: contextlib.AsyncExitStack | None = None,
103115
component_name_hook: _ComponentNameHook | None = None,
104116
) -> None:
105117
"""Initializes the MCP client."""
@@ -110,9 +122,32 @@ def __init__(
110122

111123
self._sessions = {}
112124
self._tool_to_session = {}
113-
self._exit_stack = exit_stack
125+
self._exit_stack = exit_stack or contextlib.AsyncExitStack()
126+
self._session_exit_stacks = {}
114127
self._component_name_hook = component_name_hook
115128

129+
async def __aenter__(self) -> Self:
130+
# If ClientSessionGroup itself is managing the lifecycle of _exit_stack
131+
# (i.e., it created it), it should enter it.
132+
# If _exit_stack was passed in, it's assumed the caller manages
133+
# its entry/exit.
134+
# For simplicity and consistency with how AsyncExitStack is often used when
135+
# provided as a dependency, we might not need to enter it here if it's
136+
# managed externally. However, if this class is the primary owner, entering it
137+
# ensures its 'aclose' is called even if passed in. Let's assume the
138+
# passed-in stack is already entered by the caller if needed.
139+
# For now, we just return self as the main stack's lifecycle is tied to aclose.
140+
return self
141+
142+
async def __aexit__(
143+
self,
144+
_exc_type: type[BaseException] | None,
145+
_exc_val: BaseException | None,
146+
_exc_tb: TracebackType | None,
147+
) -> bool | None:
148+
await self._exit_stack.aclose()
149+
return None # Do not suppress exceptio
150+
116151
@property
117152
def prompts(self) -> dict[str, types.Prompt]:
118153
"""Returns the prompts as a dictionary of names to prompts."""
@@ -131,33 +166,45 @@ def tools(self) -> dict[str, types.Tool]:
131166
async def call_tool(self, name: str, args: dict[str, Any]) -> types.CallToolResult:
132167
"""Executes a tool given its name and arguments."""
133168
session = self._tool_to_session[name]
134-
return await session.call_tool(name, args)
169+
session_tool_name = self.tools[name].name
170+
return await session.call_tool(session_tool_name, args)
135171

136-
def disconnect_from_server(self, session: mcp.ClientSession) -> None:
172+
async def disconnect_from_server(self, session: mcp.ClientSession) -> None:
137173
"""Disconnects from a single MCP server."""
138174

139-
if session not in self._sessions:
175+
session_known_for_components = session in self._sessions
176+
session_known_for_stack = session in self._session_exit_stacks
177+
178+
if not session_known_for_components and not session_known_for_stack:
140179
raise McpError(
141180
types.ErrorData(
142181
code=types.INVALID_PARAMS,
143-
message="Provided session is not being managed.",
182+
message="Provided session is not managed or already disconnected.",
144183
)
145184
)
146-
component_names = self._sessions[session]
147-
148-
# Remove prompts associated with the session.
149-
for name in component_names.prompts:
150-
del self._prompts[name]
151-
152-
# Remove resources associated with the session.
153-
for name in component_names.resources:
154-
del self._resources[name]
155-
156-
# Remove tools associated with the session.
157-
for name in component_names.tools:
158-
del self._tools[name]
159185

160-
del self._sessions[session]
186+
if session_known_for_components:
187+
component_names = self._sessions.pop(session) # Pop from _sessions tracking
188+
189+
# Remove prompts associated with the session.
190+
for name in component_names.prompts:
191+
if name in self._prompts:
192+
del self._prompts[name]
193+
# Remove resources associated with the session.
194+
for name in component_names.resources:
195+
if name in self._resources:
196+
del self._resources[name]
197+
# Remove tools associated with the session.
198+
for name in component_names.tools:
199+
if name in self._tools:
200+
del self._tools[name]
201+
if name in self._tool_to_session:
202+
del self._tool_to_session[name]
203+
204+
# Clean up the session's resources via its dedicated exit stack
205+
if session_known_for_stack:
206+
session_stack_to_close = self._session_exit_stacks.pop(session)
207+
await session_stack_to_close.aclose()
161208

162209
async def connect_to_server(
163210
self,
@@ -181,47 +228,66 @@ async def connect_to_server(
181228
tool_to_session_temp: dict[str, mcp.ClientSession] = {}
182229

183230
# Query the server for its prompts and aggregate to list.
184-
prompts = (await session.list_prompts()).prompts
185-
for prompt in prompts:
186-
name = self._component_name(prompt.name, server_info)
187-
if name in self._prompts:
188-
raise McpError(
189-
types.ErrorData(
190-
code=types.INVALID_PARAMS,
191-
message=f"{name} already exists in group prompts.",
192-
)
193-
)
194-
prompts_temp[name] = prompt
195-
component_names.prompts.add(name)
231+
try:
232+
prompts = (await session.list_prompts()).prompts
233+
for prompt in prompts:
234+
name = self._component_name(prompt.name, server_info)
235+
prompts_temp[name] = prompt
236+
component_names.prompts.add(name)
237+
except McpError as err:
238+
logging.warning(f"Could not fetch prompts: {err}")
196239

197240
# Query the server for its resources and aggregate to list.
198-
resources = (await session.list_resources()).resources
199-
for resource in resources:
200-
name = self._component_name(resource.name, server_info)
201-
if name in self._resources:
202-
raise McpError(
203-
types.ErrorData(
204-
code=types.INVALID_PARAMS,
205-
message=f"{name} already exists in group resources.",
206-
)
207-
)
208-
resources_temp[name] = resource
209-
component_names.resources.add(name)
241+
try:
242+
resources = (await session.list_resources()).resources
243+
for resource in resources:
244+
name = self._component_name(resource.name, server_info)
245+
resources_temp[name] = resource
246+
component_names.resources.add(name)
247+
except McpError as err:
248+
logging.warning(f"Could not fetch resources: {err}")
210249

211250
# Query the server for its tools and aggregate to list.
212-
tools = (await session.list_tools()).tools
213-
for tool in tools:
214-
name = self._component_name(tool.name, server_info)
215-
if name in self._tools:
216-
raise McpError(
217-
types.ErrorData(
218-
code=types.INVALID_PARAMS,
219-
message=f"{name} already exists in group tools.",
220-
)
251+
try:
252+
tools = (await session.list_tools()).tools
253+
for tool in tools:
254+
name = self._component_name(tool.name, server_info)
255+
tools_temp[name] = tool
256+
tool_to_session_temp[name] = session
257+
component_names.tools.add(name)
258+
except McpError as err:
259+
logging.warning(f"Could not fetch tools: {err}")
260+
261+
# Clean up exit stack for session if we couldn't retrieve anything
262+
# from the server.
263+
if not any((prompts_temp, resources_temp, tools_temp)):
264+
del self._session_exit_stacks[session]
265+
266+
# Check for duplicates.
267+
matching_prompts = prompts_temp.keys() & self._prompts.keys()
268+
if matching_prompts:
269+
raise McpError(
270+
types.ErrorData(
271+
code=types.INVALID_PARAMS,
272+
message=f"{matching_prompts} already exist in group prompts.",
221273
)
222-
tools_temp[name] = tool
223-
tool_to_session_temp[name] = session
224-
component_names.tools.add(name)
274+
)
275+
matching_resources = resources_temp.keys() & self._resources.keys()
276+
if matching_resources:
277+
raise McpError(
278+
types.ErrorData(
279+
code=types.INVALID_PARAMS,
280+
message=f"{matching_resources} already exist in group resources.",
281+
)
282+
)
283+
matching_tools = tools_temp.keys() & self._tools.keys()
284+
if matching_tools:
285+
raise McpError(
286+
types.ErrorData(
287+
code=types.INVALID_PARAMS,
288+
message=f"{matching_tools} already exist in group tools.",
289+
)
290+
)
225291

226292
# Aggregate components.
227293
self._sessions[session] = component_names
@@ -237,33 +303,48 @@ async def _establish_session(
237303
) -> tuple[types.Implementation, mcp.ClientSession]:
238304
"""Establish a client session to an MCP server."""
239305

240-
# Create read and write streams that facilitate io with the server.
241-
if isinstance(server_params, StdioServerParameters):
242-
client = mcp.stdio_client(server_params)
243-
read, write = await self._exit_stack.enter_async_context(client)
244-
elif isinstance(server_params, SseServerParameters):
245-
client = sse_client(
246-
url=server_params.url,
247-
headers=server_params.headers,
248-
timeout=server_params.timeout,
249-
sse_read_timeout=server_params.sse_read_timeout,
250-
)
251-
read, write = await self._exit_stack.enter_async_context(client)
252-
else:
253-
client = streamablehttp_client(
254-
url=server_params.url,
255-
headers=server_params.headers,
256-
timeout=server_params.timeout,
257-
sse_read_timeout=server_params.sse_read_timeout,
258-
terminate_on_close=server_params.terminate_on_close,
259-
)
260-
read, write, _ = await self._exit_stack.enter_async_context(client)
306+
session_specific_stack = contextlib.AsyncExitStack()
307+
try:
308+
# Create read and write streams that facilitate io with the server.
309+
if isinstance(server_params, StdioServerParameters):
310+
client = mcp.stdio_client(server_params)
311+
read, write = await self._exit_stack.enter_async_context(client)
312+
elif isinstance(server_params, SseServerParameters):
313+
client = sse_client(
314+
url=server_params.url,
315+
headers=server_params.headers,
316+
timeout=server_params.timeout,
317+
sse_read_timeout=server_params.sse_read_timeout,
318+
)
319+
read, write = await self._exit_stack.enter_async_context(client)
320+
else:
321+
client = streamablehttp_client(
322+
url=server_params.url,
323+
headers=server_params.headers,
324+
timeout=server_params.timeout,
325+
sse_read_timeout=server_params.sse_read_timeout,
326+
terminate_on_close=server_params.terminate_on_close,
327+
)
328+
read, write, _ = await self._exit_stack.enter_async_context(client)
261329

262-
session = await self._exit_stack.enter_async_context(
263-
mcp.ClientSession(read, write)
264-
)
265-
result = await session.initialize()
266-
return result.serverInfo, session
330+
session = await self._exit_stack.enter_async_context(
331+
mcp.ClientSession(read, write)
332+
)
333+
result = await session.initialize()
334+
335+
# Session successfully initialized.
336+
# Store its stack and register the stack with the main group stack.
337+
self._session_exit_stacks[session] = session_specific_stack
338+
# session_specific_stack itself becomes a resource managed by the
339+
# main _exit_stack.
340+
await self._exit_stack.enter_async_context(session_specific_stack)
341+
342+
return result.serverInfo, session
343+
except Exception:
344+
# If anything during this setup fails, ensure the session-specific
345+
# stack is closed.
346+
await session_specific_stack.aclose()
347+
raise
267348

268349
def _component_name(self, name: str, server_info: types.Implementation) -> str:
269350
if self._component_name_hook:

0 commit comments

Comments
 (0)