Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 58 additions & 13 deletions src/strands/tools/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,12 @@ def __init__(
mcp_instrumentation()
self._session_id = uuid.uuid4()
self._log_debug_with_thread("initializing MCPClient connection")
# Main thread blocks until future completesock
# Main thread blocks until future completes
self._init_future: futures.Future[None] = futures.Future()
# Set within the inner loop as it needs the asyncio loop
self._close_future: asyncio.futures.Future[None] | None = None
self._close_exception: None | Exception = None
# Do not want to block other threads while close event is false
self._close_event = asyncio.Event()
self._transport_callable = transport_callable

self._background_thread: threading.Thread | None = None
Expand Down Expand Up @@ -288,11 +290,12 @@ def stop(
- _background_thread: Thread running the async event loop
- _background_thread_session: MCP ClientSession (auto-closed by context manager)
- _background_thread_event_loop: AsyncIO event loop in background thread
- _close_event: AsyncIO event to signal thread shutdown
- _close_future: AsyncIO future to signal thread shutdown
- _close_exception: Exception that caused the background thread shutdown; None if a normal shutdown occurred.
- _init_future: Future for initialization synchronization

Cleanup order:
1. Signal close event to background thread (if session initialized)
1. Signal close future to background thread (if session initialized)
2. Wait for background thread to complete
3. Reset all state for reuse

Expand All @@ -303,25 +306,26 @@ def stop(
"""
self._log_debug_with_thread("exiting MCPClient context")

# Only try to signal close event if we have a background thread
# Only try to signal close future if we have a background thread
if self._background_thread is not None:
# Signal close event if event loop exists
# Signal close future if event loop exists
if self._background_thread_event_loop is not None:

async def _set_close_event() -> None:
self._close_event.set()
if self._close_future and not self._close_future.done():
self._close_future.set_result(None)

# Not calling _invoke_on_background_thread since the session does not need to exist
# we only need the thread and event loop to exist.
asyncio.run_coroutine_threadsafe(coro=_set_close_event(), loop=self._background_thread_event_loop)

self._log_debug_with_thread("waiting for background thread to join")
self._background_thread.join()

self._log_debug_with_thread("background thread is closed, MCPClient context exited")

# Reset fields to allow instance reuse
self._init_future = futures.Future()
self._close_event = asyncio.Event()
self._background_thread = None
self._background_thread_session = None
self._background_thread_event_loop = None
Expand All @@ -330,6 +334,11 @@ async def _set_close_event() -> None:
self._tool_provider_started = False
self._consumers = set()

if self._close_exception:
exception = self._close_exception
self._close_exception = None
raise RuntimeError("Connection to the MCP server was closed") from exception

def list_tools_sync(
self,
pagination_token: str | None = None,
Expand Down Expand Up @@ -563,6 +572,10 @@ async def _async_background_thread(self) -> None:
signals readiness to the main thread, and waits for a close signal.
"""
self._log_debug_with_thread("starting async background thread for MCP connection")

# Initialized here so that it has the asyncio loop
self._close_future = asyncio.Future()

try:
async with self._transport_callable() as (read_stream, write_stream, *_):
self._log_debug_with_thread("transport connection established")
Expand All @@ -583,15 +596,22 @@ async def _async_background_thread(self) -> None:

self._log_debug_with_thread("waiting for close signal")
# Keep background thread running until signaled to close.
# Thread is not blocked as this is an asyncio.Event not a threading.Event
await self._close_event.wait()
# Thread is not blocked as this a future
await self._close_future

self._log_debug_with_thread("close signal received")
except Exception as e:
# If we encounter an exception and the future is still running,
# it means it was encountered during the initialization phase.
if not self._init_future.done():
self._init_future.set_exception(e)
else:
# _close_future is automatically cancelled by the framework which doesn't provide us with the useful
# exception, so instead we store the exception in a different field where stop() can read it
self._close_exception = e
if self._close_future and not self._close_future.done():
self._close_future.set_result(None)

self._log_debug_with_thread(
"encountered exception on background thread after initialization %s", str(e)
)
Expand All @@ -601,7 +621,7 @@ def _background_task(self) -> None:

This method creates a new event loop for the background thread,
sets it as the current event loop, and runs the async_background_thread
coroutine until completion. In this case "until completion" means until the _close_event is set.
coroutine until completion. In this case "until completion" means until the _close_future is resolved.
This allows for a long-running event loop.
"""
self._log_debug_with_thread("setting up background task event loop")
Expand Down Expand Up @@ -699,9 +719,34 @@ def _log_debug_with_thread(self, msg: str, *args: Any, **kwargs: Any) -> None:
)

def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures.Future[T]:
if self._background_thread_session is None or self._background_thread_event_loop is None:
# save a reference to this so that even if it's reset we have the original
close_future = self._close_future

if (
self._background_thread_session is None
or self._background_thread_event_loop is None
or close_future is None
):
raise MCPClientInitializationError("the client session was not initialized")
return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop)

async def run_async() -> T:
# Fix for strands-agents/sdk-python/issues/995 - cancel all pending invocations if/when the session closes
invoke_event = asyncio.create_task(coro)
tasks: list[asyncio.Task | asyncio.Future] = [
invoke_event,
close_future,
]

done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)

if done.pop() == close_future:
self._log_debug_with_thread("event loop for the server closed before the invoke completed")
raise RuntimeError("Connection to the MCP server was closed")
else:
return await invoke_event

invoke_future = asyncio.run_coroutine_threadsafe(coro=run_async(), loop=self._background_thread_event_loop)
return invoke_future

def _should_include_tool(self, tool: MCPAgentTool) -> bool:
"""Check if a tool should be included based on constructor filters."""
Expand Down
67 changes: 67 additions & 0 deletions tests_integ/mcp/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,70 @@ def transport_callback() -> MCPTransport:
result = await streamable_http_client.call_tool_async(tool_use_id="123", name="timeout_tool")
assert result["status"] == "error"
assert result["content"][0]["text"] == "Tool execution failed: Connection closed"


def start_5xx_proxy_for_tool_calls(target_url: str, proxy_port: int):
"""Starts a proxy that throws a 5XX when a tool call is invoked"""
import aiohttp
from aiohttp import web

async def proxy_handler(request):
url = f"{target_url}{request.path_qs}"

async with aiohttp.ClientSession() as session:
data = await request.read()

if "tools/call" in f"{data}":
return web.Response(status=500, text="Internal Server Error")

async with session.request(
method=request.method, url=url, headers=request.headers, data=data, allow_redirects=False
) as resp:
print(f"Got request to {url} {data}")
response = web.StreamResponse(status=resp.status, headers=resp.headers)
await response.prepare(request)

async for chunk in resp.content.iter_chunked(8192):
await response.write(chunk)

return response

app = web.Application()
app.router.add_route("*", "/{path:.*}", proxy_handler)

web.run_app(app, host="127.0.0.1", port=proxy_port)


@pytest.mark.asyncio
async def test_streamable_http_mcp_client_with_500_error():
import asyncio
import multiprocessing

server_thread = threading.Thread(
target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True
)
server_thread.start()

proxy_process = multiprocessing.Process(
target=start_5xx_proxy_for_tool_calls, kwargs={"target_url": "http://127.0.0.1:8001", "proxy_port": 8002}
)
proxy_process.start()

try:
await asyncio.sleep(2) # wait for server to startup completely

def transport_callback() -> MCPTransport:
return streamablehttp_client(url="http://127.0.0.1:8002/mcp")

streamable_http_client = MCPClient(transport_callback)
with pytest.raises(RuntimeError, match="Connection to the MCP server was closed"):
with streamable_http_client:
result = await streamable_http_client.call_tool_async(
tool_use_id="123", name="calculator", arguments={"x": 3, "y": 4}
)
finally:
proxy_process.terminate()
proxy_process.join()

assert result["status"] == "error"
assert result["content"][0]["text"] == "Tool execution failed: Connection to the MCP server was closed"
Loading