Skip to content

Commit 8d48b93

Browse files
committed
feat: Add Pagination for requesting list of prompts
Adds the Pagination feature to the `prompts/list` feature as described in the specification. To make this possible mainly two changes are made: 1. The logic for cursor handling is added. 2. Handling for invalid parameters (MCP error code `-32602 (Invalid params)`) is added to the `McpServerSession`. For now the cursor is the base64 encoded start index of the next page. The page size is set to 10. When parameters are found to be invalid the newly introduced `McpParamsValidationError` is returned to handle it properly in the `McpServerSession`.
1 parent 734d173 commit 8d48b93

File tree

4 files changed

+184
-13
lines changed

4 files changed

+184
-13
lines changed

mcp-spring/mcp-spring-webflux/src/test/java/io/modelcontextprotocol/WebFluxSseIntegrationTests.java

Lines changed: 101 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.util.function.BiFunction;
1414
import java.util.function.Function;
1515
import java.util.stream.Collectors;
16+
import java.util.stream.Stream;
1617

1718
import com.fasterxml.jackson.databind.ObjectMapper;
1819
import io.modelcontextprotocol.client.McpClient;
@@ -30,6 +31,8 @@
3031
import org.junit.jupiter.api.AfterEach;
3132
import org.junit.jupiter.api.BeforeEach;
3233
import org.junit.jupiter.params.ParameterizedTest;
34+
import org.junit.jupiter.params.provider.Arguments;
35+
import org.junit.jupiter.params.provider.MethodSource;
3336
import org.junit.jupiter.params.provider.ValueSource;
3437
import reactor.netty.DisposableServer;
3538
import reactor.netty.http.server.HttpServer;
@@ -40,9 +43,8 @@
4043
import org.springframework.web.reactive.function.client.WebClient;
4144
import org.springframework.web.reactive.function.server.RouterFunctions;
4245

43-
import static org.assertj.core.api.Assertions.assertThat;
44-
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
45-
import static org.assertj.core.api.Assertions.assertWith;
46+
import static io.modelcontextprotocol.spec.McpSchema.ErrorCodes.INVALID_PARAMS;
47+
import static org.assertj.core.api.Assertions.*;
4648
import static org.awaitility.Awaitility.await;
4749
import static org.mockito.Mockito.mock;
4850

@@ -802,4 +804,99 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) {
802804
mcpServer.close();
803805
}
804806

