-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathserver.py
205 lines (182 loc) · 7.06 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import contextlib
import logging
from http import HTTPStatus
from uuid import uuid4
import anyio
import click
import mcp.types as types
from mcp.server.lowlevel import Server
from mcp.server.streamableHttp import (
MCP_SESSION_ID_HEADER,
StreamableHTTPServerTransport,
)
from pydantic import AnyUrl
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount
# Configure logging
logger = logging.getLogger(__name__)
# Global task group that will be initialized in the lifespan
task_group = None
@contextlib.asynccontextmanager
async def lifespan(app):
"""Application lifespan context manager for managing task group."""
global task_group
async with anyio.create_task_group() as tg:
task_group = tg
logger.info("Application started, task group initialized!")
try:
yield
finally:
logger.info("Application shutting down, cleaning up resources...")
if task_group:
tg.cancel_scope.cancel()
task_group = None
logger.info("Resources cleaned up successfully.")
@click.command()
@click.option("--port", default=3000, help="Port to listen on for HTTP")
@click.option(
"--log-level",
default="INFO",
help="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
)
@click.option(
"--json-response",
is_flag=True,
default=False,
help="Enable JSON responses instead of SSE streams",
)
def main(
port: int,
log_level: str,
json_response: bool,
) -> int:
# Configure logging
logging.basicConfig(
level=getattr(logging, log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
app = Server("mcp-streamable-http-demo")
@app.call_tool()
async def call_tool(
name: str, arguments: dict
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
ctx = app.request_context
interval = arguments.get("interval", 1.0)
count = arguments.get("count", 5)
caller = arguments.get("caller", "unknown")
# Send the specified number of notifications with the given interval
for i in range(count):
await ctx.session.send_log_message(
level="info",
data=f"Notification {i+1}/{count} from caller: {caller}",
logger="notification_stream",
# Associates this notification with the original request
# Ensures notifications are sent to the correct response stream
# Without this, notifications will either go to:
# - a standalone SSE stream (if GET request is supported)
# - nowhere (if GET request isn't supported)
related_request_id=ctx.request_id,
)
if i < count - 1: # Don't wait after the last notification
await anyio.sleep(interval)
# This will send a resource notificaiton though standalone SSE
# established by GET request
await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource"))
return [
types.TextContent(
type="text",
text=(
f"Sent {count} notifications with {interval}s interval"
f" for caller: {caller}"
),
)
]
@app.list_tools()
async def list_tools() -> list[types.Tool]:
return [
types.Tool(
name="start-notification-stream",
description=(
"Sends a stream of notifications with configurable count"
" and interval"
),
inputSchema={
"type": "object",
"required": ["interval", "count", "caller"],
"properties": {
"interval": {
"type": "number",
"description": "Interval between notifications in seconds",
},
"count": {
"type": "number",
"description": "Number of notifications to send",
},
"caller": {
"type": "string",
"description": (
"Identifier of the caller to include in notifications"
),
},
},
},
)
]
# We need to store the server instances between requests
server_instances = {}
# Lock to prevent race conditions when creating new sessions
session_creation_lock = anyio.Lock()
# ASGI handler for streamable HTTP connections
async def handle_streamable_http(scope, receive, send):
request = Request(scope, receive)
request_mcp_session_id = request.headers.get(MCP_SESSION_ID_HEADER)
if (
request_mcp_session_id is not None
and request_mcp_session_id in server_instances
):
transport = server_instances[request_mcp_session_id]
logger.debug("Session already exists, handling request directly")
await transport.handle_request(scope, receive, send)
elif request_mcp_session_id is None:
# try to establish new session
logger.debug("Creating new transport")
# Use lock to prevent race conditions when creating new sessions
async with session_creation_lock:
new_session_id = uuid4().hex
http_transport = StreamableHTTPServerTransport(
mcp_session_id=new_session_id,
is_json_response_enabled=json_response,
)
server_instances[http_transport.mcp_session_id] = http_transport
async with http_transport.connect() as streams:
read_stream, write_stream = streams
async def run_server():
await app.run(
read_stream,
write_stream,
app.create_initialization_options(),
)
if not task_group:
raise RuntimeError("Task group is not initialized")
task_group.start_soon(run_server)
# Handle the HTTP request and return the response
await http_transport.handle_request(scope, receive, send)
else:
response = Response(
"Bad Request: No valid session ID provided",
status_code=HTTPStatus.BAD_REQUEST,
)
await response(scope, receive, send)
# Create an ASGI application using the transport
starlette_app = Starlette(
debug=True,
routes=[
Mount("/mcp", app=handle_streamable_http),
],
lifespan=lifespan,
)
import uvicorn
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
return 0