From 0576a88a0b1cb096b9085d84be3e2e6ed2fde100 Mon Sep 17 00:00:00 2001 From: Zachary German Date: Mon, 2 Jun 2025 20:39:12 +0000 Subject: [PATCH] Adding StreamableHttpServerTransportProvider class and unit tests --- ...StreamableHttpServerTransportProvider.java | 765 ++++++++++++++++++ ...mableHttpServerTransportProviderTests.java | 397 +++++++++ 2 files changed, 1162 insertions(+) create mode 100644 mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java create mode 100644 mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java new file mode 100644 index 00000000..44d6ee2a --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProvider.java @@ -0,0 +1,765 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Enumeration; +import java.util.function.Supplier; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.util.Assert; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.annotation.WebServlet; +import jakarta.servlet.http.HttpServlet; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +/** + * Implementation of the MCP Streamable HTTP transport provider for servers. This + * implementation follows the Streamable HTTP transport specification from protocol + * version 2025-03-26. + * + *

+ * The transport handles a single HTTP endpoint that supports POST, GET, & DELETE methods: + *

+ * + *

+ * Features: + *

+ * + */ +@WebServlet(asyncSupported = true) +public class StreamableHttpServerTransportProvider extends HttpServlet implements McpServerTransportProvider { + + private static final Logger logger = LoggerFactory.getLogger(StreamableHttpServerTransportProvider.class); + + public static final String UTF_8 = "UTF-8"; + + public static final String APPLICATION_JSON = "application/json"; + + public static final String TEXT_EVENT_STREAM = "text/event-stream"; + + public static final String SESSION_ID_HEADER = "Mcp-Session-Id"; + + public static final String LAST_EVENT_ID_HEADER = "Last-Event-Id"; + + public static final String MESSAGE_EVENT_TYPE = "message"; + + public static final String ACCEPT_HEADER = "Accept"; + + public static final String ORIGIN_HEADER = "Origin"; + + public static final String CACHE_CONTROL_HEADER = "Cache-Control"; + + public static final String CONNECTION_HEADER = "Connection"; + + public static final String CACHE_CONTROL_NO_CACHE = "no-cache"; + + public static final String CONNECTION_KEEP_ALIVE = "keep-alive"; + + /** JSON object mapper for serialization/deserialization */ + private final ObjectMapper objectMapper; + + /** The endpoint path for handling MCP requests */ + private final String mcpEndpoint; + + /** Supplier for generating unique session IDs */ + private final Supplier sessionIdProvider; + + /** UUID.randomUUID().toString() */ + private static final Supplier DEFAULT_SESSION_ID_PROVIDER = () -> UUID.randomUUID().toString(); + + /** Map of active client sessions, keyed by session ID */ + private final Map sessions = new ConcurrentHashMap<>(); + + /** Map of active SSE streams, keyed by session ID */ + private final Map sseStreams = new ConcurrentHashMap<>(); + + /** Flag indicating if the transport is in the process of shutting down */ + private final AtomicBoolean isClosing = new AtomicBoolean(false); + + /** Session factory for creating new sessions */ + private McpServerSession.Factory sessionFactory; + + /** + * Creates a new StreamableHttpServerTransportProvider instance. + * @param objectMapper The JSON object mapper to use for message + * serialization/deserialization + * @param mcpEndpoint The endpoint path for handling MCP requests + * @param sessionIdProvider optional Supplier for providing unique session IDs + */ + public StreamableHttpServerTransportProvider(ObjectMapper objectMapper, String mcpEndpoint, + Supplier sessionIdProvider) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + Assert.hasText(mcpEndpoint, "MCP endpoint must not be empty"); + + this.objectMapper = objectMapper; + this.mcpEndpoint = mcpEndpoint; + this.sessionIdProvider = Objects.requireNonNullElse(sessionIdProvider, DEFAULT_SESSION_ID_PROVIDER); + } + + @Override + public void setSessionFactory(McpServerSession.Factory sessionFactory) { + this.sessionFactory = sessionFactory; + } + + @Override + public Mono notifyClients(String method, Object params) { + if (sessions.isEmpty()) { + logger.debug("No active sessions to broadcast message to"); + return Mono.empty(); + } + + logger.debug("Attempting to broadcast message to {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()) + .flatMap(session -> session.sendNotification(method, params) + .doOnError( + e -> logger.error("Failed to send message to session {}: {}", session.getId(), e.getMessage())) + .onErrorComplete()) + .then(); + } + + @Override + public Mono closeGracefully() { + isClosing.set(true); + logger.debug("Initiating graceful shutdown with {} active sessions", sessions.size()); + + return Flux.fromIterable(sessions.values()).flatMap(McpServerSession::closeGracefully).then(); + } + + /** + * Handles HTTP GET requests to establish SSE connections. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doGet(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + logger.debug("GET request received for URI: {}", requestURI); + + // Log all headers for debugging + Enumeration headerNames = request.getHeaderNames(); + while (headerNames.hasMoreElements()) { + String headerName = headerNames.nextElement(); + logger.debug("Header: {} = {}", headerName, request.getHeader(headerName)); + } + + if (!requestURI.endsWith(mcpEndpoint)) { + logger.debug("URI does not match mcpEndpoint: {}", mcpEndpoint); + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (isClosing.get()) { + logger.debug("Server is shutting down, rejecting request"); + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + String acceptHeader = request.getHeader(ACCEPT_HEADER); + logger.debug("Accept header: {}", acceptHeader); + if (acceptHeader == null || !acceptHeader.contains(TEXT_EVENT_STREAM)) { + logger.debug("Accept header missing or does not include {}", TEXT_EVENT_STREAM); + response.setContentType(APPLICATION_JSON); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + response.getWriter().write(createErrorJson("Accept header must include text/event-stream")); + return; + } + + String sessionId = request.getHeader(SESSION_ID_HEADER); + if (sessionId == null) { + response.setContentType(APPLICATION_JSON); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + response.getWriter().write(createErrorJson("Session ID missing in request header")); + return; + } + + McpServerSession session = sessions.get(sessionId); + if (session == null) { + response.setContentType(APPLICATION_JSON); + response.setStatus(HttpServletResponse.SC_NOT_FOUND); + response.getWriter().write(createErrorJson("Session not found: " + sessionId)); + return; + } + + // Set up SSE connection + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setHeader(CACHE_CONTROL_HEADER, CACHE_CONTROL_NO_CACHE); + response.setHeader(CONNECTION_HEADER, CONNECTION_KEEP_ALIVE); + response.setHeader(SESSION_ID_HEADER, sessionId); + + // Start async processing + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); // No timeout + + // Check for Last-Event-ID header for resumable streams + String lastEventId = request.getHeader(LAST_EVENT_ID_HEADER); + + // Create or get SSE stream for this session + StreamableHttpSseStream sseStream = getOrCreateSseStream(sessionId); + if (lastEventId != null) { + sseStream.replayEventsAfter(lastEventId); + } + + PrintWriter writer = response.getWriter(); + + // Subscribe to the SSE stream and write events to the response + sseStream.getEventFlux().doOnNext(event -> { + try { + if (event.id() != null) { + writer.write("id: " + event.id() + "\n"); + } + if (event.event() != null) { + writer.write("event: " + event.event() + "\n"); + } + writer.write("data: " + event.data() + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + catch (IOException e) { + logger.debug("Error writing to SSE stream: {}", e.getMessage()); + asyncContext.complete(); + } + }).doOnComplete(() -> { + try { + writer.close(); + } + finally { + asyncContext.complete(); + } + }).doOnError(e -> { + logger.error("Error in SSE stream: {}", e.getMessage()); + asyncContext.complete(); + }).subscribe(); + } + + /** + * Handles HTTP POST requests for client messages. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doPost(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + logger.debug("POST request received for URI: {}", requestURI); + + // Log all headers for debugging + Enumeration headerNames = request.getHeaderNames(); + while (headerNames.hasMoreElements()) { + String headerName = headerNames.nextElement(); + logger.debug("Header: {} = {}", headerName, request.getHeader(headerName)); + } + + if (!requestURI.endsWith(mcpEndpoint)) { + logger.debug("URI does not match mcpEndpoint: {}", mcpEndpoint); + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + if (isClosing.get()) { + logger.debug("Server is shutting down, rejecting request"); + response.sendError(HttpServletResponse.SC_SERVICE_UNAVAILABLE, "Server is shutting down"); + return; + } + + // According to spec, client MUST include an Accept header listing both + // application/json and text/event-stream + String acceptHeader = request.getHeader(ACCEPT_HEADER); + logger.debug("Accept header: {}", acceptHeader); + if (acceptHeader == null + || (!acceptHeader.contains(APPLICATION_JSON) || !acceptHeader.contains(TEXT_EVENT_STREAM))) { + logger.debug("Accept header validation failed. Header: {}", acceptHeader); + response.setContentType(APPLICATION_JSON); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + response.getWriter() + .write(createErrorJson("Accept header must include both application/json and text/event-stream")); + return; + } + + // Client accepts SSE since we've validated the Accept header contains + // text/event-stream + boolean acceptsEventStream = true; + + // Get session ID from header + String sessionId = request.getHeader(SESSION_ID_HEADER); + boolean isInitializeRequest = false; + + try { + // Read request body + StringBuilder body = new StringBuilder(); + try (BufferedReader reader = request.getReader()) { + String line; + while ((line = reader.readLine()) != null) { + body.append(line); + } + } + + // Parse the JSON-RPC message + JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body.toString()); + + // Check if this is an initialize request + if (message instanceof McpSchema.JSONRPCRequest req && McpSchema.METHOD_INITIALIZE.equals(req.method())) { + isInitializeRequest = true; + // For initialize requests, create a new session if one doesn't exist + if (sessionId == null) { + sessionId = sessionIdProvider.get(); + logger.debug("Created new session ID for initialize request: {}", sessionId); + } + } + + // Validate session ID for non-initialize requests + if (!isInitializeRequest && sessionId == null) { + response.setContentType(APPLICATION_JSON); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + response.getWriter().write(createErrorJson("Session ID missing in request header")); + return; + } + + // Get or create session + McpServerSession session = getOrCreateSession(sessionId, isInitializeRequest); + if (session == null && !isInitializeRequest) { + response.setContentType(APPLICATION_JSON); + response.setStatus(HttpServletResponse.SC_NOT_FOUND); + response.getWriter().write(createErrorJson("Session not found: " + sessionId)); + return; + } + + // Handle the message + session.handle(message).block(); // Block for servlet compatibility + + // Set session ID header in response + response.setHeader(SESSION_ID_HEADER, sessionId); + + // For requests that expect responses, we need to set up an SSE stream + if (message instanceof McpSchema.JSONRPCRequest && acceptsEventStream) { + // Set up SSE connection + response.setContentType(TEXT_EVENT_STREAM); + response.setCharacterEncoding(UTF_8); + response.setHeader(CACHE_CONTROL_HEADER, CACHE_CONTROL_NO_CACHE); + response.setHeader(CONNECTION_HEADER, CONNECTION_KEEP_ALIVE); + + // Start async processing + AsyncContext asyncContext = request.startAsync(); + asyncContext.setTimeout(0); // No timeout + + StreamableHttpSseStream sseStream = getOrCreateSseStream(sessionId); + PrintWriter writer = response.getWriter(); + + // For initialize requests, include the session ID in the response + if (isInitializeRequest) { + response.setHeader(SESSION_ID_HEADER, sessionId); + } + + // Subscribe to the SSE stream and write events to the response + sseStream.getEventFlux().doOnNext(event -> { + try { + if (event.id() != null) { + writer.write("id: " + event.id() + "\n"); + } + if (event.event() != null) { + writer.write("event: " + event.event() + "\n"); + } + writer.write("data: " + event.data() + "\n\n"); + writer.flush(); + + if (writer.checkError()) { + throw new IOException("Client disconnected"); + } + } + catch (IOException e) { + logger.debug("Error writing to SSE stream: {}", e.getMessage()); + asyncContext.complete(); + } + }).doOnComplete(() -> { + try { + writer.close(); + } + finally { + asyncContext.complete(); + } + }).doOnError(e -> { + logger.error("Error in SSE stream: {}", e.getMessage()); + asyncContext.complete(); + }).subscribe(); + } + else if (message instanceof McpSchema.JSONRPCRequest) { + // Client doesn't accept SSE, we'll return a regular JSON response + response.setContentType(APPLICATION_JSON); + response.setStatus(HttpServletResponse.SC_OK); + // The actual response would be sent later through another channel + } + else { + // For notifications and responses, return 202 Accepted + response.setStatus(HttpServletResponse.SC_ACCEPTED); + } + } + catch (Exception e) { + logger.error("Error processing message: {}", e.getMessage()); + response.setContentType(APPLICATION_JSON); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + response.getWriter().write(createErrorJson("Invalid JSON-RPC message: " + e.getMessage())); + } + } + + /** + * Handles HTTP DELETE requests to terminate sessions. + * @param request The HTTP servlet request + * @param response The HTTP servlet response + * @throws ServletException If a servlet-specific error occurs + * @throws IOException If an I/O error occurs + */ + @Override + protected void doDelete(HttpServletRequest request, HttpServletResponse response) + throws ServletException, IOException { + + String requestURI = request.getRequestURI(); + if (!requestURI.endsWith(mcpEndpoint)) { + response.sendError(HttpServletResponse.SC_NOT_FOUND); + return; + } + + String sessionId = request.getHeader(SESSION_ID_HEADER); + if (sessionId == null) { + response.setContentType(APPLICATION_JSON); + response.setStatus(HttpServletResponse.SC_BAD_REQUEST); + response.getWriter().write(createErrorJson("Session ID missing in request header")); + return; + } + + McpServerSession session = sessions.remove(sessionId); + if (session == null) { + response.setContentType(APPLICATION_JSON); + response.setStatus(HttpServletResponse.SC_NOT_FOUND); + response.getWriter().write(createErrorJson("Session not found: " + sessionId)); + return; + } + + // Close the session and any associated SSE stream + StreamableHttpSseStream sseStream = sseStreams.remove(sessionId); + if (sseStream != null) { + sseStream.complete(); + } + + session.close(); + logger.debug("Session terminated: {}", sessionId); + + response.setStatus(HttpServletResponse.SC_OK); + } + + /** + * Gets or creates a session for the given session ID. + * @param sessionId The session ID + * @param createIfMissing Whether to create a new session if one doesn't exist + * @return The session, or null if it doesn't exist and createIfMissing is false + */ + private McpServerSession getOrCreateSession(String sessionId, boolean createIfMissing) { + McpServerSession session = sessions.get(sessionId); + if (session == null && createIfMissing) { + StreamableHttpServerTransport transport = new StreamableHttpServerTransport(sessionId); + session = sessionFactory.create(transport); + sessions.put(sessionId, session); + logger.debug("Created new session: {}", sessionId); + } + return session; + } + + /** + * Gets or creates an SSE stream for the given session ID. + * @param sessionId The session ID + * @return The SSE stream + */ + private StreamableHttpSseStream getOrCreateSseStream(String sessionId) { + return sseStreams.computeIfAbsent(sessionId, id -> { + StreamableHttpSseStream stream = new StreamableHttpSseStream(); + logger.debug("Created new SSE stream for session: {}", id); + return stream; + }); + } + + /** + * Creates a JSON error response. + * @param message The error message + * @return The JSON error string + */ + private String createErrorJson(String message) { + try { + return objectMapper.writeValueAsString(new McpError(message)); + } + catch (IOException e) { + logger.error("Failed to serialize error message", e); + return "{\"error\":\"" + message + "\"}"; + } + } + + /** + * Implementation of McpServerTransport for Streamable HTTP sessions. + */ + private class StreamableHttpServerTransport implements McpServerTransport { + + private final String sessionId; + + /** + * Creates a new session transport with the specified ID. + * @param sessionId The unique identifier for this session + */ + StreamableHttpServerTransport(String sessionId) { + this.sessionId = sessionId; + logger.debug("Session transport {} initialized", sessionId); + } + + @Override + public Mono sendMessage(JSONRPCMessage message) { + StreamableHttpSseStream sseStream = sseStreams.get(sessionId); + if (sseStream == null) { + logger.debug("No SSE stream available for session {}, message will be queued for next connection", + sessionId); + // Create a stream that will hold messages until a client connects + sseStream = getOrCreateSseStream(sessionId); + } + + try { + String jsonText = objectMapper.writeValueAsString(message); + sseStream.sendEvent(MESSAGE_EVENT_TYPE, jsonText); + logger.debug("Message sent to session {}", sessionId); + + // For responses to requests, we need to complete the stream to avoid + // hanging + if (message instanceof McpSchema.JSONRPCResponse) { + logger.debug("Completing SSE stream after sending response for session {}", sessionId); + sseStream.complete(); + } + + return Mono.empty(); + } + catch (Exception e) { + logger.error("Failed to send message to session {}: {}", sessionId, e.getMessage()); + return Mono.error(e); + } + } + + @Override + public T unmarshalFrom(Object data, TypeReference typeRef) { + return objectMapper.convertValue(data, typeRef); + } + + @Override + public Mono closeGracefully() { + return Mono.fromRunnable(() -> { + logger.debug("Closing session transport: {}", sessionId); + sessions.remove(sessionId); + StreamableHttpSseStream sseStream = sseStreams.remove(sessionId); + if (sseStream != null) { + sseStream.complete(); + } + }); + } + + } + + /** + * Represents an SSE stream for a client connection. + */ + public class StreamableHttpSseStream { + + private final Sinks.Many eventSink = Sinks.many().multicast().onBackpressureBuffer(); + + private final Map eventHistory = new ConcurrentHashMap<>(); + + private long eventCounter = 0; + + /** + * Sends an event on this SSE stream. + * @param eventType The event type + * @param data The event data + */ + public void sendEvent(String eventType, String data) { + String eventId = String.valueOf(++eventCounter); + SseEvent event = new SseEvent(eventId, eventType, data); + eventHistory.put(eventId, event); + eventSink.tryEmitNext(event); + } + + /** + * Gets the Flux of SSE events for this stream. + * @return The Flux of SSE events + */ + public Flux getEventFlux() { + return eventSink.asFlux(); + } + + /** + * Replays events that occurred after the specified event ID. + * @param lastEventId The last event ID received by the client + */ + public void replayEventsAfter(String lastEventId) { + try { + long lastId = Long.parseLong(lastEventId); + for (long i = lastId + 1; i <= eventCounter; i++) { + SseEvent event = eventHistory.get(String.valueOf(i)); + if (event != null) { + eventSink.tryEmitNext(event); + } + } + } + catch (NumberFormatException e) { + logger.warn("Invalid last event ID: {}", lastEventId); + } + } + + /** + * Completes this SSE stream. + */ + public void complete() { + eventSink.tryEmitComplete(); + } + + } + + /** + * Represents an SSE event. + */ + public record SseEvent(String id, String event, String data) { + } + + /** + * Cleans up resources when the servlet is being destroyed. + */ + @Override + public void destroy() { + closeGracefully().block(); + super.destroy(); + } + + /** + * Helper method to extract headers from an HTTP request. + * @param request The HTTP servlet request + * @return A map of header names to values + */ + private Map extractHeaders(HttpServletRequest request) { + Map headers = new HashMap<>(); + Enumeration headerNames = request.getHeaderNames(); + while (headerNames.hasMoreElements()) { + String name = headerNames.nextElement(); + headers.put(name, request.getHeader(name)); + } + return headers; + } + + /** + * Creates a new Builder instance for configuring and creating instances of + * StreamableHttpServerTransportProvider. + * @return A new Builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating instances of StreamableHttpServerTransportProvider. + */ + public static class Builder { + + private ObjectMapper objectMapper = new ObjectMapper(); + + private String mcpEndpoint; + + private Supplier sessionIdProvider = () -> UUID.randomUUID().toString(); + + /** + * Sets the JSON object mapper to use for message serialization/deserialization. + * @param objectMapper The object mapper to use + * @return This builder instance for method chaining + */ + public Builder withObjectMapper(ObjectMapper objectMapper) { + Assert.notNull(objectMapper, "ObjectMapper must not be null"); + this.objectMapper = objectMapper; + return this; + } + + /** + * Sets the MCP endpoint path. + * @param mcpEndpoint The MCP endpoint path + * @return This builder instance for method chaining + */ + public Builder withMcpEndpoint(String mcpEndpoint) { + Assert.hasText(mcpEndpoint, "MCP endpoint must not be empty"); + this.mcpEndpoint = mcpEndpoint; + return this; + } + + /** + * Sets the session ID provider. + * @param sessionIdProvider The supplier for generating session IDs + * @return This builder instance for method chaining + */ + public Builder withSessionIdProvider(Supplier sessionIdProvider) { + Assert.notNull(sessionIdProvider, "SessionIdProvider must not be null"); + this.sessionIdProvider = sessionIdProvider; + return this; + } + + /** + * Builds a new instance of StreamableHttpServerTransportProvider with the + * configured settings. + * @return A new StreamableHttpServerTransportProvider instance + * @throws IllegalStateException if objectMapper or mcpEndpoint is not set + */ + public StreamableHttpServerTransportProvider build() { + if (objectMapper == null) { + throw new IllegalStateException("ObjectMapper must be set"); + } + if (mcpEndpoint == null) { + throw new IllegalStateException("MCP endpoint must be set"); + } + return new StreamableHttpServerTransportProvider(objectMapper, mcpEndpoint, sessionIdProvider); + } + + } + +} \ No newline at end of file diff --git a/mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java new file mode 100644 index 00000000..b0cd3f17 --- /dev/null +++ b/mcp/src/test/java/io/modelcontextprotocol/server/transport/StreamableHttpServerTransportProviderTests.java @@ -0,0 +1,397 @@ +/* + * Copyright 2024-2024 the original author or authors. + */ + +package io.modelcontextprotocol.server.transport; + +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +import com.fasterxml.jackson.databind.ObjectMapper; +import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; +import io.modelcontextprotocol.spec.McpSchema.JSONRPCRequest; +import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.McpServerTransport; +import jakarta.servlet.AsyncContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link StreamableHttpServerTransportProvider}. + */ +class StreamableHttpServerTransportProviderTests { + + private StreamableHttpServerTransportProvider transportProvider; + + private ObjectMapper objectMapper; + + private McpServerSession.Factory sessionFactory; + + private McpServerSession mockSession; + + private McpServerTransport capturedTransport; + + @BeforeEach + void setUp() { + objectMapper = new ObjectMapper(); + + mockSession = mock(McpServerSession.class); + sessionFactory = mock(McpServerSession.Factory.class); + + when(sessionFactory.create(any(McpServerTransport.class))).thenAnswer(invocation -> { + capturedTransport = invocation.getArgument(0); + return mockSession; + }); + when(mockSession.closeGracefully()).thenReturn(Mono.empty()); + when(mockSession.sendNotification(any(), any())).thenReturn(Mono.empty()); + when(mockSession.handle(any(JSONRPCMessage.class))).thenReturn(Mono.empty()); + when(mockSession.getId()).thenReturn("test-session-id"); + + transportProvider = new StreamableHttpServerTransportProvider(objectMapper, "/mcp", null); + transportProvider.setSessionFactory(sessionFactory); + } + + @Test + void shouldNotifyClients() { + String sessionId = UUID.randomUUID().toString(); + Map sessions = new ConcurrentHashMap<>(); + sessions.put(sessionId, mockSession); + + // Use reflection to set the sessions map in the transport provider + try { + java.lang.reflect.Field sessionsField = StreamableHttpServerTransportProvider.class + .getDeclaredField("sessions"); + sessionsField.setAccessible(true); + sessionsField.set(transportProvider, sessions); + } + catch (Exception e) { + throw new RuntimeException("Failed to set sessions field", e); + } + + String method = "testNotification"; + Map params = Map.of("key", "value"); + StepVerifier.create(transportProvider.notifyClients(method, params)).verifyComplete(); + + verify(mockSession).sendNotification(eq(method), eq(params)); + } + + @Test + void shouldCloseGracefully() { + String sessionId = UUID.randomUUID().toString(); + Map sessions = new ConcurrentHashMap<>(); + sessions.put(sessionId, mockSession); + + // Use reflection to set the sessions map in the transport provider + try { + java.lang.reflect.Field sessionsField = StreamableHttpServerTransportProvider.class + .getDeclaredField("sessions"); + sessionsField.setAccessible(true); + sessionsField.set(transportProvider, sessions); + } + catch (Exception e) { + throw new RuntimeException("Failed to set sessions field", e); + } + + StepVerifier.create(transportProvider.closeGracefully()).verifyComplete(); + + verify(mockSession).closeGracefully(); + } + + @Test + void shouldHandlePostRequestForInitialize() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter writer = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getHeader("Accept")).thenReturn("application/json, text/event-stream"); + when(request.getHeader(StreamableHttpServerTransportProvider.SESSION_ID_HEADER)).thenReturn(null); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + String initializeRequest = "{\"jsonrpc\":\"2.0\",\"method\":\"initialize\",\"params\":{\"protocolVersion\":\"2024-11-05\",\"capabilities\":{},\"clientInfo\":{\"name\":\"test-client\",\"version\":\"1.0.0\"}},\"id\":1}"; + when(request.getReader()).thenReturn(new java.io.BufferedReader(new java.io.StringReader(initializeRequest))); + when(response.getWriter()).thenReturn(writer); + AsyncContext asyncContext = mock(AsyncContext.class); + when(request.startAsync()).thenReturn(asyncContext); + + transportProvider.doPost(request, response); + + verify(sessionFactory).create(any(McpServerTransport.class)); + ArgumentCaptor messageCaptor = ArgumentCaptor.forClass(JSONRPCMessage.class); + verify(mockSession).handle(messageCaptor.capture()); + JSONRPCMessage capturedMessage = messageCaptor.getValue(); + assertThat(capturedMessage).isInstanceOf(JSONRPCRequest.class); + JSONRPCRequest capturedRequest = (JSONRPCRequest) capturedMessage; + assertThat(capturedRequest.method()).isEqualTo(McpSchema.METHOD_INITIALIZE); + verify(response, atLeastOnce()).setHeader(eq(StreamableHttpServerTransportProvider.SESSION_ID_HEADER), + anyString()); + } + + @Test + void shouldHandlePostRequestWithExistingSession() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + String sessionId = UUID.randomUUID().toString(); + PrintWriter writer = new PrintWriter(stringWriter); + Map sessions = new HashMap<>(); + sessions.put(sessionId, mockSession); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getHeader("Accept")).thenReturn("application/json, text/event-stream"); + when(request.getHeader(StreamableHttpServerTransportProvider.SESSION_ID_HEADER)).thenReturn(sessionId); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + String toolCallRequest = "{\"jsonrpc\":\"2.0\",\"method\":\"tools/call\",\"params\":{\"name\":\"test-tool\",\"arguments\":{}},\"id\":2}"; + when(request.getReader()).thenReturn(new java.io.BufferedReader(new java.io.StringReader(toolCallRequest))); + when(response.getWriter()).thenReturn(writer); + + // Use reflection to set the sessions map in the transport provider + try { + java.lang.reflect.Field sessionsField = StreamableHttpServerTransportProvider.class + .getDeclaredField("sessions"); + sessionsField.setAccessible(true); + sessionsField.set(transportProvider, sessions); + } + catch (Exception e) { + throw new RuntimeException("Failed to set sessions field", e); + } + + transportProvider.doPost(request, response); + + ArgumentCaptor messageCaptor = ArgumentCaptor.forClass(JSONRPCMessage.class); + verify(mockSession).handle(messageCaptor.capture()); + JSONRPCMessage capturedMessage = messageCaptor.getValue(); + assertThat(capturedMessage).isInstanceOf(JSONRPCRequest.class); + JSONRPCRequest capturedRequest = (JSONRPCRequest) capturedMessage; + assertThat(capturedRequest.method()).isEqualTo(McpSchema.METHOD_TOOLS_CALL); + verify(response).setHeader(eq(StreamableHttpServerTransportProvider.SESSION_ID_HEADER), eq(sessionId)); + } + + @Test + void shouldHandleGetRequest() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + String sessionId = UUID.randomUUID().toString(); + AsyncContext asyncContext = mock(AsyncContext.class); + PrintWriter writer = new PrintWriter(stringWriter); + Map sessions = new HashMap<>(); + sessions.put(sessionId, mockSession); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader(StreamableHttpServerTransportProvider.SESSION_ID_HEADER)).thenReturn(sessionId); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(request.startAsync()).thenReturn(asyncContext); + when(response.getWriter()).thenReturn(writer); + + // Use reflection to set the sessions map in the transport provider + try { + java.lang.reflect.Field sessionsField = StreamableHttpServerTransportProvider.class + .getDeclaredField("sessions"); + sessionsField.setAccessible(true); + sessionsField.set(transportProvider, sessions); + } + catch (Exception e) { + throw new RuntimeException("Failed to set sessions field", e); + } + + transportProvider.doGet(request, response); + + verify(response).setContentType(eq(StreamableHttpServerTransportProvider.TEXT_EVENT_STREAM)); + verify(response).setCharacterEncoding(eq(StreamableHttpServerTransportProvider.UTF_8)); + verify(response).setHeader(eq("Cache-Control"), eq("no-cache")); + verify(response).setHeader(eq("Connection"), eq("keep-alive")); + verify(response).setHeader(eq(StreamableHttpServerTransportProvider.SESSION_ID_HEADER), eq(sessionId)); + verify(request).startAsync(); + verify(asyncContext).setTimeout(0); + } + + @Test + void shouldHandleDeleteRequest() throws IOException, ServletException { + // Mock HTTP request and response + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter writer = new PrintWriter(stringWriter); + String sessionId = UUID.randomUUID().toString(); + Map sessions = new HashMap<>(); + sessions.put(sessionId, mockSession); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getHeader(StreamableHttpServerTransportProvider.SESSION_ID_HEADER)).thenReturn(sessionId); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(writer); + + // Use reflection to set the sessions map in the transport provider + try { + java.lang.reflect.Field sessionsField = StreamableHttpServerTransportProvider.class + .getDeclaredField("sessions"); + sessionsField.setAccessible(true); + sessionsField.set(transportProvider, sessions); + } + catch (Exception e) { + throw new RuntimeException("Failed to set sessions field", e); + } + + transportProvider.doDelete(request, response); + + verify(mockSession).close(); + verify(response).setStatus(HttpServletResponse.SC_OK); + assertThat(sessions).isEmpty(); + } + + @Test + void shouldSendMessageThroughTransport() throws Exception { + String sessionId = UUID.randomUUID().toString(); + Map sessions = new HashMap<>(); + sessions.put(sessionId, mockSession); + + // Use reflection to set the sessions map in the transport provider + try { + java.lang.reflect.Field sessionsField = StreamableHttpServerTransportProvider.class + .getDeclaredField("sessions"); + sessionsField.setAccessible(true); + sessionsField.set(transportProvider, sessions); + } + catch (Exception e) { + throw new RuntimeException("Failed to set sessions field", e); + } + + // Create a message to send through a mocked SSE stream + JSONRPCMessage message = new McpSchema.JSONRPCResponse("2.0", 1, Map.of("protocolVersion", + McpSchema.LATEST_PROTOCOL_VERSION, "serverInfo", Map.of("name", "test-server", "version", "1.0.0")), + null); + + AtomicReference capturedEventData = new AtomicReference<>(); + + StreamableHttpServerTransportProvider.StreamableHttpSseStream mockSseStream = mock( + StreamableHttpServerTransportProvider.StreamableHttpSseStream.class); + doAnswer(invocation -> { + String eventType = invocation.getArgument(0); + String data = invocation.getArgument(1); + assertThat(eventType).isEqualTo(StreamableHttpServerTransportProvider.MESSAGE_EVENT_TYPE); + capturedEventData.set(data); + return null; + }).when(mockSseStream).sendEvent(anyString(), anyString()); + + Map sseStreams = new HashMap<>(); + sseStreams.put(sessionId, mockSseStream); + try { + java.lang.reflect.Field sseStreamsField = StreamableHttpServerTransportProvider.class + .getDeclaredField("sseStreams"); + sseStreamsField.setAccessible(true); + sseStreamsField.set(transportProvider, sseStreams); + } + catch (Exception e) { + throw new RuntimeException("Failed to set sseStreams field", e); + } + + // Using reflection to access the private constructor + McpServerTransport transport; + try { + Class transportClass = Class.forName( + "io.modelcontextprotocol.server.transport.StreamableHttpServerTransportProvider$StreamableHttpServerTransport"); + java.lang.reflect.Constructor constructor = transportClass + .getDeclaredConstructor(StreamableHttpServerTransportProvider.class, String.class); + constructor.setAccessible(true); + transport = (McpServerTransport) constructor.newInstance(transportProvider, sessionId); + } + catch (Exception e) { + throw new RuntimeException("Failed to create transport", e); + } + + StepVerifier.create(transport.sendMessage(message)).verifyComplete(); + verify(mockSseStream, times(1)).sendEvent(eq(StreamableHttpServerTransportProvider.MESSAGE_EVENT_TYPE), + anyString()); + + String eventData = capturedEventData.get(); + assertThat(eventData).isNotNull(); + } + + @Test + void shouldHandleInvalidRequestURI() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + + when(request.getRequestURI()).thenReturn("/wrong-path"); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + + transportProvider.doGet(request, response); + transportProvider.doPost(request, response); + transportProvider.doDelete(request, response); + + verify(response, times(3)).sendError(HttpServletResponse.SC_NOT_FOUND); + } + + @Test + void shouldHandleMissingSessionId() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter writer = new PrintWriter(stringWriter); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader(StreamableHttpServerTransportProvider.SESSION_ID_HEADER)).thenReturn(null); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(writer); + + // Execute GET request without Session ID (required) + transportProvider.doGet(request, response); + + verify(response).setStatus(HttpServletResponse.SC_BAD_REQUEST); + verify(response).setContentType(eq(StreamableHttpServerTransportProvider.APPLICATION_JSON)); + assertThat(stringWriter.toString()).contains("Session ID missing"); + } + + @Test + void shouldHandleSessionNotFound() throws IOException, ServletException { + HttpServletRequest request = mock(HttpServletRequest.class); + HttpServletResponse response = mock(HttpServletResponse.class); + StringWriter stringWriter = new StringWriter(); + PrintWriter writer = new PrintWriter(stringWriter); + String sessionId = UUID.randomUUID().toString(); + + when(request.getRequestURI()).thenReturn("/mcp"); + when(request.getHeader("Accept")).thenReturn("text/event-stream"); + when(request.getHeader(StreamableHttpServerTransportProvider.SESSION_ID_HEADER)).thenReturn(sessionId); + when(request.getHeaderNames()).thenReturn(Collections.enumeration(Collections.emptyList())); + when(response.getWriter()).thenReturn(writer); + + // Execute GET request with non-existent session ID + transportProvider.doGet(request, response); + + verify(response).setStatus(HttpServletResponse.SC_NOT_FOUND); + verify(response).setContentType(eq(StreamableHttpServerTransportProvider.APPLICATION_JSON)); + assertThat(stringWriter.toString()).contains("Session not found"); + } + +} \ No newline at end of file