|
1 | 1 | import contextlib
|
2 | 2 | import logging
|
3 |
| -from http import HTTPStatus |
4 |
| -from uuid import uuid4 |
| 3 | +from collections.abc import AsyncIterator |
5 | 4 |
|
6 | 5 | import anyio
|
7 | 6 | import click
|
8 | 7 | import mcp.types as types
|
9 | 8 | from mcp.server.lowlevel import Server
|
10 |
| -from mcp.server.streamable_http import ( |
11 |
| - MCP_SESSION_ID_HEADER, |
12 |
| - StreamableHTTPServerTransport, |
13 |
| -) |
| 9 | +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager |
14 | 10 | from pydantic import AnyUrl
|
15 | 11 | from starlette.applications import Starlette
|
16 |
| -from starlette.requests import Request |
17 |
| -from starlette.responses import Response |
18 | 12 | from starlette.routing import Mount
|
| 13 | +from starlette.types import Receive, Scope, Send |
19 | 14 |
|
20 | 15 | from .event_store import InMemoryEventStore
|
21 | 16 |
|
22 | 17 | # Configure logging
|
23 | 18 | logger = logging.getLogger(__name__)
|
24 | 19 |
|
25 |
| -# Global task group that will be initialized in the lifespan |
26 |
| -task_group = None |
27 |
| - |
28 |
| -# Event store for resumability |
29 |
| -# The InMemoryEventStore enables resumability support for StreamableHTTP transport. |
30 |
| -# It stores SSE events with unique IDs, allowing clients to: |
31 |
| -# 1. Receive event IDs for each SSE message |
32 |
| -# 2. Resume streams by sending Last-Event-ID in GET requests |
33 |
| -# 3. Replay missed events after reconnection |
34 |
| -# Note: This in-memory implementation is for demonstration ONLY. |
35 |
| -# For production, use a persistent storage solution. |
36 |
| -event_store = InMemoryEventStore() |
37 |
| - |
38 |
| - |
39 |
| -@contextlib.asynccontextmanager |
40 |
| -async def lifespan(app): |
41 |
| - """Application lifespan context manager for managing task group.""" |
42 |
| - global task_group |
43 |
| - |
44 |
| - async with anyio.create_task_group() as tg: |
45 |
| - task_group = tg |
46 |
| - logger.info("Application started, task group initialized!") |
47 |
| - try: |
48 |
| - yield |
49 |
| - finally: |
50 |
| - logger.info("Application shutting down, cleaning up resources...") |
51 |
| - if task_group: |
52 |
| - tg.cancel_scope.cancel() |
53 |
| - task_group = None |
54 |
| - logger.info("Resources cleaned up successfully.") |
55 |
| - |
56 | 20 |
|
57 | 21 | @click.command()
|
58 | 22 | @click.option("--port", default=3000, help="Port to listen on for HTTP")
|
@@ -156,60 +120,38 @@ async def list_tools() -> list[types.Tool]:
|
156 | 120 | )
|
157 | 121 | ]
|
158 | 122 |
|
159 |
| - # We need to store the server instances between requests |
160 |
| - server_instances = {} |
161 |
| - # Lock to prevent race conditions when creating new sessions |
162 |
| - session_creation_lock = anyio.Lock() |
| 123 | + # Create event store for resumability |
| 124 | + # The InMemoryEventStore enables resumability support for StreamableHTTP transport. |
| 125 | + # It stores SSE events with unique IDs, allowing clients to: |
| 126 | + # 1. Receive event IDs for each SSE message |
| 127 | + # 2. Resume streams by sending Last-Event-ID in GET requests |
| 128 | + # 3. Replay missed events after reconnection |
| 129 | + # Note: This in-memory implementation is for demonstration ONLY. |
| 130 | + # For production, use a persistent storage solution. |
| 131 | + event_store = InMemoryEventStore() |
| 132 | + |
| 133 | + # Create the session manager with our app and event store |
| 134 | + session_manager = StreamableHTTPSessionManager( |
| 135 | + app=app, |
| 136 | + event_store=event_store, # Enable resumability |
| 137 | + json_response=json_response, |
| 138 | + ) |
163 | 139 |
|
164 | 140 | # ASGI handler for streamable HTTP connections
|
165 |
| - async def handle_streamable_http(scope, receive, send): |
166 |
| - request = Request(scope, receive) |
167 |
| - request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER) |
168 |
| - if ( |
169 |
| - request_mcp_session_id is not None |
170 |
| - and request_mcp_session_id in server_instances |
171 |
| - ): |
172 |
| - transport = server_instances[request_mcp_session_id] |
173 |
| - logger.debug("Session already exists, handling request directly") |
174 |
| - await transport.handle_request(scope, receive, send) |
175 |
| - elif request_mcp_session_id is None: |
176 |
| - # try to establish new session |
177 |
| - logger.debug("Creating new transport") |
178 |
| - # Use lock to prevent race conditions when creating new sessions |
179 |
| - async with session_creation_lock: |
180 |
| - new_session_id = uuid4().hex |
181 |
| - http_transport = StreamableHTTPServerTransport( |
182 |
| - mcp_session_id=new_session_id, |
183 |
| - is_json_response_enabled=json_response, |
184 |
| - event_store=event_store, # Enable resumability |
185 |
| - ) |
186 |
| - server_instances[http_transport.mcp_session_id] = http_transport |
187 |
| - logger.info(f"Created new transport with session ID: {new_session_id}") |
188 |
| - |
189 |
| - async def run_server(task_status=None): |
190 |
| - async with http_transport.connect() as streams: |
191 |
| - read_stream, write_stream = streams |
192 |
| - if task_status: |
193 |
| - task_status.started() |
194 |
| - await app.run( |
195 |
| - read_stream, |
196 |
| - write_stream, |
197 |
| - app.create_initialization_options(), |
198 |
| - ) |
199 |
| - |
200 |
| - if not task_group: |
201 |
| - raise RuntimeError("Task group is not initialized") |
202 |
| - |
203 |
| - await task_group.start(run_server) |
204 |
| - |
205 |
| - # Handle the HTTP request and return the response |
206 |
| - await http_transport.handle_request(scope, receive, send) |
207 |
| - else: |
208 |
| - response = Response( |
209 |
| - "Bad Request: No valid session ID provided", |
210 |
| - status_code=HTTPStatus.BAD_REQUEST, |
211 |
| - ) |
212 |
| - await response(scope, receive, send) |
| 141 | + async def handle_streamable_http( |
| 142 | + scope: Scope, receive: Receive, send: Send |
| 143 | + ) -> None: |
| 144 | + await session_manager.handle_request(scope, receive, send) |
| 145 | + |
| 146 | + @contextlib.asynccontextmanager |
| 147 | + async def lifespan(app: Starlette) -> AsyncIterator[None]: |
| 148 | + """Context manager for managing session manager lifecycle.""" |
| 149 | + async with session_manager.run(): |
| 150 | + logger.info("Application started with StreamableHTTP session manager!") |
| 151 | + try: |
| 152 | + yield |
| 153 | + finally: |
| 154 | + logger.info("Application shutting down...") |
213 | 155 |
|
214 | 156 | # Create an ASGI application using the transport
|
215 | 157 | starlette_app = Starlette(
|
|
0 commit comments