Skip to content

feat: supported McpServerSessionListener on SseSession events. #233

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
package io.modelcontextprotocol.server.transport;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.spec.*;
import io.modelcontextprotocol.util.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -100,6 +95,8 @@ public class WebFluxSseServerTransportProvider implements McpServerTransportProv

private McpServerSession.Factory sessionFactory;

private McpServerSessionListener<ServerRequest> mcpServerSessionListener;

/**
* Map of active client sessions, keyed by session ID.
*/
Expand Down Expand Up @@ -149,6 +146,22 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String messa
*/
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null);
}

/**
* Constructs a new WebFlux SSE server transport provider instance.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of MCP messages. Must not be null.
* @param baseUrl webflux message base path
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages. This endpoint will be communicated to clients during SSE connection
* setup. Must not be null.
* @param mcpServerSessionListener The listener for handling server session events.
* @throws IllegalArgumentException if either parameter is null
*/
public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, McpServerSessionListener<ServerRequest> mcpServerSessionListener) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(baseUrl, "Message base path must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Expand All @@ -162,6 +175,7 @@ public WebFluxSseServerTransportProvider(ObjectMapper objectMapper, String baseU
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
.build();
this.mcpServerSessionListener = mcpServerSessionListener;
}

@Override
Expand Down Expand Up @@ -229,6 +243,10 @@ public Mono<Void> closeGracefully() {
.then();
}

public void setMcpServerSessionListener(McpServerSessionListener<ServerRequest> mcpServerSessionListener) {
this.mcpServerSessionListener = mcpServerSessionListener;
}

/**
* Returns the WebFlux router function that defines the transport's HTTP endpoints.
* This router function should be integrated into the application's web configuration.
Expand Down Expand Up @@ -277,6 +295,10 @@ private Mono<ServerResponse> handleSseConnection(ServerRequest request) {
logger.debug("Session {} cancelled", sessionId);
sessions.remove(sessionId);
});

if (null != mcpServerSessionListener) {
mcpServerSessionListener.onConnection(session, request);
}
}), ServerSentEvent.class);
}

Expand Down Expand Up @@ -310,6 +332,9 @@ private Mono<ServerResponse> handleMessage(ServerRequest request) {
return ServerResponse.status(HttpStatus.NOT_FOUND)
.bodyValue(new McpError("Session not found: " + request.queryParam("sessionId").get()));
}
if (null != mcpServerSessionListener) {
mcpServerSessionListener.onMessage(session, request);
}

return request.bodyToMono(String.class).flatMap(body -> {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,7 @@

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerTransport;
import io.modelcontextprotocol.spec.McpServerTransportProvider;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.*;
import io.modelcontextprotocol.util.Assert;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -97,6 +93,8 @@ public class WebMvcSseServerTransportProvider implements McpServerTransportProvi

private McpServerSession.Factory sessionFactory;

private McpServerSessionListener<ServerRequest> mcpServerSessionListener;

/**
* Map of active client sessions, keyed by session ID.
*/
Expand Down Expand Up @@ -149,6 +147,24 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String messag
*/
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint) {
this(objectMapper, baseUrl, messageEndpoint, sseEndpoint, null);
}

