9
9
"""
10
10
11
11
import contextlib
12
+ import logging
12
13
from collections .abc import Callable
13
14
from datetime import timedelta
15
+ from types import TracebackType
14
16
from typing import Any , TypeAlias
15
17
18
+ import anyio
16
19
from pydantic import BaseModel
20
+ from typing_extensions import Self
17
21
18
22
import mcp
19
23
from mcp import types
@@ -72,6 +76,14 @@ class ClientSessionGroup:
72
76
For auxiliary handlers, such as resource subscription, this is delegated to
73
77
the client and can be accessed via the session. For example:
74
78
mcp_session_group.get_session("server_name").subscribe_to_resource(...)
79
+
80
+ Example Usage:
81
+ name_fn = lambda name, server_info: f"{(server_info.name)}-{name}"
82
+ with async ClientSessionGroup(component_name_hook=name_fn) as group:
83
+ for server_params in server_params:
84
+ group.connect_to_server(server_param)
85
+ ...
86
+
75
87
"""
76
88
77
89
class _ComponentNames (BaseModel ):
@@ -90,6 +102,7 @@ class _ComponentNames(BaseModel):
90
102
_sessions : dict [mcp .ClientSession , _ComponentNames ]
91
103
_tool_to_session : dict [str , mcp .ClientSession ]
92
104
_exit_stack : contextlib .AsyncExitStack
105
+ _session_exit_stacks : dict [mcp .ClientSession , contextlib .AsyncExitStack ]
93
106
94
107
# Optional fn consuming (component_name, serverInfo) for custom names.
95
108
# This is provide a means to mitigate naming conflicts across servers.
@@ -99,7 +112,7 @@ class _ComponentNames(BaseModel):
99
112
100
113
def __init__ (
101
114
self ,
102
- exit_stack : contextlib .AsyncExitStack = contextlib . AsyncExitStack () ,
115
+ exit_stack : contextlib .AsyncExitStack | None = None ,
103
116
component_name_hook : _ComponentNameHook | None = None ,
104
117
) -> None :
105
118
"""Initializes the MCP client."""
@@ -110,9 +123,41 @@ def __init__(
110
123
111
124
self ._sessions = {}
112
125
self ._tool_to_session = {}
113
- self ._exit_stack = exit_stack
126
+ self ._exit_stack = exit_stack or contextlib .AsyncExitStack ()
127
+ self ._session_exit_stacks = {}
114
128
self ._component_name_hook = component_name_hook
115
129
130
+ async def __aenter__ (self ) -> Self :
131
+ # If ClientSessionGroup itself is managing the lifecycle of _exit_stack
132
+ # (i.e., it created it), it should enter it.
133
+ # If _exit_stack was passed in, it's assumed the caller manages
134
+ # its entry/exit.
135
+ # For simplicity and consistency with how AsyncExitStack is often used when
136
+ # provided as a dependency, we might not need to enter it here if it's
137
+ # managed externally. However, if this class is the primary owner, entering it
138
+ # ensures its 'aclose' is called even if passed in. Let's assume the
139
+ # passed-in stack is already entered by the caller if needed.
140
+ # For now, we just return self as the main stack's lifecycle is tied to aclose.
141
+ return self
142
+
143
+ async def __aexit__ (
144
+ self ,
145
+ _exc_type : type [BaseException ] | None ,
146
+ _exc_val : BaseException | None ,
147
+ _exc_tb : TracebackType | None ,
148
+ ) -> bool | None :
149
+ """Closes session exit stacks and main exit stack upon completion."""
150
+ async with anyio .create_task_group () as tg :
151
+ for exit_stack in self ._session_exit_stacks .values ():
152
+ tg .start_soon (exit_stack .aclose )
153
+ await self ._exit_stack .aclose ()
154
+ return None
155
+
156
+ @property
157
+ def sessions (self ) -> list [mcp .ClientSession ]:
158
+ """Returns the list of sessions being managed."""
159
+ return list (self ._sessions .keys ())
160
+
116
161
@property
117
162
def prompts (self ) -> dict [str , types .Prompt ]:
118
163
"""Returns the prompts as a dictionary of names to prompts."""
@@ -131,33 +176,45 @@ def tools(self) -> dict[str, types.Tool]:
131
176
async def call_tool (self , name : str , args : dict [str , Any ]) -> types .CallToolResult :
132
177
"""Executes a tool given its name and arguments."""
133
178
session = self ._tool_to_session [name ]
134
- return await session .call_tool (name , args )
179
+ session_tool_name = self .tools [name ].name
180
+ return await session .call_tool (session_tool_name , args )
135
181
136
- def disconnect_from_server (self , session : mcp .ClientSession ) -> None :
182
+ async def disconnect_from_server (self , session : mcp .ClientSession ) -> None :
137
183
"""Disconnects from a single MCP server."""
138
184
139
- if session not in self ._sessions :
185
+ session_known_for_components = session in self ._sessions
186
+ session_known_for_stack = session in self ._session_exit_stacks
187
+
188
+ if not session_known_for_components and not session_known_for_stack :
140
189
raise McpError (
141
190
types .ErrorData (
142
191
code = types .INVALID_PARAMS ,
143
- message = "Provided session is not being managed." ,
192
+ message = "Provided session is not managed or already disconnected ." ,
144
193
)
145
194
)
146
- component_names = self ._sessions [session ]
147
-
148
- # Remove prompts associated with the session.
149
- for name in component_names .prompts :
150
- del self ._prompts [name ]
151
195
152
- # Remove resources associated with the session.
153
- for name in component_names .resources :
154
- del self ._resources [name ]
155
-
156
- # Remove tools associated with the session.
157
- for name in component_names .tools :
158
- del self ._tools [name ]
159
-
160
- del self ._sessions [session ]
196
+ if session_known_for_components :
197
+ component_names = self ._sessions .pop (session ) # Pop from _sessions tracking
198
+
199
+ # Remove prompts associated with the session.
200
+ for name in component_names .prompts :
201
+ if name in self ._prompts :
202
+ del self ._prompts [name ]
203
+ # Remove resources associated with the session.
204
+ for name in component_names .resources :
205
+ if name in self ._resources :
206
+ del self ._resources [name ]
207
+ # Remove tools associated with the session.
208
+ for name in component_names .tools :
209
+ if name in self ._tools :
210
+ del self ._tools [name ]
211
+ if name in self ._tool_to_session :
212
+ del self ._tool_to_session [name ]
213
+
214
+ # Clean up the session's resources via its dedicated exit stack
215
+ if session_known_for_stack :
216
+ session_stack_to_close = self ._session_exit_stacks .pop (session )
217
+ await session_stack_to_close .aclose ()
161
218
162
219
async def connect_to_server (
163
220
self ,
@@ -181,47 +238,66 @@ async def connect_to_server(
181
238
tool_to_session_temp : dict [str , mcp .ClientSession ] = {}
182
239
183
240
# Query the server for its prompts and aggregate to list.
184
- prompts = (await session .list_prompts ()).prompts
185
- for prompt in prompts :
186
- name = self ._component_name (prompt .name , server_info )
187
- if name in self ._prompts :
188
- raise McpError (
189
- types .ErrorData (
190
- code = types .INVALID_PARAMS ,
191
- message = f"{ name } already exists in group prompts." ,
192
- )
193
- )
194
- prompts_temp [name ] = prompt
195
- component_names .prompts .add (name )
241
+ try :
242
+ prompts = (await session .list_prompts ()).prompts
243
+ for prompt in prompts :
244
+ name = self ._component_name (prompt .name , server_info )
245
+ prompts_temp [name ] = prompt
246
+ component_names .prompts .add (name )
247
+ except McpError as err :
248
+ logging .warning (f"Could not fetch prompts: { err } " )
196
249
197
250
# Query the server for its resources and aggregate to list.
198
- resources = (await session .list_resources ()).resources
199
- for resource in resources :
200
- name = self ._component_name (resource .name , server_info )
201
- if name in self ._resources :
202
- raise McpError (
203
- types .ErrorData (
204
- code = types .INVALID_PARAMS ,
205
- message = f"{ name } already exists in group resources." ,
206
- )
207
- )
208
- resources_temp [name ] = resource
209
- component_names .resources .add (name )
251
+ try :
252
+ resources = (await session .list_resources ()).resources
253
+ for resource in resources :
254
+ name = self ._component_name (resource .name , server_info )
255
+ resources_temp [name ] = resource
256
+ component_names .resources .add (name )
257
+ except McpError as err :
258
+ logging .warning (f"Could not fetch resources: { err } " )
210
259
211
260
# Query the server for its tools and aggregate to list.
212
- tools = (await session .list_tools ()).tools
213
- for tool in tools :
214
- name = self ._component_name (tool .name , server_info )
215
- if name in self ._tools :
216
- raise McpError (
217
- types .ErrorData (
218
- code = types .INVALID_PARAMS ,
219
- message = f"{ name } already exists in group tools." ,
220
- )
261
+ try :
262
+ tools = (await session .list_tools ()).tools
263
+ for tool in tools :
264
+ name = self ._component_name (tool .name , server_info )
265
+ tools_temp [name ] = tool
266
+ tool_to_session_temp [name ] = session
267
+ component_names .tools .add (name )
268
+ except McpError as err :
269
+ logging .warning (f"Could not fetch tools: { err } " )
270
+
271
+ # Clean up exit stack for session if we couldn't retrieve anything
272
+ # from the server.
273
+ if not any ((prompts_temp , resources_temp , tools_temp )):
274
+ del self ._session_exit_stacks [session ]
275
+
276
+ # Check for duplicates.
277
+ matching_prompts = prompts_temp .keys () & self ._prompts .keys ()
278
+ if matching_prompts :
279
+ raise McpError (
280
+ types .ErrorData (
281
+ code = types .INVALID_PARAMS ,
282
+ message = f"{ matching_prompts } already exist in group prompts." ,
221
283
)
222
- tools_temp [name ] = tool
223
- tool_to_session_temp [name ] = session
224
- component_names .tools .add (name )
284
+ )
285
+ matching_resources = resources_temp .keys () & self ._resources .keys ()
286
+ if matching_resources :
287
+ raise McpError (
288
+ types .ErrorData (
289
+ code = types .INVALID_PARAMS ,
290
+ message = f"{ matching_resources } already exist in group resources." ,
291
+ )
292
+ )
293
+ matching_tools = tools_temp .keys () & self ._tools .keys ()
294
+ if matching_tools :
295
+ raise McpError (
296
+ types .ErrorData (
297
+ code = types .INVALID_PARAMS ,
298
+ message = f"{ matching_tools } already exist in group tools." ,
299
+ )
300
+ )
225
301
226
302
# Aggregate components.
227
303
self ._sessions [session ] = component_names
@@ -237,33 +313,48 @@ async def _establish_session(
237
313
) -> tuple [types .Implementation , mcp .ClientSession ]:
238
314
"""Establish a client session to an MCP server."""
239
315
240
- # Create read and write streams that facilitate io with the server.
241
- if isinstance (server_params , StdioServerParameters ):
242
- client = mcp .stdio_client (server_params )
243
- read , write = await self ._exit_stack .enter_async_context (client )
244
- elif isinstance (server_params , SseServerParameters ):
245
- client = sse_client (
246
- url = server_params .url ,
247
- headers = server_params .headers ,
248
- timeout = server_params .timeout ,
249
- sse_read_timeout = server_params .sse_read_timeout ,
250
- )
251
- read , write = await self ._exit_stack .enter_async_context (client )
252
- else :
253
- client = streamablehttp_client (
254
- url = server_params .url ,
255
- headers = server_params .headers ,
256
- timeout = server_params .timeout ,
257
- sse_read_timeout = server_params .sse_read_timeout ,
258
- terminate_on_close = server_params .terminate_on_close ,
259
- )
260
- read , write , _ = await self ._exit_stack .enter_async_context (client )
316
+ session_specific_stack = contextlib .AsyncExitStack ()
317
+ try :
318
+ # Create read and write streams that facilitate io with the server.
319
+ if isinstance (server_params , StdioServerParameters ):
320
+ client = mcp .stdio_client (server_params )
321
+ read , write = await self ._exit_stack .enter_async_context (client )
322
+ elif isinstance (server_params , SseServerParameters ):
323
+ client = sse_client (
324
+ url = server_params .url ,
325
+ headers = server_params .headers ,
326
+ timeout = server_params .timeout ,
327
+ sse_read_timeout = server_params .sse_read_timeout ,
328
+ )
329
+ read , write = await self ._exit_stack .enter_async_context (client )
330
+ else :
331
+ client = streamablehttp_client (
332
+ url = server_params .url ,
333
+ headers = server_params .headers ,
334
+ timeout = server_params .timeout ,
335
+ sse_read_timeout = server_params .sse_read_timeout ,
336
+ terminate_on_close = server_params .terminate_on_close ,
337
+ )
338
+ read , write , _ = await self ._exit_stack .enter_async_context (client )
261
339
262
- session = await self ._exit_stack .enter_async_context (
263
- mcp .ClientSession (read , write )
264
- )
265
- result = await session .initialize ()
266
- return result .serverInfo , session
340
+ session = await self ._exit_stack .enter_async_context (
341
+ mcp .ClientSession (read , write )
342
+ )
343
+ result = await session .initialize ()
344
+
345
+ # Session successfully initialized.
346
+ # Store its stack and register the stack with the main group stack.
347
+ self ._session_exit_stacks [session ] = session_specific_stack
348
+ # session_specific_stack itself becomes a resource managed by the
349
+ # main _exit_stack.
350
+ await self ._exit_stack .enter_async_context (session_specific_stack )
351
+
352
+ return result .serverInfo , session
353
+ except Exception :
354
+ # If anything during this setup fails, ensure the session-specific
355
+ # stack is closed.
356
+ await session_specific_stack .aclose ()
357
+ raise
267
358
268
359
def _component_name (self , name : str , server_info : types .Implementation ) -> str :
269
360
if self ._component_name_hook :
0 commit comments