805-
}
807+
// ---------------------------------------
808+
// Prompt List Tests
809+
// ---------------------------------------
810+
811+
static Stream<Arguments> providePaginationTestParams() {
812+
return Stream.of(Arguments.of("httpclient", 0), Arguments.of("httpclient", 1), Arguments.of("httpclient", 21),
813+
Arguments.of("webflux", 0), Arguments.of("webflux", 1), Arguments.of("webflux", 21));
814+
}
815+
816+
@ParameterizedTest(name = "{0} ({1}) : {displayName} ")
817+
@MethodSource("providePaginationTestParams")
818+
void testListPromptSuccess(String clientType, int availablePrompts) {
819+
820+
var clientBuilder = clientBuilders.get(clientType);
821+
822+
// Setup list of prompts
823+
List<McpServerFeatures.SyncPromptSpecification> prompts = new ArrayList<>();
824+
825+
for (int i = 0; i < availablePrompts; i++) {
826+
McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt-" + i, "Test Prompt Description",
827+
List.of(new McpSchema.PromptArgument("arg1", "Test argument", true)));
828+
829+
var promptSpec = new McpServerFeatures.SyncPromptSpecification(mockPrompt, null);
830+
831+
prompts.add(promptSpec);
832+
}
833+
834+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
835+
.capabilities(ServerCapabilities.builder().prompts(true).build())
836+
.prompts(prompts)
837+
.build();
838+
839+
try (var mcpClient = clientBuilder.build()) {
840+
841+
InitializeResult initResult = mcpClient.initialize();
842+
assertThat(initResult).isNotNull();
843+
844+
// Iterate through list
845+
var returnedPromptsSum = 0;
846+
847+
var hasEntries = true;
848+
String nextCursor = null;
849+
850+
while (hasEntries) {
851+
var res = mcpClient.listPrompts(nextCursor);
852+
returnedPromptsSum += res.prompts().size();
853+
854+
nextCursor = res.nextCursor();
855+
856+
if (nextCursor == null) {
857+
hasEntries = false;
858+
}
859+
}
860+
861+
assertThat(returnedPromptsSum).isEqualTo(availablePrompts);
862+
863+
}
864+
865+
mcpServer.close();
866+
}
867+
868+
@ParameterizedTest(name = "{0} : {displayName} ")
869+
@ValueSource(strings = { "httpclient", "webflux" })
870+
void testListPromptInvalidCursor(String clientType) {
871+
872+
var clientBuilder = clientBuilders.get(clientType);
873+
874+
McpSchema.Prompt mockPrompt = new McpSchema.Prompt("test-prompt", "Test Prompt Description",
875+
List.of(new McpSchema.PromptArgument("arg1", "Test argument", true)));
876+
877+
var promptSpec = new McpServerFeatures.SyncPromptSpecification(mockPrompt, null);
878+
879+
var mcpServer = McpServer.sync(mcpServerTransportProvider)
880+
.capabilities(ServerCapabilities.builder().prompts(true).build())
881+
.prompts(promptSpec)
882+
.build();
883+
884+
try (var mcpClient = clientBuilder.build()) {
885+
886+
InitializeResult initResult = mcpClient.initialize();
887+
assertThat(initResult).isNotNull();
888+
889+
assertThatThrownBy(() -> mcpClient.listPrompts("INVALID")).isInstanceOf(McpError.class)
890+
.hasMessage("Invalid cursor")
891+
.satisfies(exception -> {
892+
var error = (McpError) exception;
893+
assertThat(error.getJsonRpcError().code()).isEqualTo(INVALID_PARAMS);
894+
assertThat(error.getJsonRpcError().message()).isEqualTo("Invalid cursor");
895+
});
896+
897+
}
898+
899+
mcpServer.close();
900+
}
901+
902+
}

