diff --git a/src/client/streamableHttp.test.ts b/src/client/streamableHttp.test.ts index 1572e912..073cc9ac 100644 --- a/src/client/streamableHttp.test.ts +++ b/src/client/streamableHttp.test.ts @@ -1,4 +1,4 @@ -import { StreamableHTTPClientTransport } from "./streamableHttp.js"; +import { StreamableHTTPClientTransport, StreamableHTTPReconnectionOptions } from "./streamableHttp.js"; import { JSONRPCMessage } from "../types.js"; @@ -164,7 +164,7 @@ describe("StreamableHTTPClientTransport", () => { // We expect the 405 error to be caught and handled gracefully // This should not throw an error that breaks the transport await transport.start(); - await expect(transport["_startOrAuthStandaloneSSE"]()).resolves.not.toThrow("Failed to open SSE stream: Method Not Allowed"); + await expect(transport["_startOrAuthStandaloneSSE"]({})).resolves.not.toThrow("Failed to open SSE stream: Method Not Allowed"); // Check that GET was attempted expect(global.fetch).toHaveBeenCalledWith( expect.anything(), @@ -208,7 +208,7 @@ describe("StreamableHTTPClientTransport", () => { transport.onmessage = messageSpy; await transport.start(); - await transport["_startOrAuthStandaloneSSE"](); + await transport["_startOrAuthStandaloneSSE"]({}); // Give time for the SSE event to be processed await new Promise(resolve => setTimeout(resolve, 50)); @@ -275,45 +275,62 @@ describe("StreamableHTTPClientTransport", () => { })).toBe(true); }); - it("should include last-event-id header when resuming a broken connection", async () => { - // First make a successful connection that provides an event ID - const encoder = new TextEncoder(); - const stream = new ReadableStream({ - start(controller) { - const event = "id: event-123\nevent: message\ndata: {\"jsonrpc\": \"2.0\", \"method\": \"serverNotification\", \"params\": {}}\n\n"; - controller.enqueue(encoder.encode(event)); - controller.close(); + it("should support custom reconnection options", () => { + // Create a transport with custom reconnection options + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + reconnectionOptions: { + initialReconnectionDelay: 500, + maxReconnectionDelay: 10000, + reconnectionDelayGrowFactor: 2, + maxRetries: 5, } }); - (global.fetch as jest.Mock).mockResolvedValueOnce({ - ok: true, - status: 200, - headers: new Headers({ "content-type": "text/event-stream" }), - body: stream - }); + // Verify options were set correctly (checking implementation details) + // Access private properties for testing + const transportInstance = transport as unknown as { + _reconnectionOptions: StreamableHTTPReconnectionOptions; + }; + expect(transportInstance._reconnectionOptions.initialReconnectionDelay).toBe(500); + expect(transportInstance._reconnectionOptions.maxRetries).toBe(5); + }); - await transport.start(); - await transport["_startOrAuthStandaloneSSE"](); - await new Promise(resolve => setTimeout(resolve, 50)); + it("should pass lastEventId when reconnecting", async () => { + // Create a fresh transport + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp")); - // Now simulate attempting to reconnect - (global.fetch as jest.Mock).mockResolvedValueOnce({ + // Mock fetch to verify headers sent + const fetchSpy = global.fetch as jest.Mock; + fetchSpy.mockReset(); + fetchSpy.mockResolvedValue({ ok: true, status: 200, headers: new Headers({ "content-type": "text/event-stream" }), - body: null + body: new ReadableStream() }); - await transport["_startOrAuthStandaloneSSE"](); + // Call the reconnect method directly with a lastEventId + await transport.start(); + // Type assertion to access private method + const transportWithPrivateMethods = transport as unknown as { + _startOrAuthStandaloneSSE: (options: { lastEventId?: string }) => Promise + }; + await transportWithPrivateMethods._startOrAuthStandaloneSSE({ lastEventId: "test-event-id" }); - // Check that Last-Event-ID was included - const calls = (global.fetch as jest.Mock).mock.calls; - const lastCall = calls[calls.length - 1]; - expect(lastCall[1].headers.get("last-event-id")).toBe("event-123"); + // Verify fetch was called with the lastEventId header + expect(fetchSpy).toHaveBeenCalled(); + const fetchCall = fetchSpy.mock.calls[0]; + const headers = fetchCall[1].headers; + expect(headers.get("last-event-id")).toBe("test-event-id"); }); it("should throw error when invalid content-type is received", async () => { + // Clear any previous state from other tests + jest.clearAllMocks(); + + // Create a fresh transport instance + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp")); + const message: JSONRPCMessage = { jsonrpc: "2.0", method: "test", @@ -323,7 +340,7 @@ describe("StreamableHTTPClientTransport", () => { const stream = new ReadableStream({ start(controller) { - controller.enqueue("invalid text response"); + controller.enqueue(new TextEncoder().encode("invalid text response")); controller.close(); } }); @@ -365,7 +382,7 @@ describe("StreamableHTTPClientTransport", () => { await transport.start(); - await transport["_startOrAuthStandaloneSSE"](); + await transport["_startOrAuthStandaloneSSE"]({}); expect((actualReqInit.headers as Headers).get("x-custom-header")).toBe("CustomValue"); requestInit.headers["X-Custom-Header"] = "SecondCustomValue"; @@ -375,4 +392,38 @@ describe("StreamableHTTPClientTransport", () => { expect(global.fetch).toHaveBeenCalledTimes(2); }); + + + it("should have exponential backoff with configurable maxRetries", () => { + // This test verifies the maxRetries and backoff calculation directly + + // Create transport with specific options for testing + transport = new StreamableHTTPClientTransport(new URL("http://localhost:1234/mcp"), { + reconnectionOptions: { + initialReconnectionDelay: 100, + maxReconnectionDelay: 5000, + reconnectionDelayGrowFactor: 2, + maxRetries: 3, + } + }); + + // Get access to the internal implementation + const getDelay = transport["_getNextReconnectionDelay"].bind(transport); + + // First retry - should use initial delay + expect(getDelay(0)).toBe(100); + + // Second retry - should double (2^1 * 100 = 200) + expect(getDelay(1)).toBe(200); + + // Third retry - should double again (2^2 * 100 = 400) + expect(getDelay(2)).toBe(400); + + // Fourth retry - should double again (2^3 * 100 = 800) + expect(getDelay(3)).toBe(800); + + // Tenth retry - should be capped at maxReconnectionDelay + expect(getDelay(10)).toBe(5000); + }); + }); diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index ea69ee77..7bb88c0d 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -3,6 +3,14 @@ import { isJSONRPCNotification, JSONRPCMessage, JSONRPCMessageSchema } from "../ import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js"; import { EventSourceParserStream } from "eventsource-parser/stream"; +// Default reconnection options for StreamableHTTP connections +const DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS: StreamableHTTPReconnectionOptions = { + initialReconnectionDelay: 1000, + maxReconnectionDelay: 30000, + reconnectionDelayGrowFactor: 1.5, + maxRetries: 2, +}; + export class StreamableHTTPError extends Error { constructor( public readonly code: number | undefined, @@ -12,6 +20,45 @@ export class StreamableHTTPError extends Error { } } +/** + * Options for starting or authenticating an SSE connection + */ +export interface StartSSEOptions { + /** + * The ID of the last received event, used for resuming a disconnected stream + */ + lastEventId?: string; +} + +/** + * Configuration options for reconnection behavior of the StreamableHTTPClientTransport. + */ +export interface StreamableHTTPReconnectionOptions { + /** + * Maximum backoff time between reconnection attempts in milliseconds. + * Default is 30000 (30 seconds). + */ + maxReconnectionDelay: number; + + /** + * Initial backoff time between reconnection attempts in milliseconds. + * Default is 1000 (1 second). + */ + initialReconnectionDelay: number; + + /** + * The factor by which the reconnection delay increases after each attempt. + * Default is 1.5. + */ + reconnectionDelayGrowFactor: number; + + /** + * Maximum number of reconnection attempts before giving up. + * Default is 2. + */ + maxRetries: number; +} + /** * Configuration options for the `StreamableHTTPClientTransport`. */ @@ -36,6 +83,11 @@ export type StreamableHTTPClientTransportOptions = { * Customizes HTTP requests to the server. */ requestInit?: RequestInit; + + /** + * Options to configure the reconnection behavior. + */ + reconnectionOptions?: StreamableHTTPReconnectionOptions; }; /** @@ -49,7 +101,7 @@ export class StreamableHTTPClientTransport implements Transport { private _requestInit?: RequestInit; private _authProvider?: OAuthClientProvider; private _sessionId?: string; - private _lastEventId?: string; + private _reconnectionOptions: StreamableHTTPReconnectionOptions; onclose?: () => void; onerror?: (error: Error) => void; @@ -62,6 +114,7 @@ export class StreamableHTTPClientTransport implements Transport { this._url = url; this._requestInit = opts?.requestInit; this._authProvider = opts?.authProvider; + this._reconnectionOptions = opts?.reconnectionOptions ?? DEFAULT_STREAMABLE_HTTP_RECONNECTION_OPTIONS; } private async _authThenStart(): Promise { @@ -81,7 +134,7 @@ export class StreamableHTTPClientTransport implements Transport { throw new UnauthorizedError(); } - return await this._startOrAuthStandaloneSSE(); + return await this._startOrAuthStandaloneSSE({ lastEventId: undefined }); } private async _commonHeaders(): Promise { @@ -102,16 +155,18 @@ export class StreamableHTTPClientTransport implements Transport { ); } - private async _startOrAuthStandaloneSSE(): Promise { + + private async _startOrAuthStandaloneSSE(options: StartSSEOptions): Promise { + const { lastEventId } = options; try { // Try to open an initial SSE stream with GET to listen for server messages // This is optional according to the spec - server may not support it const headers = await this._commonHeaders(); headers.set("Accept", "text/event-stream"); - // Include Last-Event-ID header for resumable streams - if (this._lastEventId) { - headers.set("last-event-id", this._lastEventId); + // Include Last-Event-ID header for resumable streams if provided + if (lastEventId) { + headers.set("last-event-id", lastEventId); } const response = await fetch(this._url, { @@ -137,7 +192,7 @@ export class StreamableHTTPClientTransport implements Transport { `Failed to open SSE stream: ${response.statusText}`, ); } - // Successful connection, handle the SSE stream as a standalone listener + this._handleSseStream(response.body); } catch (error) { this.onerror?.(error as Error); @@ -145,36 +200,111 @@ export class StreamableHTTPClientTransport implements Transport { } } + + /** + * Calculates the next reconnection delay using backoff algorithm + * + * @param attempt Current reconnection attempt count for the specific stream + * @returns Time to wait in milliseconds before next reconnection attempt + */ + private _getNextReconnectionDelay(attempt: number): number { + // Access default values directly, ensuring they're never undefined + const initialDelay = this._reconnectionOptions.initialReconnectionDelay; + const growFactor = this._reconnectionOptions.reconnectionDelayGrowFactor; + const maxDelay = this._reconnectionOptions.maxReconnectionDelay; + + // Cap at maximum delay + return Math.min(initialDelay * Math.pow(growFactor, attempt), maxDelay); + + } + + /** + * Schedule a reconnection attempt with exponential backoff + * + * @param lastEventId The ID of the last received event for resumability + * @param attemptCount Current reconnection attempt count for this specific stream + */ + private _scheduleReconnection(lastEventId: string, attemptCount = 0): void { + // Use provided options or default options + const maxRetries = this._reconnectionOptions.maxRetries; + + // Check if we've exceeded maximum retry attempts + if (maxRetries > 0 && attemptCount >= maxRetries) { + this.onerror?.(new Error(`Maximum reconnection attempts (${maxRetries}) exceeded.`)); + return; + } + + // Calculate next delay based on current attempt count + const delay = this._getNextReconnectionDelay(attemptCount); + + // Schedule the reconnection + setTimeout(() => { + // Use the last event ID to resume where we left off + this._startOrAuthStandaloneSSE({ lastEventId }).catch(error => { + this.onerror?.(new Error(`Failed to reconnect SSE stream: ${error instanceof Error ? error.message : String(error)}`)); + // Schedule another attempt if this one failed, incrementing the attempt counter + this._scheduleReconnection(lastEventId, attemptCount + 1); + }); + }, delay); + } + private _handleSseStream(stream: ReadableStream | null): void { if (!stream) { return; } + let lastEventId: string | undefined; const processStream = async () => { - // Create a pipeline: binary stream -> text decoder -> SSE parser - const eventStream = stream - .pipeThrough(new TextDecoderStream()) - .pipeThrough(new EventSourceParserStream()); - - for await (const event of eventStream) { - // Update last event ID if provided - if (event.id) { - this._lastEventId = event.id; + // this is the closest we can get to trying to catch network errors + // if something happens reader will throw + try { + // Create a pipeline: binary stream -> text decoder -> SSE parser + const reader = stream + .pipeThrough(new TextDecoderStream()) + .pipeThrough(new EventSourceParserStream()) + .getReader(); + + + while (true) { + const { value: event, done } = await reader.read(); + if (done) { + break; + } + + // Update last event ID if provided + if (event.id) { + lastEventId = event.id; + } + + if (!event.event || event.event === "message") { + try { + const message = JSONRPCMessageSchema.parse(JSON.parse(event.data)); + this.onmessage?.(message); + } catch (error) { + this.onerror?.(error as Error); + } + } } - // Handle message events (default event type is undefined per docs) - // or explicit 'message' event type - if (!event.event || event.event === "message") { - try { - const message = JSONRPCMessageSchema.parse(JSON.parse(event.data)); - this.onmessage?.(message); - } catch (error) { - this.onerror?.(error as Error); + } catch (error) { + // Handle stream errors - likely a network disconnect + this.onerror?.(new Error(`SSE stream disconnected: ${error}`)); + + // Attempt to reconnect if the stream disconnects unexpectedly and we aren't closing + if (this._abortController && !this._abortController.signal.aborted) { + // Use the exponential backoff reconnection strategy + if (lastEventId !== undefined) { + try { + this._scheduleReconnection(lastEventId, 0); + } + catch (error) { + this.onerror?.(new Error(`Failed to reconnect: ${error instanceof Error ? error.message : String(error)}`)); + + } } } } }; - - processStream().catch(err => this.onerror?.(err)); + processStream(); } async start() { @@ -252,8 +382,8 @@ export class StreamableHTTPClientTransport implements Transport { // if the accepted notification is initialized, we start the SSE stream // if it's supported by the server if (isJSONRPCNotification(message) && message.method === "notifications/initialized") { - // We don't need to handle 405 here anymore as it's handled in _startOrAuthStandaloneSSE - this._startOrAuthStandaloneSSE().catch(err => this.onerror?.(err)); + // Start without a lastEventId since this is a fresh connection + this._startOrAuthStandaloneSSE({ lastEventId: undefined }).catch(err => this.onerror?.(err)); } return; } diff --git a/src/examples/client/simpleStreamableHttp.ts b/src/examples/client/simpleStreamableHttp.ts index 739e1164..923ffbc2 100644 --- a/src/examples/client/simpleStreamableHttp.ts +++ b/src/examples/client/simpleStreamableHttp.ts @@ -1,5 +1,6 @@ import { Client } from '../../client/index.js'; import { StreamableHTTPClientTransport } from '../../client/streamableHttp.js'; +import { createInterface } from 'node:readline'; import { ListToolsRequest, ListToolsResultSchema, @@ -15,139 +16,429 @@ import { ResourceListChangedNotificationSchema, } from '../../types.js'; +// Create readline interface for user input +const readline = createInterface({ + input: process.stdin, + output: process.stdout +}); + +// Track received notifications for debugging resumability +let notificationCount = 0; + +// Global client and transport for interactive commands +let client: Client | null = null; +let transport: StreamableHTTPClientTransport | null = null; +let serverUrl = 'http://localhost:3000/mcp'; + async function main(): Promise { - // Create a new client with streamable HTTP transport - const client = new Client({ - name: 'example-client', - version: '1.0.0' - }); + console.log('MCP Interactive Client'); + console.log('====================='); - const transport = new StreamableHTTPClientTransport( - new URL('http://localhost:3000/mcp') - ); + // Connect to server immediately with default settings + await connect(); - // Connect the client using the transport and initialize the server - await client.connect(transport); - console.log('Connected to MCP server'); + // Print help and start the command loop + printHelp(); + commandLoop(); +} - // Set up notification handlers for server-initiated messages - client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { - console.log(`Notification received: ${notification.params.level} - ${notification.params.data}`); - }); - client.setNotificationHandler(ResourceListChangedNotificationSchema, async (_) => { - console.log(`Resource list changed notification received!`); - const resourcesRequest: ListResourcesRequest = { - method: 'resources/list', - params: {} - }; - const resourcesResult = await client.request(resourcesRequest, ListResourcesResultSchema); - console.log('Available resources count:', resourcesResult.resources.length); - }); +function printHelp(): void { + console.log('\nAvailable commands:'); + console.log(' connect [url] - Connect to MCP server (default: http://localhost:3000/mcp)'); + console.log(' disconnect - Disconnect from server'); + console.log(' reconnect - Reconnect to the server'); + console.log(' list-tools - List available tools'); + console.log(' call-tool [args] - Call a tool with optional JSON arguments'); + console.log(' greet [name] - Call the greet tool'); + console.log(' multi-greet [name] - Call the multi-greet tool with notifications'); + console.log(' start-notifications [interval] [count] - Start periodic notifications'); + console.log(' list-prompts - List available prompts'); + console.log(' get-prompt [name] [args] - Get a prompt with optional JSON arguments'); + console.log(' list-resources - List available resources'); + console.log(' help - Show this help'); + console.log(' quit - Exit the program'); +} - // List and call tools - await listTools(client); +function commandLoop(): void { + readline.question('\n> ', async (input) => { + const args = input.trim().split(/\s+/); + const command = args[0]?.toLowerCase(); - await callGreetTool(client); - await callMultiGreetTool(client); + try { + switch (command) { + case 'connect': + await connect(args[1]); + break; + case 'disconnect': + await disconnect(); + break; - // List available prompts - try { - const promptsRequest: ListPromptsRequest = { - method: 'prompts/list', - params: {} - }; - const promptsResult = await client.request(promptsRequest, ListPromptsResultSchema); - console.log('Available prompts:', promptsResult.prompts); - } catch (error) { - console.log(`Prompts not supported by this server (${error})`); + case 'reconnect': + await reconnect(); + break; + + case 'list-tools': + await listTools(); + break; + + case 'call-tool': + if (args.length < 2) { + console.log('Usage: call-tool [args]'); + } else { + const toolName = args[1]; + let toolArgs = {}; + if (args.length > 2) { + try { + toolArgs = JSON.parse(args.slice(2).join(' ')); + } catch { + console.log('Invalid JSON arguments. Using empty args.'); + } + } + await callTool(toolName, toolArgs); + } + break; + + case 'greet': + await callGreetTool(args[1] || 'MCP User'); + break; + + case 'multi-greet': + await callMultiGreetTool(args[1] || 'MCP User'); + break; + + case 'start-notifications': { + const interval = args[1] ? parseInt(args[1], 10) : 2000; + const count = args[2] ? parseInt(args[2], 10) : 0; + await startNotifications(interval, count); + break; + } + + case 'list-prompts': + await listPrompts(); + break; + + case 'get-prompt': + if (args.length < 2) { + console.log('Usage: get-prompt [args]'); + } else { + const promptName = args[1]; + let promptArgs = {}; + if (args.length > 2) { + try { + promptArgs = JSON.parse(args.slice(2).join(' ')); + } catch { + console.log('Invalid JSON arguments. Using empty args.'); + } + } + await getPrompt(promptName, promptArgs); + } + break; + + case 'list-resources': + await listResources(); + break; + + case 'help': + printHelp(); + break; + + case 'quit': + case 'exit': + await cleanup(); + return; + + default: + if (command) { + console.log(`Unknown command: ${command}`); + } + break; + } + } catch (error) { + console.error(`Error executing command: ${error}`); + } + + // Continue the command loop + commandLoop(); + }); +} + +async function connect(url?: string): Promise { + if (client) { + console.log('Already connected. Disconnect first.'); + return; } - // Get a prompt + if (url) { + serverUrl = url; + } + + console.log(`Connecting to ${serverUrl}...`); + try { - const promptRequest: GetPromptRequest = { - method: 'prompts/get', - params: { - name: 'greeting-template', - arguments: { name: 'MCP User' } + // Create a new client + client = new Client({ + name: 'example-client', + version: '1.0.0' + }); + client.onerror = (error) => { + console.error('\x1b[31mClient error:', error, '\x1b[0m'); + } + + transport = new StreamableHTTPClientTransport( + new URL(serverUrl) + ); + + // Set up notification handlers + client.setNotificationHandler(LoggingMessageNotificationSchema, (notification) => { + notificationCount++; + console.log(`\nNotification #${notificationCount}: ${notification.params.level} - ${notification.params.data}`); + // Re-display the prompt + process.stdout.write('> '); + }); + + client.setNotificationHandler(ResourceListChangedNotificationSchema, async (_) => { + console.log(`\nResource list changed notification received!`); + try { + if (!client) { + console.log('Client disconnected, cannot fetch resources'); + return; + } + const resourcesResult = await client.request({ + method: 'resources/list', + params: {} + }, ListResourcesResultSchema); + console.log('Available resources count:', resourcesResult.resources.length); + } catch { + console.log('Failed to list resources after change notification'); } - }; - const promptResult = await client.request(promptRequest, GetPromptResultSchema); - console.log('Prompt template:', promptResult.messages[0].content.text); + // Re-display the prompt + process.stdout.write('> '); + }); + + // Connect the client + await client.connect(transport); + console.log('Connected to MCP server'); } catch (error) { - console.log(`Prompt retrieval not supported by this server (${error})`); + console.error('Failed to connect:', error); + client = null; + transport = null; + } +} + +async function disconnect(): Promise { + if (!client || !transport) { + console.log('Not connected.'); + return; } - // List available resources try { - const resourcesRequest: ListResourcesRequest = { - method: 'resources/list', - params: {} - }; - const resourcesResult = await client.request(resourcesRequest, ListResourcesResultSchema); - console.log('Available resources:', resourcesResult.resources); + await transport.close(); + console.log('Disconnected from MCP server'); + client = null; + transport = null; } catch (error) { - console.log(`Resources not supported by this server (${error})`); + console.error('Error disconnecting:', error); + } +} + +async function reconnect(): Promise { + if (client) { + await disconnect(); } - // Keep the connection open to receive notifications - console.log('\nKeeping connection open to receive notifications. Press Ctrl+C to exit.'); + await connect(); } -async function listTools(client: Client): Promise { +async function listTools(): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + try { const toolsRequest: ListToolsRequest = { method: 'tools/list', params: {} }; const toolsResult = await client.request(toolsRequest, ListToolsResultSchema); - console.log('Available tools:', toolsResult.tools); + + console.log('Available tools:'); if (toolsResult.tools.length === 0) { - console.log('No tools available from the server'); + console.log(' No tools available'); + } else { + for (const tool of toolsResult.tools) { + console.log(` - ${tool.name}: ${tool.description}`); + } } } catch (error) { console.log(`Tools not supported by this server (${error})`); - return } } -async function callGreetTool(client: Client): Promise { +async function callTool(name: string, args: Record): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + try { - const greetRequest: CallToolRequest = { + const request: CallToolRequest = { method: 'tools/call', params: { - name: 'greet', - arguments: { name: 'MCP User' } + name, + arguments: args + } + }; + + console.log(`Calling tool '${name}' with args:`, args); + const result = await client.request(request, CallToolResultSchema); + + console.log('Tool result:'); + result.content.forEach(item => { + if (item.type === 'text') { + console.log(` ${item.text}`); + } else { + console.log(` ${item.type} content:`, item); } + }); + } catch (error) { + console.log(`Error calling tool ${name}: ${error}`); + } +} + +async function callGreetTool(name: string): Promise { + await callTool('greet', { name }); +} + +async function callMultiGreetTool(name: string): Promise { + console.log('Calling multi-greet tool with notifications...'); + await callTool('multi-greet', { name }); +} + +async function startNotifications(interval: number, count: number): Promise { + console.log(`Starting notification stream: interval=${interval}ms, count=${count || 'unlimited'}`); + await callTool('start-notification-stream', { interval, count }); +} + +async function listPrompts(): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const promptsRequest: ListPromptsRequest = { + method: 'prompts/list', + params: {} }; - const greetResult = await client.request(greetRequest, CallToolResultSchema); - console.log('Greeting result:', greetResult.content[0].text); + const promptsResult = await client.request(promptsRequest, ListPromptsResultSchema); + console.log('Available prompts:'); + if (promptsResult.prompts.length === 0) { + console.log(' No prompts available'); + } else { + for (const prompt of promptsResult.prompts) { + console.log(` - ${prompt.name}: ${prompt.description}`); + } + } } catch (error) { - console.log(`Error calling greet tool: ${error}`); + console.log(`Prompts not supported by this server (${error})`); } } -async function callMultiGreetTool(client: Client): Promise { +async function getPrompt(name: string, args: Record): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + try { - console.log('\nCalling multi-greet tool (with notifications)...'); - const multiGreetRequest: CallToolRequest = { - method: 'tools/call', + const promptRequest: GetPromptRequest = { + method: 'prompts/get', params: { - name: 'multi-greet', - arguments: { name: 'MCP User' } + name, + arguments: args as Record } }; - const multiGreetResult = await client.request(multiGreetRequest, CallToolResultSchema); - console.log('Multi-greet results:'); - multiGreetResult.content.forEach(item => { - if (item.type === 'text') { - console.log(`- ${item.text}`); - } + + const promptResult = await client.request(promptRequest, GetPromptResultSchema); + console.log('Prompt template:'); + promptResult.messages.forEach((msg, index) => { + console.log(` [${index + 1}] ${msg.role}: ${msg.content.text}`); }); } catch (error) { - console.log(`Error calling multi-greet tool: ${error}`); + console.log(`Error getting prompt ${name}: ${error}`); } } +async function listResources(): Promise { + if (!client) { + console.log('Not connected to server.'); + return; + } + + try { + const resourcesRequest: ListResourcesRequest = { + method: 'resources/list', + params: {} + }; + const resourcesResult = await client.request(resourcesRequest, ListResourcesResultSchema); + + console.log('Available resources:'); + if (resourcesResult.resources.length === 0) { + console.log(' No resources available'); + } else { + for (const resource of resourcesResult.resources) { + console.log(` - ${resource.name}: ${resource.uri}`); + } + } + } catch (error) { + console.log(`Resources not supported by this server (${error})`); + } +} + +async function cleanup(): Promise { + if (client && transport) { + try { + await transport.close(); + } catch (error) { + console.error('Error closing transport:', error); + } + } + + + process.stdin.setRawMode(false); + readline.close(); + console.log('\nGoodbye!'); + process.exit(0); +} + +// Set up raw mode for keyboard input to capture Escape key +process.stdin.setRawMode(true); +process.stdin.on('data', async (data) => { + // Check for Escape key (27) + if (data.length === 1 && data[0] === 27) { + console.log('\nESC key pressed. Disconnecting from server...'); + + // Abort current operation and disconnect from server + if (client && transport) { + await disconnect(); + console.log('Disconnected. Press Enter to continue.'); + } else { + console.log('Not connected to server.'); + } + + // Re-display the prompt + process.stdout.write('> '); + } +}); + +// Handle Ctrl+C +process.on('SIGINT', async () => { + console.log('\nReceived SIGINT. Cleaning up...'); + await cleanup(); +}); + +// Start the interactive client main().catch((error: unknown) => { console.error('Error running MCP client:', error); process.exit(1); diff --git a/src/examples/server/simpleStreamableHttp.ts b/src/examples/server/simpleStreamableHttp.ts index f0f74439..153e35b7 100644 --- a/src/examples/server/simpleStreamableHttp.ts +++ b/src/examples/server/simpleStreamableHttp.ts @@ -1,9 +1,84 @@ import express, { Request, Response } from 'express'; import { randomUUID } from 'node:crypto'; import { McpServer } from '../../server/mcp.js'; -import { StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; +import { EventStore, StreamableHTTPServerTransport } from '../../server/streamableHttp.js'; import { z } from 'zod'; -import { CallToolResult, GetPromptResult, ReadResourceResult } from '../../types.js'; +import { CallToolResult, GetPromptResult, JSONRPCMessage, ReadResourceResult } from '../../types.js'; + +// Create a simple in-memory EventStore for resumability +class InMemoryEventStore implements EventStore { + private events: Map = new Map(); + + /** + * Generates a unique event ID for a given stream ID + */ + private generateEventId(streamId: string): string { + return `${streamId}_${Date.now()}_${Math.random().toString(36).substring(2, 10)}`; + } + + private getStreamIdFromEventId(eventId: string): string { + const parts = eventId.split('_'); + return parts.length > 0 ? parts[0] : ''; + } + + /** + * Stores an event with a generated event ID + * Implements EventStore.storeEvent + */ + async storeEvent(streamId: string, message: JSONRPCMessage): Promise { + const eventId = this.generateEventId(streamId); + console.log(`Storing event ${eventId} for stream ${streamId}`); + this.events.set(eventId, { streamId, message }); + return eventId; + } + + /** + * Replays events that occurred after a specific event ID + * Implements EventStore.replayEventsAfter + */ + async replayEventsAfter(lastEventId: string, + { send }: { send: (eventId: string, message: JSONRPCMessage) => Promise } + ): Promise { + if (!lastEventId || !this.events.has(lastEventId)) { + console.log(`No events found for lastEventId: ${lastEventId}`); + return ''; + } + + // Extract the stream ID from the event ID + const streamId = this.getStreamIdFromEventId(lastEventId); + if (!streamId) { + console.log(`Could not extract streamId from lastEventId: ${lastEventId}`); + return ''; + } + + let foundLastEvent = false; + let eventCount = 0; + + // Sort events by eventId for chronological ordering + const sortedEvents = [...this.events.entries()].sort((a, b) => a[0].localeCompare(b[0])); + + for (const [eventId, { streamId: eventStreamId, message }] of sortedEvents) { + // Only include events from the same stream + if (eventStreamId !== streamId) { + continue; + } + + // Start sending events after we find the lastEventId + if (eventId === lastEventId) { + foundLastEvent = true; + continue; + } + + if (foundLastEvent) { + await send(eventId, message); + eventCount++; + } + } + + console.log(`Replayed ${eventCount} events after ${lastEventId} for stream ${streamId}`); + return streamId; + } +} // Create an MCP server with implementation details const server = new McpServer({ @@ -92,6 +167,43 @@ server.prompt( } ); +// Register a tool specifically for testing resumability +server.tool( + 'start-notification-stream', + 'Starts sending periodic notifications for testing resumability', + { + interval: z.number().describe('Interval in milliseconds between notifications').default(100), + count: z.number().describe('Number of notifications to send (0 for 100)').default(50), + }, + async ({ interval, count }, { sendNotification }): Promise => { + const sleep = (ms: number) => new Promise(resolve => setTimeout(resolve, ms)); + let counter = 0; + + while (count === 0 || counter < count) { + counter++; + await sendNotification({ + method: "notifications/message", + params: { + level: "info", + data: `Periodic notification #${counter} at ${new Date().toISOString()}` + } + }); + + // Wait for the specified interval + await sleep(interval); + } + + return { + content: [ + { + type: 'text', + text: `Started sending periodic notifications every ${interval}ms`, + } + ], + }; + } +); + // Create a simple resource at a fixed URI server.resource( 'greeting-resource', @@ -127,8 +239,10 @@ app.post('/mcp', async (req: Request, res: Response) => { transport = transports[sessionId]; } else if (!sessionId && isInitializeRequest(req.body)) { // New initialization request + const eventStore = new InMemoryEventStore(); transport = new StreamableHTTPServerTransport({ sessionIdGenerator: () => randomUUID(), + eventStore, // Enable resumability }); // Connect the transport to the MCP server BEFORE handling the request @@ -182,7 +296,14 @@ app.get('/mcp', async (req: Request, res: Response) => { return; } - console.log(`Establishing SSE stream for session ${sessionId}`); + // Check for Last-Event-ID header for resumability + const lastEventId = req.headers['last-event-id'] as string | undefined; + if (lastEventId) { + console.log(`Client reconnecting with Last-Event-ID: ${lastEventId}`); + } else { + console.log(`Establishing new SSE stream for session ${sessionId}`); + } + const transport = transports[sessionId]; await transport.handleRequest(req, res); }); diff --git a/src/server/streamableHttp.test.ts b/src/server/streamableHttp.test.ts index 949cafc7..efd5de1c 100644 --- a/src/server/streamableHttp.test.ts +++ b/src/server/streamableHttp.test.ts @@ -1,7 +1,7 @@ import { createServer, type Server, IncomingMessage, ServerResponse } from "node:http"; import { AddressInfo } from "node:net"; import { randomUUID } from "node:crypto"; -import { StreamableHTTPServerTransport } from "./streamableHttp.js"; +import { EventStore, StreamableHTTPServerTransport, EventId, StreamId } from "./streamableHttp.js"; import { McpServer } from "./mcp.js"; import { CallToolResult, JSONRPCMessage } from "../types.js"; import { z } from "zod"; @@ -13,6 +13,7 @@ interface TestServerConfig { sessionIdGenerator?: () => string | undefined; enableJsonResponse?: boolean; customRequestHandler?: (req: IncomingMessage, res: ServerResponse, parsedBody?: unknown) => Promise; + eventStore?: EventStore; } /** @@ -26,7 +27,7 @@ async function createTestServer(config: TestServerConfig = {}): Promise<{ }> { const mcpServer = new McpServer( { name: "test-server", version: "1.0.0" }, - { capabilities: {} } + { capabilities: { logging: {} } } ); mcpServer.tool( @@ -40,7 +41,8 @@ async function createTestServer(config: TestServerConfig = {}): Promise<{ const transport = new StreamableHTTPServerTransport({ sessionIdGenerator: config.sessionIdGenerator ?? (() => randomUUID()), - enableJsonResponse: config.enableJsonResponse ?? false + enableJsonResponse: config.enableJsonResponse ?? false, + eventStore: config.eventStore }); await mcpServer.connect(transport); @@ -89,7 +91,10 @@ const TEST_MESSAGES = { params: { clientInfo: { name: "test-client", version: "1.0" }, protocolVersion: "2025-03-26", + capabilities: { + }, }, + id: "init-1", } as JSONRPCMessage, @@ -896,6 +901,165 @@ describe("StreamableHTTPServerTransport with pre-parsed body", () => { }); }); +// Test resumability support +describe("StreamableHTTPServerTransport with resumability", () => { + let server: Server; + let transport: StreamableHTTPServerTransport; + let baseUrl: URL; + let sessionId: string; + let mcpServer: McpServer; + const storedEvents: Map = new Map(); + + // Simple implementation of EventStore + const eventStore: EventStore = { + + async storeEvent(streamId: string, message: JSONRPCMessage): Promise { + const eventId = `${streamId}_${randomUUID()}`; + storedEvents.set(eventId, { eventId, message }); + return eventId; + }, + + async replayEventsAfter(lastEventId: EventId, { send }: { + send: (eventId: EventId, message: JSONRPCMessage) => Promise + }): Promise { + const streamId = lastEventId.split('_')[0]; + // Extract stream ID from the event ID + // For test simplicity, just return all events with matching streamId that aren't the lastEventId + for (const [eventId, { message }] of storedEvents.entries()) { + if (eventId.startsWith(streamId) && eventId !== lastEventId) { + await send(eventId, message); + } + } + return streamId; + }, + }; + + beforeEach(async () => { + storedEvents.clear(); + const result = await createTestServer({ + sessionIdGenerator: () => randomUUID(), + eventStore + }); + + server = result.server; + transport = result.transport; + baseUrl = result.baseUrl; + mcpServer = result.mcpServer; + + // Verify resumability is enabled on the transport + expect((transport)['_eventStore']).toBeDefined(); + + // Initialize the server + const initResponse = await sendPostRequest(baseUrl, TEST_MESSAGES.initialize); + sessionId = initResponse.headers.get("mcp-session-id") as string; + expect(sessionId).toBeDefined(); + }); + + afterEach(async () => { + await stopTestServer({ server, transport }); + storedEvents.clear(); + }); + + it("should store and include event IDs in server SSE messages", async () => { + // Open a standalone SSE stream + const sseResponse = await fetch(baseUrl, { + method: "GET", + headers: { + Accept: "text/event-stream", + "mcp-session-id": sessionId, + }, + }); + + expect(sseResponse.status).toBe(200); + expect(sseResponse.headers.get("content-type")).toBe("text/event-stream"); + + // Send a notification that should be stored with an event ID + const notification: JSONRPCMessage = { + jsonrpc: "2.0", + method: "notifications/message", + params: { level: "info", data: "Test notification with event ID" }, + }; + + // Send the notification via transport + await transport.send(notification); + + // Read from the stream and verify we got the notification with an event ID + const reader = sseResponse.body?.getReader(); + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + // The response should contain an event ID + expect(text).toContain('id: '); + expect(text).toContain('"method":"notifications/message"'); + + // Extract the event ID + const idMatch = text.match(/id: ([^\n]+)/); + expect(idMatch).toBeTruthy(); + + // Verify the event was stored + const eventId = idMatch![1]; + expect(storedEvents.has(eventId)).toBe(true); + const storedEvent = storedEvents.get(eventId); + expect(eventId.startsWith('_GET_stream')).toBe(true); + expect(storedEvent?.message).toMatchObject(notification); + }); + + + it("should store and replay MCP server tool notifications", async () => { + // Establish a standalone SSE stream + const sseResponse = await fetch(baseUrl, { + method: "GET", + headers: { + Accept: "text/event-stream", + "mcp-session-id": sessionId, + }, + }); + expect(sseResponse.status).toBe(200); // Send a server notification through the MCP server + await mcpServer.server.sendLoggingMessage({ level: "info", data: "First notification from MCP server" }); + + // Read the notification from the SSE stream + const reader = sseResponse.body?.getReader(); + const { value } = await reader!.read(); + const text = new TextDecoder().decode(value); + + // Verify the notification was sent with an event ID + expect(text).toContain('id: '); + expect(text).toContain('First notification from MCP server'); + + // Extract the event ID + const idMatch = text.match(/id: ([^\n]+)/); + expect(idMatch).toBeTruthy(); + const firstEventId = idMatch![1]; + + // Send a second notification + await mcpServer.server.sendLoggingMessage({ level: "info", data: "Second notification from MCP server" }); + + // Close the first SSE stream to simulate a disconnect + await reader!.cancel(); + + // Reconnect with the Last-Event-ID to get missed messages + const reconnectResponse = await fetch(baseUrl, { + method: "GET", + headers: { + Accept: "text/event-stream", + "mcp-session-id": sessionId, + "last-event-id": firstEventId + }, + }); + + expect(reconnectResponse.status).toBe(200); + + // Read the replayed notification + const reconnectReader = reconnectResponse.body?.getReader(); + const reconnectData = await reconnectReader!.read(); + const reconnectText = new TextDecoder().decode(reconnectData.value); + + // Verify we received the second notification that was sent after our stored eventId + expect(reconnectText).toContain('Second notification from MCP server'); + expect(reconnectText).toContain('id: '); + }); +}); + // Test stateless mode describe("StreamableHTTPServerTransport in stateless mode", () => { let server: Server; diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 0eaaa673..7ddfa3ab 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -1,11 +1,32 @@ import { IncomingMessage, ServerResponse } from "node:http"; import { Transport } from "../shared/transport.js"; -import { JSONRPCMessage, JSONRPCMessageSchema, RequestId } from "../types.js"; +import { isJSONRPCRequest, isJSONRPCResponse, JSONRPCMessage, JSONRPCMessageSchema, RequestId } from "../types.js"; import getRawBody from "raw-body"; import contentType from "content-type"; +import { randomUUID } from "node:crypto"; const MAXIMUM_MESSAGE_SIZE = "4mb"; +export type StreamId = string; +export type EventId = string; + +/** + * Interface for resumability support via event storage + */ +export interface EventStore { + /** + * Stores an event for later retrieval + * @param streamId ID of the stream the event belongs to + * @param message The JSON-RPC message to store + * @returns The generated event ID for the stored event + */ + storeEvent(streamId: StreamId, message: JSONRPCMessage): Promise; + + replayEventsAfter(lastEventId: EventId, { send }: { + send: (eventId: EventId, message: JSONRPCMessage) => Promise + }): Promise; +} + /** * Configuration options for StreamableHTTPServerTransport */ @@ -24,6 +45,12 @@ export interface StreamableHTTPServerTransportOptions { * Default is false (SSE streams are preferred). */ enableJsonResponse?: boolean; + + /** + * Event store for resumability support + * If provided, resumability will be enabled, allowing clients to reconnect and resume messages + */ + eventStore?: EventStore; } /** @@ -64,12 +91,13 @@ export class StreamableHTTPServerTransport implements Transport { // when sessionId is not set (undefined), it means the transport is in stateless mode private sessionIdGenerator: () => string | undefined; private _started: boolean = false; - private _responseMapping: Map = new Map(); + private _streamMapping: Map = new Map(); + private _requestToStreamMapping: Map = new Map(); private _requestResponseMap: Map = new Map(); private _initialized: boolean = false; private _enableJsonResponse: boolean = false; - private _standaloneSSE: ServerResponse | undefined; - + private _standaloneSseStreamId: string = '_GET_stream'; + private _eventStore?: EventStore; sessionId?: string | undefined; onclose?: () => void; @@ -79,6 +107,7 @@ export class StreamableHTTPServerTransport implements Transport { constructor(options: StreamableHTTPServerTransportOptions) { this.sessionIdGenerator = options.sessionIdGenerator; this._enableJsonResponse = options.enableJsonResponse ?? false; + this._eventStore = options.eventStore; } /** @@ -131,6 +160,14 @@ export class StreamableHTTPServerTransport implements Transport { if (!this.validateSession(req, res)) { return; } + // Handle resumability: check for Last-Event-ID header + if (this._eventStore) { + const lastEventId = req.headers['last-event-id'] as string | undefined; + if (lastEventId) { + await this.replayEvents(lastEventId, res); + return; + } + } // The server MUST either return Content-Type: text/event-stream in response to this HTTP GET, // or else return HTTP 405 Method Not Allowed @@ -144,12 +181,9 @@ export class StreamableHTTPServerTransport implements Transport { if (this.sessionId !== undefined) { headers["mcp-session-id"] = this.sessionId; } - // The server MAY include a Last-Event-ID header in the response to this HTTP GET. - // Resumability will be supported in the future // Check if there's already an active standalone SSE stream for this session - - if (this._standaloneSSE !== undefined) { + if (this._streamMapping.get(this._standaloneSseStreamId) !== undefined) { // Only one GET SSE stream is allowed per session res.writeHead(409).end(JSON.stringify({ jsonrpc: "2.0", @@ -161,19 +195,68 @@ export class StreamableHTTPServerTransport implements Transport { })); return; } - // We need to send headers immediately as message will arrive much later, + + // We need to send headers immediately as messages will arrive much later, // otherwise the client will just wait for the first message res.writeHead(200, headers).flushHeaders(); - // Assing the response to the standalone SSE stream - this._standaloneSSE = res; + // Assign the response to the standalone SSE stream + this._streamMapping.set(this._standaloneSseStreamId, res); // Set up close handler for client disconnects res.on("close", () => { - this._standaloneSSE = undefined; + this._streamMapping.delete(this._standaloneSseStreamId); }); } + /** + * Replays events that would have been sent after the specified event ID + * Only used when resumability is enabled + */ + private async replayEvents(lastEventId: string, res: ServerResponse): Promise { + if (!this._eventStore) { + return; + } + try { + const headers: Record = { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache, no-transform", + Connection: "keep-alive", + }; + + if (this.sessionId !== undefined) { + headers["mcp-session-id"] = this.sessionId; + } + res.writeHead(200, headers).flushHeaders(); + + const streamId = await this._eventStore?.replayEventsAfter(lastEventId, { + send: async (eventId: string, message: JSONRPCMessage) => { + if (!this.writeSSEEvent(res, message, eventId)) { + this.onerror?.(new Error("Failed replay events")); + res.end(); + } + } + }); + this._streamMapping.set(streamId, res); + } catch (error) { + this.onerror?.(error as Error); + } + } + + /** + * Writes an event to the SSE stream with proper formatting + */ + private writeSSEEvent(res: ServerResponse, message: JSONRPCMessage, eventId?: string): boolean { + let eventData = `event: message\n`; + // Include event ID if provided - this is important for resumability + if (eventId) { + eventData += `id: ${eventId}\n`; + } + eventData += `data: ${JSON.stringify(message)}\n\n`; + + return res.write(eventData); + } + /** * Handles unsupported requests (PUT, PATCH, etc.) */ @@ -285,11 +368,9 @@ export class StreamableHTTPServerTransport implements Transport { // check if it contains requests - const hasRequests = messages.some(msg => 'method' in msg && 'id' in msg); - const hasOnlyNotificationsOrResponses = messages.every(msg => - ('method' in msg && !('id' in msg)) || ('result' in msg || 'error' in msg)); + const hasRequests = messages.some(isJSONRPCRequest); - if (hasOnlyNotificationsOrResponses) { + if (!hasRequests) { // if it only contains notifications or responses, return 202 res.writeHead(202).end(); @@ -300,6 +381,7 @@ export class StreamableHTTPServerTransport implements Transport { } else if (hasRequests) { // The default behavior is to use SSE streaming // but in some cases server will return JSON responses + const streamId = randomUUID(); if (!this._enableJsonResponse) { const headers: Record = { "Content-Type": "text/event-stream", @@ -318,19 +400,22 @@ export class StreamableHTTPServerTransport implements Transport { // We need to track by request ID to maintain the connection for (const message of messages) { if ('method' in message && 'id' in message) { - this._responseMapping.set(message.id, res); + this._streamMapping.set(streamId, res); + this._requestToStreamMapping.set(message.id, streamId); } } // Set up close handler for client disconnects res.on("close", () => { + // find a stream ID for this response // Remove all entries that reference this response - for (const [id, storedRes] of this._responseMapping.entries()) { - if (storedRes === res) { - this._responseMapping.delete(id); + for (const [id, stream] of this._requestToStreamMapping.entries()) { + if (streamId === stream) { + this._requestToStreamMapping.delete(id); this._requestResponseMap.delete(id); } } + this._streamMapping.delete(streamId); }); // handle each message @@ -431,16 +516,13 @@ export class StreamableHTTPServerTransport implements Transport { async close(): Promise { // Close all SSE connections - this._responseMapping.forEach((response) => { + this._streamMapping.forEach((response) => { response.end(); }); - this._responseMapping.clear(); + this._streamMapping.clear(); // Clear any pending responses this._requestResponseMap.clear(); - this._standaloneSSE?.end(); - this._standaloneSSE = undefined; - this.onclose?.(); } @@ -459,32 +541,47 @@ export class StreamableHTTPServerTransport implements Transport { if ('result' in message || 'error' in message) { throw new Error("Cannot send a response on a standalone SSE stream unless resuming a previous client request"); } - - if (this._standaloneSSE === undefined) { + const standaloneSse = this._streamMapping.get(this._standaloneSseStreamId) + if (standaloneSse === undefined) { // The spec says the server MAY send messages on the stream, so it's ok to discard if no stream return; } + // Generate and store event ID if event store is provided + let eventId: string | undefined; + if (this._eventStore) { + // Stores the event and gets the generated event ID + eventId = await this._eventStore.storeEvent(this._standaloneSseStreamId, message); + } + // Send the message to the standalone SSE stream - this._standaloneSSE.write(`event: message\ndata: ${JSON.stringify(message)}\n\n`); + this.writeSSEEvent(standaloneSse, message, eventId); return; } // Get the response for this request - const response = this._responseMapping.get(requestId); - if (!response) { + const streamId = this._requestToStreamMapping.get(requestId); + const response = this._streamMapping.get(streamId!); + if (!streamId || !response) { throw new Error(`No connection established for request ID: ${String(requestId)}`); } if (!this._enableJsonResponse) { - response.write(`event: message\ndata: ${JSON.stringify(message)}\n\n`); + // For SSE responses, generate event ID if event store is provided + let eventId: string | undefined; + + if (this._eventStore) { + eventId = await this._eventStore.storeEvent(streamId, message); + } + + // Write the event to the response stream + this.writeSSEEvent(response, message, eventId); } - if ('result' in message || 'error' in message) { - this._requestResponseMap.set(requestId, message); - // Get all request IDs that share the same request response object - const relatedIds = Array.from(this._responseMapping.entries()) - .filter(([_, res]) => res === response) + if (isJSONRPCResponse(message)) { + this._requestResponseMap.set(requestId, message); + const relatedIds = Array.from(this._requestToStreamMapping.entries()) + .filter(([_, streamId]) => this._streamMapping.get(streamId) === response) .map(([id]) => id); // Check if we have responses for all requests using this connection @@ -516,7 +613,7 @@ export class StreamableHTTPServerTransport implements Transport { // Clean up for (const id of relatedIds) { this._requestResponseMap.delete(id); - this._responseMapping.delete(id); + this._requestToStreamMapping.delete(id); } } }