diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java index 62264d9a..d6677482 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxSseServerTransportProvider.java @@ -1,16 +1,11 @@ package io.modelcontextprotocol.server.transport; import java.io.IOException; -import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerSession; -import io.modelcontextprotocol.spec.McpServerTransport; -import io.modelcontextprotocol.spec.McpServerTransportProvider; +import io.modelcontextprotocol.spec.*; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -100,6 +95,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv private McpServerSession.Factory sessionFactory; + private McpServerSessionListener mcpServerSessionListener; + /** * Map of active client sessions, keyed by session ID. */ @@ -149,6 +146,22 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa */ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebFlux SSE server transport provider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of MCP messages. Must not be null. + * @param baseUrl webflux message base path + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages. This endpoint will be communicated to clients during SSE connection + * setup. Must not be null. + * @param mcpServerSessionListener The listener for handling server session events. + * @throws IllegalArgumentException if either parameter is null + */ + public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, McpServerSessionListener mcpServerSessionListener) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base path must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -162,6 +175,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) .build(); + this.mcpServerSessionListener = mcpServerSessionListener; } @Override @@ -229,6 +243,10 @@ public Mono closeGracefully() { .then(); } + public void setMcpServerSessionListener(McpServerSessionListener mcpServerSessionListener) { + this.mcpServerSessionListener = mcpServerSessionListener; + } + /** * Returns the WebFlux router function that defines the transport's HTTP endpoints. * This router function should be integrated into the application's web configuration. @@ -277,6 +295,10 @@ private Mono handleSseConnection(ServerRequest request) { logger.debug("Session {} cancelled", sessionId); sessions.remove(sessionId); }); + + if (null != mcpServerSessionListener) { + mcpServerSessionListener.onConnection(session, request); + } }), ServerSentEvent.class); } @@ -310,6 +332,9 @@ private Mono handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.NOT_FOUND) .bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get())); } + if (null != mcpServerSessionListener) { + mcpServerSessionListener.onMessage(session, request); + } return request.bodyToMono(String.class).flatMap(body -> { try { diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java index fc86cfaa..3ccfb9be 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcSseServerTransportProvider.java @@ -12,11 +12,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpServerTransport; -import io.modelcontextprotocol.spec.McpServerTransportProvider; -import io.modelcontextprotocol.spec.McpServerSession; +import io.modelcontextprotocol.spec.*; import io.modelcontextprotocol.util.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -97,6 +93,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi private McpServerSession.Factory sessionFactory; + private McpServerSessionListener mcpServerSessionListener; + /** * Map of active client sessions, keyed by session ID. */ @@ -149,6 +147,24 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag */ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, String sseEndpoint) { + this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null); + } + + /** + * Constructs a new WebMvcSseServerTransportProvider instance. + * @param objectMapper The ObjectMapper to use for JSON serialization/deserialization + * of messages. + * @param baseUrl The base URL for the message endpoint, used to construct the full + * endpoint URL for clients. + * @param messageEndpoint The endpoint URI where clients should send their JSON-RPC + * messages via HTTP POST. This endpoint will be communicated to clients through the + * SSE connection's initial endpoint event. + * @param mcpServerSessionListener The listener for handling server session events. + * @param sseEndpoint The endpoint URI where clients establish their SSE connections. + * @throws IllegalArgumentException if any parameter is null + */ + public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint, + String sseEndpoint, McpServerSessionListener mcpServerSessionListener) { Assert.notNull(objectMapper, "ObjectMapper must not be null"); Assert.notNull(baseUrl, "Message base URL must not be null"); Assert.notNull(messageEndpoint, "Message endpoint must not be null"); @@ -162,6 +178,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr .GET(this.sseEndpoint, this::handleSseConnection) .POST(this.messageEndpoint, this::handleMessage) .build(); + this.mcpServerSessionListener = mcpServerSessionListener; } @Override @@ -215,6 +232,10 @@ public Mono closeGracefully() { .doOnSuccess(v -> logger.debug("Graceful shutdown completed")); } + public void setMcpServerSessionListener(McpServerSessionListener mcpServerSessionListener) { + this.mcpServerSessionListener = mcpServerSessionListener; + } + /** * Returns the RouterFunction that defines the HTTP endpoints for this transport. The * router function handles two endpoints: @@ -275,6 +296,10 @@ private ServerResponse handleSseConnection(ServerRequest request) { logger.error("Failed to send initial endpoint event: {}", e.getMessage()); sseBuilder.error(e); } + + if (null != mcpServerSessionListener) { + mcpServerSessionListener.onConnection(session, request); + } }, Duration.ZERO); } catch (Exception e) { @@ -311,6 +336,10 @@ private ServerResponse handleMessage(ServerRequest request) { return ServerResponse.status(HttpStatus.NOT_FOUND).body(new McpError("Session not found: " + sessionId)); } + if (null != mcpServerSessionListener) { + mcpServerSessionListener.onConnection(session, request); + } + try { String body = request.body(String.class); McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body); @@ -415,6 +444,11 @@ public void close() { } } + @Override + public String getSessionId() { + return sessionId; + } + } } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index 1efa13de..210f113d 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -183,9 +183,12 @@ public class McpAsyncServer { notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED, asyncRootsListChangedNotificationHandler(rootsChangeConsumers)); - mcpTransportProvider.setSessionFactory( - transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport, - this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers)); + mcpTransportProvider.setSessionFactory(transport -> { + // If the sessionId is not provided, generate a random one. + String sessionId = Optional.ofNullable(transport.getSessionId()).orElse(UUID.randomUUID().toString()); + return new McpServerSession(sessionId, requestTimeout, transport, this::asyncInitializeRequestHandler, + Mono::empty, requestHandlers, notificationHandlers); + }); } // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java index 889dc66d..6f30dad7 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServerExchange.java @@ -50,6 +50,14 @@ public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabili this.clientInfo = clientInfo; } + /** + * Get the session id. + * @return The session id + */ + public String getSessionId() { + return this.session.getId(); + } + /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java index 52360e54..8136e320 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpSyncServerExchange.java @@ -28,6 +28,14 @@ public McpSyncServerExchange(McpAsyncServerExchange exchange) { this.exchange = exchange; } + /** + * Get the session id. + * @return The session id + */ + public String getSessionId() { + return this.exchange.getSessionId(); + } + /** * Get the client capabilities that define the supported features and functionality. * @return The client capabilities diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSessionListener.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSessionListener.java new file mode 100644 index 00000000..b7164942 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSessionListener.java @@ -0,0 +1,25 @@ +package io.modelcontextprotocol.spec; + +/** + * Listener for McpServerSession events. + * + * @param The type of request object used in the session. + * @author Allen Hu + */ +public interface McpServerSessionListener { + + /** + * Called when a new session is connected. + * @param session The session that was connected. + * @param request The request object used in the session. + */ + void onConnection(McpServerSession session, T request); + + /** + * Called when a message is received. + * @param session The session that received the message. + * @param request The request object used in the session. + */ + void onMessage(McpServerSession session, T request); + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java index 632b8cee..52a96cef 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerTransport.java @@ -8,4 +8,12 @@ */ public interface McpServerTransport extends McpTransport { + /** + * Returns the session id. + * @return the session id. default is null. + */ + default String getSessionId() { + return null; + } + }