diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index f3f6c2c33..64c20e238 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -8,13 +8,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import io.modelcontextprotocol.server.DefaultMcpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; -import io.modelcontextprotocol.spec.HttpHeaders; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpStreamableServerSession; -import io.modelcontextprotocol.spec.McpStreamableServerTransport; -import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; -import io.modelcontextprotocol.spec.ProtocolVersions; +import io.modelcontextprotocol.spec.*; import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.KeepAliveScheduler; @@ -278,6 +272,19 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { WebFluxStreamableMcpSessionTransport st = new WebFluxStreamableMcpSessionTransport(sink); Mono stream = session.responseStream(jsonrpcRequest, st); Disposable streamSubscription = stream.onErrorComplete(err -> { + if (err instanceof McpParamsValidationError) { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + jsonrpcRequest.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.INVALID_PARAMS, err.getMessage(), null)); + + var event = ServerSentEvent.builder() + .event(MESSAGE_EVENT_TYPE) + .data(errorResponse) + .build(); + + sink.next(event); + return true; + } sink.error(err); return true; }).contextWrite(sink.contextView()).subscribe(); diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java index fa51a0130..e28aaef4c 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -10,6 +10,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; +import io.modelcontextprotocol.spec.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.http.HttpStatus; @@ -26,13 +27,6 @@ import io.modelcontextprotocol.server.DefaultMcpTransportContext; import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; -import io.modelcontextprotocol.spec.HttpHeaders; -import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpSchema; -import io.modelcontextprotocol.spec.McpStreamableServerSession; -import io.modelcontextprotocol.spec.McpStreamableServerTransport; -import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; -import io.modelcontextprotocol.spec.ProtocolVersions; import io.modelcontextprotocol.util.Assert; import io.modelcontextprotocol.util.KeepAliveScheduler; import reactor.core.publisher.Flux; @@ -396,6 +390,12 @@ else if (message instanceof McpSchema.JSONRPCRequest jsonrpcRequest) { .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) .block(); } + catch (McpParamsValidationError e) { + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, + jsonrpcRequest.id(), null, new McpSchema.JSONRPCResponse.JSONRPCError( + McpSchema.ErrorCodes.INVALID_PARAMS, e.getMessage(), null)); + sessionTransport.sendMessage(errorResponse).block(); + } catch (Exception e) { logger.error("Failed to handle request stream: {}", e.getMessage()); sseBuilder.error(e); diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java index 8e041d91e..32f209b20 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -4,11 +4,11 @@ package io.modelcontextprotocol; +import static io.modelcontextprotocol.spec.McpSchema.ErrorCodes.INVALID_PARAMS; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.assertThatJson; import static net.javacrumbs.jsonunit.assertj.JsonAssertions.json; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; -import static org.assertj.core.api.Assertions.assertWith; +import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; @@ -17,6 +17,8 @@ import java.net.http.HttpRequest; import java.net.http.HttpResponse; import java.time.Duration; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -27,8 +29,11 @@ import java.util.function.BiFunction; import java.util.function.Function; import java.util.stream.Collectors; +import java.util.stream.Stream; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.ValueSource; import io.modelcontextprotocol.client.McpClient; @@ -921,6 +926,135 @@ void testToolListChangeHandlingSuccess(String clientType) { mcpServer.close(); } + // --------------------------------------- + // Tests for Paginated Tool List Results + // --------------------------------------- + + @ParameterizedTest(name = "{0} ({1}) : {displayName} ") + @MethodSource("providePaginationTestParams") + void testListToolsSuccess(String clientType, int availableElements) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + List tools = new ArrayList<>(); + + for (int i = 0; i < availableElements; i++) { + var mock = new McpSchema.Tool("test-tool-" + i, "Test Tool Description", emptyJsonSchema); + var spec = new McpServerFeatures.SyncToolSpecification(mock, null); + + tools.add(spec); + } + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tools) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var returnedElements = new HashSet(); + + var hasEntries = true; + String nextCursor = null; + + while (hasEntries) { + var res = mcpClient.listTools(nextCursor); + + res.tools().forEach(e -> returnedElements.add(e.name())); // store unique + // attribute + + nextCursor = res.nextCursor(); + + if (nextCursor == null) { + hasEntries = false; + } + } + + assertThat(returnedElements.size()).isEqualTo(availableElements); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListToolsCursorInvalidListChanged(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + var pageSize = 10; + List tools = new ArrayList<>(); + + for (int i = 0; i <= pageSize; i++) { + var mock = new McpSchema.Tool("test-tool-" + i, "Test Tool Description", emptyJsonSchema); + var spec = new McpServerFeatures.SyncToolSpecification(mock, null); + + tools.add(spec); + } + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(tools) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var res = mcpClient.listTools(null); + + // Change list + var mock = new McpSchema.Tool("test-tool-xyz", "Test Tool Description", emptyJsonSchema); + mcpServer.addTool(new McpServerFeatures.SyncToolSpecification(mock, null)); + + assertThatThrownBy(() -> mcpClient.listTools(res.nextCursor())).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListToolsInvalidCursor(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mock = new McpSchema.Tool("test-tool", "Test Tool Description", emptyJsonSchema); + var spec = new McpServerFeatures.SyncToolSpecification(mock, null); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().tools(true).build()) + .tools(spec) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatThrownBy(() -> mcpClient.listTools("INVALID")).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) void testInitialize(String clientType) { @@ -964,36 +1098,36 @@ void testLoggingNotification(String clientType) throws InterruptedException { //@formatter:off return exchange // This should be filtered out (DEBUG < NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.DEBUG) - .logger("test-logger") - .data("Debug message") - .build()) - .then(exchange // This should be sent (NOTICE >= NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.NOTICE) - .logger("test-logger") - .data("Notice message") - .build())) - .then(exchange // This should be sent (ERROR > NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Error message") - .build())) - .then(exchange // This should be filtered out (INFO < NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.INFO) - .logger("test-logger") - .data("Another info message") - .build())) - .then(exchange // This should be sent (ERROR >= NOTICE) - .loggingNotification(McpSchema.LoggingMessageNotification.builder() - .level(McpSchema.LoggingLevel.ERROR) - .logger("test-logger") - .data("Another error message") - .build())) - .thenReturn(new CallToolResult("Logging test completed", false)); + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.DEBUG) + .logger("test-logger") + .data("Debug message") + .build()) + .then(exchange // This should be sent (NOTICE >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.NOTICE) + .logger("test-logger") + .data("Notice message") + .build())) + .then(exchange // This should be sent (ERROR > NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Error message") + .build())) + .then(exchange // This should be filtered out (INFO < NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.INFO) + .logger("test-logger") + .data("Another info message") + .build())) + .then(exchange // This should be sent (ERROR >= NOTICE) + .loggingNotification(McpSchema.LoggingMessageNotification.builder() + .level(McpSchema.LoggingLevel.ERROR) + .logger("test-logger") + .data("Another error message") + .build())) + .thenReturn(new CallToolResult("Logging test completed", false)); //@formatter:on }) .build(); @@ -1056,7 +1190,7 @@ void testLoggingNotification(String clientType) throws InterruptedException { @ValueSource(strings = { "httpclient", "webflux" }) void testProgressNotification(String clientType) throws InterruptedException { int expectedNotificationsCount = 4; // 3 notifications + 1 for another progress - // token + // token CountDownLatch latch = new CountDownLatch(expectedNotificationsCount); // Create a list to store received logging notifications List receivedNotifications = new CopyOnWriteArrayList<>(); @@ -1524,6 +1658,380 @@ void testStructuredOutputRuntimeToolAddition(String clientType) { mcpServer.close(); } + // --------------------------------------- + // Tests for Paginated Prompt List Results + // --------------------------------------- + + @ParameterizedTest(name = "{0} ({1}) : {displayName} ") + @MethodSource("providePaginationTestParams") + void testListPromptsSuccess(String clientType, int availableElements) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + List prompts = new ArrayList<>(); + + for (int i = 0; i < availableElements; i++) { + var mock = new McpSchema.Prompt("test-prompt-" + i, "Test Prompt Description", + List.of(new McpSchema.PromptArgument("arg1", "Test argument", true))); + var spec = new McpServerFeatures.SyncPromptSpecification(mock, null); + + prompts.add(spec); + } + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(prompts) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var returnedElements = new HashSet(); + + var hasEntries = true; + String nextCursor = null; + + while (hasEntries) { + var res = mcpClient.listPrompts(nextCursor); + + res.prompts().forEach(e -> returnedElements.add(e.name())); // store + // unique + // attribute + + nextCursor = res.nextCursor(); + + if (nextCursor == null) { + hasEntries = false; + } + } + + assertThat(returnedElements.size()).isEqualTo(availableElements); + + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListPromptsCursorInvalidListChanged(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + var pageSize = 10; + List prompts = new ArrayList<>(); + + for (int i = 0; i <= pageSize; i++) { + var mock = new McpSchema.Prompt("test-prompt-" + i, "Test Prompt Description", + List.of(new McpSchema.PromptArgument("arg1", "Test argument", true))); + var spec = new McpServerFeatures.SyncPromptSpecification(mock, null); + + prompts.add(spec); + } + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(prompts) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var res = mcpClient.listPrompts(null); + + // Change list + var mock = new McpSchema.Prompt("test-prompt-xyz", "Test Prompt Description", + List.of(new McpSchema.PromptArgument("arg1", "Test argument", true))); + + mcpServer.addPrompt(new McpServerFeatures.SyncPromptSpecification(mock, null)); + + assertThatThrownBy(() -> mcpClient.listPrompts(res.nextCursor())).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListPromptsInvalidCursor(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mock = new McpSchema.Prompt("test-prompt", "Test Prompt Description", + List.of(new McpSchema.PromptArgument("arg1", "Test argument", true))); + + var spec = new McpServerFeatures.SyncPromptSpecification(mock, null); + + var mcpServer = prepareSyncServerBuilder().capabilities(ServerCapabilities.builder().prompts(true).build()) + .prompts(spec) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatThrownBy(() -> mcpClient.listPrompts("INVALID")).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tests for Paginated Resources List Results + // --------------------------------------- + + @ParameterizedTest(name = "{0} ({1}) : {displayName} ") + @MethodSource("providePaginationTestParams") + void testListResourcesSuccess(String clientType, int availableElements) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + List resources = new ArrayList<>(); + + for (int i = 0; i < availableElements; i++) { + var mock = new McpSchema.Resource("file://example-" + i + ".txt", "test-resource", + "Test Resource Description", "application/octet-stream", null); + var spec = new McpServerFeatures.SyncResourceSpecification(mock, null); + + resources.add(spec); + } + + var mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().resources(true, true).build()) + .resources(resources) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var returnedElements = new HashSet(); + + var hasEntries = true; + String nextCursor = null; + + while (hasEntries) { + var res = mcpClient.listResources(nextCursor); + + res.resources().forEach(e -> returnedElements.add(e.uri())); // store + // unique + // attribute + + nextCursor = res.nextCursor(); + + if (nextCursor == null) { + hasEntries = false; + } + } + + assertThat(returnedElements.size()).isEqualTo(availableElements); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListResourcesCursorInvalidListChanged(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + var pageSize = 10; + List resources = new ArrayList<>(); + + for (int i = 0; i <= pageSize; i++) { + var mock = new McpSchema.Resource("file://example-" + i + ".txt", "test-resource", + "Test Resource Description", "application/octet-stream", null); + var spec = new McpServerFeatures.SyncResourceSpecification(mock, null); + + resources.add(spec); + } + + var mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().resources(true, true).build()) + .resources(resources) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var res = mcpClient.listResources(null); + + // Change list + var mock = new McpSchema.Resource("file://example-xyz.txt", "test-resource", "Test Resource Description", + "application/octet-stream", null); + mcpServer.addResource(new McpServerFeatures.SyncResourceSpecification(mock, null)); + + assertThatThrownBy(() -> mcpClient.listResources(res.nextCursor())).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListResourcesInvalidCursor(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mock = new McpSchema.Resource("file://example.txt", "test-resource", "Test Resource Description", + "application/octet-stream", null); + var spec = new McpServerFeatures.SyncResourceSpecification(mock, null); + + var mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().resources(true, true).build()) + .resources(spec) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatThrownBy(() -> mcpClient.listResources("INVALID")).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + + // --------------------------------------- + // Tests for Paginated Resource Templates Results + // --------------------------------------- + + @ParameterizedTest(name = "{0} ({1}) : {displayName} ") + @MethodSource("providePaginationTestParams") + void testListResourceTemplatesSuccess(String clientType, int availableElements) { + + var clientBuilder = clientBuilders.get(clientType); + + // Setup list of prompts + List resourceTemplates = new ArrayList<>(); + + for (int i = 0; i < availableElements; i++) { + resourceTemplates.add(new McpSchema.ResourceTemplate("file://{path}-" + i + ".txt", "test-resource", + "Test Resource Description", "application/octet-stream", null)); + } + + var mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().resources(true, true).build()) + .resourceTemplates(resourceTemplates) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + var returnedElements = new HashSet(); + + var hasEntries = true; + String nextCursor = null; + + while (hasEntries) { + var res = mcpClient.listResourceTemplates(nextCursor); + + res.resourceTemplates().forEach(e -> returnedElements.add(e.uriTemplate())); // store + // unique + // attribute + + nextCursor = res.nextCursor(); + + if (nextCursor == null) { + hasEntries = false; + } + } + + assertThat(returnedElements.size()).isEqualTo(availableElements); + } + + mcpServer.close(); + } + + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListResourceTemplatesInvalidCursor(String clientType) { + + var clientBuilder = clientBuilders.get(clientType); + + var mock = new McpSchema.ResourceTemplate("file://{path}.txt", "test-resource", "Test Resource Description", + "application/octet-stream", null); + + var mcpServer = prepareSyncServerBuilder() + .capabilities(ServerCapabilities.builder().resources(true, true).build()) + .resourceTemplates(mock) + .build(); + + try (var mcpClient = clientBuilder.build()) { + + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + assertThatThrownBy(() -> mcpClient.listResourceTemplates("INVALID")).isInstanceOf(McpError.class) + .hasMessage("Invalid cursor") + .satisfies(exception -> { + var error = (McpError) exception; + assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS); + assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor"); + }); + + } + + mcpServer.close(); + } + + // --------------------------------------- + // Helpers for Tests of Paginated Lists + // --------------------------------------- + + /** + * Helper function for pagination tests. This provides a stream of the following + * parameters: 1. Client type (e.g. httpclient, webflux) 2. Number of available + * elements in the list + * @return a stream of arguments with test parameters + */ + static Stream providePaginationTestParams() { + return Stream.of(Arguments.of("httpclient", 0), Arguments.of("httpclient", 1), Arguments.of("httpclient", 21), + Arguments.of("webflux", 0), Arguments.of("webflux", 1), Arguments.of("webflux", 21)); + } + private double evaluateExpression(String expression) { // Simple expression evaluator for testing return switch (expression) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index a51c2e36c..4939192d8 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -6,6 +6,7 @@ import java.time.Duration; import java.util.ArrayList; +import java.util.Base64; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -27,13 +28,13 @@ import io.modelcontextprotocol.spec.JsonSchemaValidator; import io.modelcontextprotocol.spec.McpClientSession; import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpParamsValidationError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; import io.modelcontextprotocol.spec.McpSchema.LoggingLevel; import io.modelcontextprotocol.spec.McpSchema.LoggingMessageNotification; import io.modelcontextprotocol.spec.McpSchema.ResourceTemplate; import io.modelcontextprotocol.spec.McpSchema.SetLevelRequest; -import io.modelcontextprotocol.spec.McpSchema.Tool; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransportProvider; import io.modelcontextprotocol.util.Assert; @@ -109,6 +110,8 @@ public class McpAsyncServer { private final ConcurrentHashMap prompts = new ConcurrentHashMap<>(); + private static final int PAGE_SIZE = 10; + // FIXME: this field is deprecated and should be remvoed together with the // broadcasting loggingNotification. private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG; @@ -490,9 +493,25 @@ public Mono notifyToolsListChanged() { private McpRequestHandler toolsListRequestHandler() { return (exchange, params) -> { - List tools = this.tools.stream().map(McpServerFeatures.AsyncToolSpecification::tool).toList(); + McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + new TypeReference() { + }); + + int mapSize = this.tools.size(); + int mapHash = this.tools.hashCode(); + + int requestedStartIndex = handleCursor(request.cursor(), mapSize, mapHash).block(); + int endIndex = Math.min(requestedStartIndex + PAGE_SIZE, mapSize); - return Mono.just(new McpSchema.ListToolsResult(tools, null)); + var nextCursor = getCursor(endIndex, mapSize, mapHash); + + var resultList = this.tools.stream() + .skip(requestedStartIndex) + .limit(endIndex - requestedStartIndex) + .map(McpServerFeatures.AsyncToolSpecification::tool) + .toList(); + + return Mono.just(new McpSchema.ListToolsResult(resultList, nextCursor)); }; } @@ -591,18 +610,49 @@ public Mono notifyResourcesUpdated(McpSchema.ResourcesUpdatedNotification private McpRequestHandler resourcesListRequestHandler() { return (exchange, params) -> { - var resourceList = this.resources.values() + McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + new TypeReference() { + }); + + int mapSize = this.resources.size(); + int mapHash = this.resources.hashCode(); + + int requestedStartIndex = handleCursor(request.cursor(), mapSize, mapHash).block(); + int endIndex = Math.min(requestedStartIndex + PAGE_SIZE, mapSize); + + var nextCursor = getCursor(endIndex, mapSize, mapHash); + + var resultList = this.resources.values() .stream() + .skip(requestedStartIndex) + .limit(endIndex - requestedStartIndex) .map(McpServerFeatures.AsyncResourceSpecification::resource) .toList(); - return Mono.just(new McpSchema.ListResourcesResult(resourceList, null)); + + return Mono.just(new McpSchema.ListResourcesResult(resultList, nextCursor)); }; } private McpRequestHandler resourceTemplateListRequestHandler() { - return (exchange, params) -> Mono - .just(new McpSchema.ListResourceTemplatesResult(this.getResourceTemplates(), null)); + return (exchange, params) -> { + McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + new TypeReference() { + }); + var all = this.getResourceTemplates(); + + int mapSize = all.size(); + int mapHash = all.hashCode(); + + int requestedStartIndex = handleCursor(request.cursor(), mapSize, mapHash).block(); + int endIndex = Math.min(requestedStartIndex + PAGE_SIZE, mapSize); + + var nextCursor = getCursor(endIndex, mapSize, mapHash); + + var resultList = all.stream().skip(requestedStartIndex).limit(endIndex - requestedStartIndex).toList(); + + return Mono.just(new McpSchema.ListResourceTemplatesResult(resultList, nextCursor)); + }; } private List getResourceTemplates() { @@ -718,17 +768,27 @@ public Mono notifyPromptsListChanged() { private McpRequestHandler promptsListRequestHandler() { return (exchange, params) -> { - // TODO: Implement pagination - // McpSchema.PaginatedRequest request = objectMapper.convertValue(params, - // new TypeReference() { - // }); - var promptList = this.prompts.values() + McpSchema.PaginatedRequest request = objectMapper.convertValue(params, + new TypeReference() { + }); + + int mapSize = this.prompts.size(); + int mapHash = this.prompts.hashCode(); + + int requestedStartIndex = handleCursor(request.cursor(), mapSize, mapHash).block(); + int endIndex = Math.min(requestedStartIndex + PAGE_SIZE, mapSize); + + var nextCursor = getCursor(endIndex, mapSize, mapHash); + + var resultList = this.prompts.values() .stream() + .skip(requestedStartIndex) + .limit(endIndex - requestedStartIndex) .map(McpServerFeatures.AsyncPromptSpecification::prompt) .toList(); - return Mono.just(new McpSchema.ListPromptsResult(promptList, null)); + return Mono.just(new McpSchema.ListPromptsResult(resultList, nextCursor)); }; } @@ -906,4 +966,79 @@ void setProtocolVersions(List protocolVersions) { this.protocolVersions = protocolVersions; } + // --------------------------------------- + // Cursor Handling for paginated requests + // --------------------------------------- + + /** + * Handles the cursor by decoding, validating and reading the index of it. + * @param cursor the base64 representation of the cursor. + * @param mapSize the size of the map from which the values should be read. + * @param mapHash the hash of the map to compare the cursor value to. + * @return a {@link Mono} which contains the index to which the cursor points. + */ + private Mono handleCursor(String cursor, int mapSize, int mapHash) { + if (cursor == null) { + return Mono.just(0); + } + + var decodedCursor = decodeCursor(cursor); + + if (!isCursorValid(decodedCursor, mapSize, mapHash)) { + return Mono.error(new McpParamsValidationError("Invalid cursor")); + } + + return Mono.just(getCursorIndex(decodedCursor)); + } + + private String getCursor(int endIndex, int mapSize, int mapHash) { + if (endIndex >= mapSize) { + return null; + } + return encodeCursor(endIndex, mapHash); + } + + private int getCursorIndex(String cursor) { + return Integer.parseInt(cursor.split(":")[0]); + } + + private boolean isCursorValid(String cursor, int maxPageSize, int currentHash) { + var cursorElements = cursor.split(":"); + + if (cursorElements.length != 2) { + logger.debug("Length of elements in cursor doesn't match expected number. Cursor: {} Actual number: {}", + cursor, cursorElements.length); + return false; + } + + int index; + int hash; + + try { + index = Integer.parseInt(cursorElements[0]); + hash = Integer.parseInt(cursorElements[1]); + } + catch (NumberFormatException e) { + logger.debug("Failed to parse cursor elements."); + return false; + } + + if (index < 0 || index > maxPageSize || hash != currentHash) { + logger.debug("Cursor boundaries are invalid."); + return false; + } + + return true; + } + + private String encodeCursor(int index, int hash) { + var cursor = index + ":" + hash; + + return Base64.getEncoder().encodeToString(cursor.getBytes()); + } + + private String decodeCursor(String base64Cursor) { + return new String(Base64.getDecoder().decode(base64Cursor)); + } + } diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpParamsValidationError.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpParamsValidationError.java new file mode 100644 index 000000000..e7ecb0058 --- /dev/null +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpParamsValidationError.java @@ -0,0 +1,9 @@ +package io.modelcontextprotocol.spec; + +public class McpParamsValidationError extends McpError { + + public McpParamsValidationError(String error) { + super(error); + } + +} diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java index 62985dc17..1fc4a2a92 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java @@ -270,10 +270,20 @@ private Mono handleIncomingRequest(McpSchema.JSONRPCR } return resultMono .map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null)) - .onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), - null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR, - error.getMessage(), null)))); // TODO: add error message - // through the data field + .onErrorResume(error -> { + + var errorCode = McpSchema.ErrorCodes.INTERNAL_ERROR; + + if (error instanceof McpParamsValidationError) { + errorCode = McpSchema.ErrorCodes.INVALID_PARAMS; + } + + // TODO: add error message through the data field + var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null, + new McpSchema.JSONRPCResponse.JSONRPCError(errorCode, error.getMessage(), null)); + + return Mono.just(errorResponse); + }); }); }