/**
* Constructs a new WebMvcSseServerTransportProvider instance.
* @param objectMapper The ObjectMapper to use for JSON serialization/deserialization
* of messages.
* @param baseUrl The base URL for the message endpoint, used to construct the full
* endpoint URL for clients.
* @param messageEndpoint The endpoint URI where clients should send their JSON-RPC
* messages via HTTP POST. This endpoint will be communicated to clients through the
* SSE connection's initial endpoint event.
* @param mcpServerSessionListener The listener for handling server session events.
* @param sseEndpoint The endpoint URI where clients establish their SSE connections.
* @throws IllegalArgumentException if any parameter is null
*/
public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUrl, String messageEndpoint,
String sseEndpoint, McpServerSessionListener<ServerRequest> mcpServerSessionListener) {
Assert.notNull(objectMapper, "ObjectMapper must not be null");
Assert.notNull(baseUrl, "Message base URL must not be null");
Assert.notNull(messageEndpoint, "Message endpoint must not be null");
Expand All @@ -162,6 +178,7 @@ public WebMvcSseServerTransportProvider(ObjectMapper objectMapper, String baseUr
.GET(this.sseEndpoint, this::handleSseConnection)
.POST(this.messageEndpoint, this::handleMessage)
.build();
this.mcpServerSessionListener = mcpServerSessionListener;
}

@Override
Expand Down Expand Up @@ -215,6 +232,10 @@ public Mono<Void> closeGracefully() {
.doOnSuccess(v -> logger.debug("Graceful shutdown completed"));
}

public void setMcpServerSessionListener(McpServerSessionListener<ServerRequest> mcpServerSessionListener) {
this.mcpServerSessionListener = mcpServerSessionListener;
}

/**
* Returns the RouterFunction that defines the HTTP endpoints for this transport. The
* router function handles two endpoints:
Expand Down Expand Up @@ -275,6 +296,10 @@ private ServerResponse handleSseConnection(ServerRequest request) {
logger.error("Failed to send initial endpoint event: {}", e.getMessage());
sseBuilder.error(e);
}

if (null != mcpServerSessionListener) {
mcpServerSessionListener.onConnection(session, request);
}
}, Duration.ZERO);
}
catch (Exception e) {
Expand Down Expand Up @@ -311,6 +336,10 @@ private ServerResponse handleMessage(ServerRequest request) {
return ServerResponse.status(HttpStatus.NOT_FOUND).body(new McpError("Session not found: " + sessionId));
}

if (null != mcpServerSessionListener) {
mcpServerSessionListener.onConnection(session, request);
}

try {
String body = request.body(String.class);
McpSchema.JSONRPCMessage message = McpSchema.deserializeJsonRpcMessage(objectMapper, body);
Expand Down Expand Up @@ -415,6 +444,11 @@ public void close() {
}
}

@Override
public String getSessionId() {
return sessionId;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,12 @@ public class McpAsyncServer {
notificationHandlers.put(McpSchema.METHOD_NOTIFICATION_ROOTS_LIST_CHANGED,
asyncRootsListChangedNotificationHandler(rootsChangeConsumers));

mcpTransportProvider.setSessionFactory(
transport -> new McpServerSession(UUID.randomUUID().toString(), requestTimeout, transport,
this::asyncInitializeRequestHandler, Mono::empty, requestHandlers, notificationHandlers));
mcpTransportProvider.setSessionFactory(transport -> {
// If the sessionId is not provided, generate a random one.
String sessionId = Optional.ofNullable(transport.getSessionId()).orElse(UUID.randomUUID().toString());
return new McpServerSession(sessionId, requestTimeout, transport, this::asyncInitializeRequestHandler,
Mono::empty, requestHandlers, notificationHandlers);
});
}

// ---------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ public McpAsyncServerExchange(McpServerSession session, McpSchema.ClientCapabili
this.clientInfo = clientInfo;
}

/**
* Get the session id.
* @return The session id
*/
public String getSessionId() {
return this.session.getId();
}

/**
* Get the client capabilities that define the supported features and functionality.
* @return The client capabilities
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ public McpSyncServerExchange(McpAsyncServerExchange exchange) {
this.exchange = exchange;
}

/**
* Get the session id.
* @return The session id
*/
public String getSessionId() {
return this.exchange.getSessionId();
}

/**
* Get the client capabilities that define the supported features and functionality.
* @return The client capabilities
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package io.modelcontextprotocol.spec;

/**
* Listener for McpServerSession events.
*
* @param <T> The type of request object used in the session.
* @author Allen Hu
*/
public interface McpServerSessionListener<T> {

/**
* Called when a new session is connected.
* @param session The session that was connected.
* @param request The request object used in the session.
*/
void onConnection(McpServerSession session, T request);

/**
* Called when a message is received.
* @param session The session that received the message.
* @param request The request object used in the session.
*/
void onMessage(McpServerSession session, T request);

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,12 @@
*/
public interface McpServerTransport extends McpTransport {

/**
* Returns the session id.
* @return the session id. default is null.
*/
default String getSessionId() {
return null;
}

}