mcp/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import java.util.Map;
1111
import java.util.Optional;
1212
import java.util.UUID;
13+
import java.util.Base64;
1314
import java.util.concurrent.ConcurrentHashMap;
1415
import java.util.concurrent.CopyOnWriteArrayList;
1516
import java.util.function.BiFunction;
@@ -18,6 +19,7 @@
1819
import com.fasterxml.jackson.databind.ObjectMapper;
1920
import io.modelcontextprotocol.spec.McpClientSession;
2021
import io.modelcontextprotocol.spec.McpError;
22+
import io.modelcontextprotocol.spec.McpParamsValidationError;
2123
import io.modelcontextprotocol.spec.McpSchema;
2224
import io.modelcontextprotocol.spec.McpSchema.CallToolResult;
2325
import io.modelcontextprotocol.spec.McpSchema.LoggingLevel;
@@ -266,6 +268,8 @@ private static class AsyncServerImpl extends McpAsyncServer {
266268

267269
private final ConcurrentHashMap<String, McpServerFeatures.AsyncPromptSpecification> prompts = new ConcurrentHashMap<>();
268270

271+
private static final int PAGE_SIZE = 10;
272+
269273
// FIXME: this field is deprecated and should be remvoed together with the
270274
// broadcasting loggingNotification.
271275
private LoggingLevel minLoggingLevel = LoggingLevel.DEBUG;
@@ -647,20 +651,67 @@ public Mono<Void> notifyPromptsListChanged() {
647651

648652
private McpServerSession.RequestHandler<McpSchema.ListPromptsResult> promptsListRequestHandler() {
649653
return (exchange, params) -> {
650-
// TODO: Implement pagination
651-
// McpSchema.PaginatedRequest request = objectMapper.convertValue(params,
652-
// new TypeReference<McpSchema.PaginatedRequest>() {
653-
// });
654+
McpSchema.PaginatedRequest request = objectMapper.convertValue(params,
655+
new TypeReference<McpSchema.PaginatedRequest>() {
656+
});
657+
658+
if (!isCursorValid(request.cursor(), this.prompts.size())) {
659+
return Mono.error(new McpParamsValidationError("Invalid cursor"));
660+
}
661+
662+
int requestedStartIndex = 0;
663+
664+
if (request.cursor() != null) {
665+
requestedStartIndex = decodeCursor(request.cursor());
666+
}
667+
668+
int endIndex = Math.min(requestedStartIndex + PAGE_SIZE, this.prompts.size());
654669

655670
var promptList = this.prompts.values()
656671
.stream()
672+
.skip(requestedStartIndex)
673+
.limit(endIndex - requestedStartIndex)
657674
.map(McpServerFeatures.AsyncPromptSpecification::prompt)
658675
.toList();
659676

660-
return Mono.just(new McpSchema.ListPromptsResult(promptList, null));
677+
String nextCursor = null;
678+
679+
if (endIndex < this.prompts.size()) {
680+
nextCursor = encodeCursor(endIndex);
681+
}
682+
683+
return Mono.just(new McpSchema.ListPromptsResult(promptList, nextCursor));
661684
};
662685
}
663686

687+
private boolean isCursorValid(String cursor, int maxPageSize) {
688+
if (cursor == null) {
689+
return true;
690+
}
691+
692+
try {
693+
var decoded = decodeCursor(cursor);
694+
695+
if (decoded < 0 || decoded > maxPageSize) {
696+
return false;
697+
}
698+
699+
return true;
700+
}
701+
catch (NumberFormatException e) {
702+
return false;
703+
}
704+
}
705+
706+
private String encodeCursor(int index) {
707+
return Base64.getEncoder().encodeToString(String.valueOf(index).getBytes());
708+
}
709+
710+
private int decodeCursor(String cursor) {
711+
String decoded = new String(Base64.getDecoder().decode(cursor));
712+
return Integer.parseInt(decoded);
713+
}
714+
664715
private McpServerSession.RequestHandler<McpSchema.GetPromptResult> promptsGetRequestHandler() {
665716
return (exchange, params) -> {
666717
McpSchema.GetPromptRequest promptRequest = objectMapper.convertValue(params,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
package io.modelcontextprotocol.spec;
2+
3+
public class McpParamsValidationError extends McpError {
4+
5+
public McpParamsValidationError(McpSchema.JSONRPCResponse.JSONRPCError jsonRpcError) {
6+
super(jsonRpcError.message());
7+
}
8+
9+
public McpParamsValidationError(Object error) {
10+
super(error.toString());
11+
}
12+
13+
}

mcp/src/main/java/io/modelcontextprotocol/spec/McpServerSession.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,20 @@ private Mono<McpSchema.JSONRPCResponse> handleIncomingRequest(McpSchema.JSONRPCR
225225
}
226226
return resultMono
227227
.map(result -> new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), result, null))
228-
.onErrorResume(error -> Mono.just(new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(),
229-
null, new McpSchema.JSONRPCResponse.JSONRPCError(McpSchema.ErrorCodes.INTERNAL_ERROR,
230-
error.getMessage(), null)))); // TODO: add error message
231-
// through the data field
228+
.onErrorResume(error -> {
229+
230+
var errorCode = McpSchema.ErrorCodes.INTERNAL_ERROR;
231+
232+
if (error instanceof McpParamsValidationError) {
233+
errorCode = McpSchema.ErrorCodes.INVALID_PARAMS;
234+
}
235+
236+
// TODO: add error message through the data field
237+
var errorResponse = new McpSchema.JSONRPCResponse(McpSchema.JSONRPC_VERSION, request.id(), null,
238+
new McpSchema.JSONRPCResponse.JSONRPCError(errorCode, error.getMessage(), null));
239+
240+
return Mono.just(errorResponse);
241+
});
232242
});
233243
}
234244

0 commit comments

Comments
 (0)