|
9 | 9 | from mcp.server.lowlevel.server import Server |
10 | 10 | from mcp.shared.exceptions import McpError |
11 | 11 | from mcp.shared.memory import create_connected_server_and_client_session |
| 12 | +from mcp.shared.message import SessionMessage |
12 | 13 | from mcp.types import ( |
| 14 | + LATEST_PROTOCOL_VERSION, |
13 | 15 | CallToolRequest, |
14 | 16 | CallToolRequestParams, |
15 | 17 | CallToolResult, |
16 | 18 | CancelledNotification, |
17 | 19 | CancelledNotificationParams, |
| 20 | + ClientCapabilities, |
18 | 21 | ClientNotification, |
19 | 22 | ClientRequest, |
| 23 | + Implementation, |
| 24 | + InitializeRequestParams, |
| 25 | + JSONRPCNotification, |
| 26 | + JSONRPCRequest, |
20 | 27 | Tool, |
21 | 28 | ) |
22 | 29 |
|
@@ -108,3 +115,156 @@ async def first_request(): |
108 | 115 | assert isinstance(content, types.TextContent) |
109 | 116 | assert content.text == "Call number: 2" |
110 | 117 | assert call_count == 2 |
| 118 | + |
| 119 | + |
| 120 | +@pytest.mark.anyio |
| 121 | +async def test_server_cancels_in_flight_handlers_on_transport_close(): |
| 122 | + """When the transport closes mid-request, server.run() must cancel in-flight |
| 123 | + handlers rather than join on them. |
| 124 | +
|
| 125 | + Without the cancel, the task group waits for the handler, which then tries |
| 126 | + to respond through a write stream that _receive_loop already closed, |
| 127 | + raising ClosedResourceError and crashing server.run() with exit code 1. |
| 128 | +
|
| 129 | + This drives server.run() with raw memory streams because InMemoryTransport |
| 130 | + wraps it in its own finally-cancel (_memory.py) which masks the bug. |
| 131 | + """ |
| 132 | + handler_started = anyio.Event() |
| 133 | + handler_cancelled = anyio.Event() |
| 134 | + server_run_returned = anyio.Event() |
| 135 | + |
| 136 | + server = Server("test") |
| 137 | + |
| 138 | + @server.call_tool() |
| 139 | + async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: |
| 140 | + handler_started.set() |
| 141 | + try: |
| 142 | + await anyio.sleep_forever() |
| 143 | + finally: |
| 144 | + handler_cancelled.set() |
| 145 | + # unreachable: sleep_forever only exits via cancellation |
| 146 | + raise AssertionError # pragma: no cover |
| 147 | + |
| 148 | + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) |
| 149 | + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) |
| 150 | + |
| 151 | + async def run_server(): |
| 152 | + await server.run(server_read, server_write, server.create_initialization_options()) |
| 153 | + server_run_returned.set() |
| 154 | + |
| 155 | + init_req = JSONRPCRequest( |
| 156 | + jsonrpc="2.0", |
| 157 | + id=1, |
| 158 | + method="initialize", |
| 159 | + params=InitializeRequestParams( |
| 160 | + protocolVersion=LATEST_PROTOCOL_VERSION, |
| 161 | + capabilities=ClientCapabilities(), |
| 162 | + clientInfo=Implementation(name="test", version="1.0"), |
| 163 | + ).model_dump(by_alias=True, mode="json", exclude_none=True), |
| 164 | + ) |
| 165 | + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") |
| 166 | + call_req = JSONRPCRequest( |
| 167 | + jsonrpc="2.0", |
| 168 | + id=2, |
| 169 | + method="tools/call", |
| 170 | + params=CallToolRequestParams(name="slow", arguments={}).model_dump(by_alias=True, mode="json"), |
| 171 | + ) |
| 172 | + |
| 173 | + with anyio.fail_after(5): |
| 174 | + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: |
| 175 | + tg.start_soon(run_server) |
| 176 | + |
| 177 | + await to_server.send(SessionMessage(message=types.JSONRPCMessage(init_req))) |
| 178 | + await from_server.receive() # init response |
| 179 | + await to_server.send(SessionMessage(message=types.JSONRPCMessage(initialized))) |
| 180 | + await to_server.send(SessionMessage(message=types.JSONRPCMessage(call_req))) |
| 181 | + |
| 182 | + await handler_started.wait() |
| 183 | + |
| 184 | + # Close the server's input stream — this is what stdin EOF does. |
| 185 | + # server.run()'s incoming_messages loop ends, finally-cancel fires, |
| 186 | + # handler gets CancelledError, server.run() returns. |
| 187 | + await to_server.aclose() |
| 188 | + |
| 189 | + await server_run_returned.wait() |
| 190 | + |
| 191 | + assert handler_cancelled.is_set() |
| 192 | + |
| 193 | + |
| 194 | +@pytest.mark.anyio |
| 195 | +async def test_server_handles_transport_close_with_pending_server_to_client_requests(): |
| 196 | + """When the transport closes while handlers are blocked on server→client |
| 197 | + requests (sampling, roots, elicitation), server.run() must still exit cleanly. |
| 198 | +
|
| 199 | + Two bugs covered: |
| 200 | + 1. _receive_loop's finally iterates _response_streams with await checkpoints |
| 201 | + inside; the woken handler's send_request finally pops from that dict |
| 202 | + before the next __next__() — RuntimeError: dictionary changed size. |
| 203 | + 2. The woken handler's MCPError is caught in _handle_request, which falls |
| 204 | + through to respond() against a write stream _receive_loop already closed. |
| 205 | + """ |
| 206 | + handlers_started = 0 |
| 207 | + both_started = anyio.Event() |
| 208 | + server_run_returned = anyio.Event() |
| 209 | + |
| 210 | + server = Server("test") |
| 211 | + |
| 212 | + @server.call_tool() |
| 213 | + async def handle_call_tool(name: str, arguments: dict[str, Any] | None) -> list[types.TextContent]: |
| 214 | + nonlocal handlers_started |
| 215 | + handlers_started += 1 |
| 216 | + if handlers_started == 2: |
| 217 | + both_started.set() |
| 218 | + # Blocks on send_request waiting for a client response that never comes. |
| 219 | + # _receive_loop's finally will wake this with CONNECTION_CLOSED. |
| 220 | + await server.request_context.session.list_roots() |
| 221 | + raise AssertionError # pragma: no cover |
| 222 | + |
| 223 | + to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) |
| 224 | + server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) |
| 225 | + |
| 226 | + async def run_server(): |
| 227 | + await server.run(server_read, server_write, server.create_initialization_options()) |
| 228 | + server_run_returned.set() |
| 229 | + |
| 230 | + init_req = JSONRPCRequest( |
| 231 | + jsonrpc="2.0", |
| 232 | + id=1, |
| 233 | + method="initialize", |
| 234 | + params=InitializeRequestParams( |
| 235 | + protocolVersion=LATEST_PROTOCOL_VERSION, |
| 236 | + capabilities=ClientCapabilities(), |
| 237 | + clientInfo=Implementation(name="test", version="1.0"), |
| 238 | + ).model_dump(by_alias=True, mode="json", exclude_none=True), |
| 239 | + ) |
| 240 | + initialized = JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized") |
| 241 | + |
| 242 | + with anyio.fail_after(5): |
| 243 | + async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: |
| 244 | + tg.start_soon(run_server) |
| 245 | + |
| 246 | + await to_server.send(SessionMessage(message=types.JSONRPCMessage(init_req))) |
| 247 | + await from_server.receive() # init response |
| 248 | + await to_server.send(SessionMessage(message=types.JSONRPCMessage(initialized))) |
| 249 | + |
| 250 | + # Two tool calls → two handlers → two _response_streams entries. |
| 251 | + for rid in (2, 3): |
| 252 | + call_req = JSONRPCRequest( |
| 253 | + jsonrpc="2.0", |
| 254 | + id=rid, |
| 255 | + method="tools/call", |
| 256 | + params=CallToolRequestParams(name="t", arguments={}).model_dump(by_alias=True, mode="json"), |
| 257 | + ) |
| 258 | + await to_server.send(SessionMessage(message=types.JSONRPCMessage(call_req))) |
| 259 | + |
| 260 | + await both_started.wait() |
| 261 | + # Drain the two roots/list requests so send_request's _write_stream.send() |
| 262 | + # completes and both handlers are parked at response_stream_reader.receive(). |
| 263 | + await from_server.receive() |
| 264 | + await from_server.receive() |
| 265 | + |
| 266 | + await to_server.aclose() |
| 267 | + |
| 268 | + # Without the fixes: RuntimeError (dict mutation) or ClosedResourceError |
| 269 | + # (respond after write-stream close) escapes run_server and this hangs. |
| 270 | + await server_run_returned.wait() |
0 commit comments