Skip to content

Commit 6abbae5

Browse files
committed
Refactor MockTransport
This replaces bespoke test transport with a reusable, configurable MockTransport in a shared testing package. The new MockTransport supports registering handlers for specific JSON-RPC methods (success and error), records sent/received messages, and provides an awaitMessage helper with polling and timeouts. Tests were added to validate handler registration, auto-responses, error responses, concurrency, and message awaiting behavior. Existing client tests were updated to configure MockTransport via lambdas instead of hardcoded logic. Additionally, coroutine test utilities were added to the core test dependencies.
1 parent 8f9f484 commit 6abbae5

File tree

5 files changed

+812
-95
lines changed

5 files changed

+812
-95
lines changed

kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/ClientMetaParameterTest.kt

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
package io.modelcontextprotocol.kotlin.sdk.client
22

3+
import io.modelcontextprotocol.kotlin.sdk.CallToolResult
34
import io.modelcontextprotocol.kotlin.sdk.Implementation
5+
import io.modelcontextprotocol.kotlin.sdk.InitializeResult
46
import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest
7+
import io.modelcontextprotocol.kotlin.sdk.Method
8+
import io.modelcontextprotocol.kotlin.sdk.ServerCapabilities
9+
import io.modelcontextprotocol.kotlin.sdk.testing.MockTransport
510
import kotlinx.coroutines.test.runTest
611
import kotlinx.serialization.json.JsonObject
712
import kotlinx.serialization.json.boolean
@@ -31,7 +36,24 @@ class ClientMetaParameterTest {
3136

3237
@BeforeTest
3338
fun setup() = runTest {
34-
mockTransport = MockTransport()
39+
mockTransport = MockTransport {
40+
// configure mock transport behavior
41+
onMessageReplyResult(Method.Defined.Initialize) {
42+
InitializeResult(
43+
protocolVersion = "2024-11-05",
44+
capabilities = ServerCapabilities(
45+
tools = ServerCapabilities.Tools(listChanged = null),
46+
),
47+
serverInfo = Implementation("mock-server", "1.0.0"),
48+
)
49+
}
50+
onMessageReplyResult(Method.Defined.ToolsCall) {
51+
CallToolResult(
52+
content = listOf(),
53+
isError = false,
54+
)
55+
}
56+
}
3557
client = Client(clientInfo = clientInfo)
3658
mockTransport.setupInitializationResponse()
3759
client.connect(mockTransport)

kotlin-sdk-client/src/commonTest/kotlin/io/modelcontextprotocol/kotlin/sdk/client/MockTransport.kt

Lines changed: 0 additions & 94 deletions
This file was deleted.

kotlin-sdk-core/build.gradle.kts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ kotlin {
124124
implementation(kotlin("test"))
125125
implementation(libs.kotest.assertions.core)
126126
implementation(libs.kotest.assertions.json)
127+
implementation(libs.kotlinx.coroutines.test)
127128
}
128129
}
129130

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
package io.modelcontextprotocol.kotlin.sdk.testing
2+
3+
import io.ktor.util.collections.ConcurrentSet
4+
import io.modelcontextprotocol.kotlin.sdk.ErrorCode
5+
import io.modelcontextprotocol.kotlin.sdk.JSONRPCError
6+
import io.modelcontextprotocol.kotlin.sdk.JSONRPCMessage
7+
import io.modelcontextprotocol.kotlin.sdk.JSONRPCRequest
8+
import io.modelcontextprotocol.kotlin.sdk.JSONRPCResponse
9+
import io.modelcontextprotocol.kotlin.sdk.Method
10+
import io.modelcontextprotocol.kotlin.sdk.RequestResult
11+
import io.modelcontextprotocol.kotlin.sdk.shared.Transport
12+
import kotlinx.coroutines.delay
13+
import kotlinx.coroutines.sync.Mutex
14+
import kotlinx.coroutines.sync.withLock
15+
import kotlin.time.Clock
16+
import kotlin.time.Duration
17+
import kotlin.time.Duration.Companion.milliseconds
18+
import kotlin.time.Duration.Companion.seconds
19+
import kotlin.time.ExperimentalTime
20+
21+
private typealias RequestPredicate = (JSONRPCRequest) -> Boolean
22+
private typealias RequestHandler = suspend (JSONRPCRequest) -> JSONRPCResponse
23+
24+
/**
25+
* A mock transport implementation for testing JSON-RPC communication.
26+
*
27+
* This class simulates transport that can be used to test server and client interactions by
28+
* allowing the registration of handlers for incoming requests and the ability to record
29+
* messages sent and received.
30+
*
31+
* The mock transport supports:
32+
* - Recording all sent and received messages (via `getSentMessages` and `getReceivedMessages`)
33+
* - Registering request handlers that respond to specific message predicates (e.g., by method)
34+
* - Setting up responses that can be either successful or with errors
35+
* - Waiting for specific messages to be received
36+
*
37+
* Note: This class is designed to be used as a test helper and should not be used in production.
38+
*/
39+
@Suppress("TooManyFunctions")
40+
public open class MockTransport(configurer: MockTransport.() -> Unit = {}) : Transport {
41+
private val _sentMessages = mutableListOf<JSONRPCMessage>()
42+
private val _receivedMessages = mutableListOf<JSONRPCMessage>()
43+
44+
private val requestHandlers = ConcurrentSet<Pair<RequestPredicate, RequestHandler>>()
45+
private val mutex = Mutex()
46+
47+
public suspend fun getSentMessages(): List<JSONRPCMessage> = mutex.withLock { _sentMessages.toList() }
48+
49+
public suspend fun getReceivedMessages(): List<JSONRPCMessage> = mutex.withLock { _receivedMessages.toList() }
50+
51+
private var onMessageBlock: (suspend (JSONRPCMessage) -> Unit)? = null
52+
private var onCloseBlock: (() -> Unit)? = null
53+
private var onErrorBlock: ((Throwable) -> Unit)? = null
54+
55+
init {
56+
configurer.invoke(this)
57+
}
58+
59+
override suspend fun start(): Unit = Unit
60+
61+
override suspend fun send(message: JSONRPCMessage) {
62+
mutex.withLock {
63+
_sentMessages += message
64+
}
65+
66+
// Auto-respond to using preconfigured request handlers
67+
when (message) {
68+
is JSONRPCRequest -> {
69+
val response = requestHandlers.firstOrNull {
70+
it.first.invoke(message)
71+
}?.second?.invoke(message)
72+
73+
checkNotNull(response) {
74+
"No request handler found for $message."
75+
}
76+
onMessageBlock?.invoke(response)
77+
}
78+
79+
else -> {
80+
// TODO("Not implemented yet")
81+
}
82+
}
83+
}
84+
85+
override suspend fun close() {
86+
onCloseBlock?.invoke()
87+
}
88+
89+
override fun onMessage(block: suspend (JSONRPCMessage) -> Unit) {
90+
onMessageBlock = { message ->
91+
mutex.withLock {
92+
_receivedMessages += message
93+
}
94+
block(message)
95+
}
96+
}
97+
98+
override fun onClose(block: () -> Unit) {
99+
onCloseBlock = block
100+
}
101+
102+
override fun onError(block: (Throwable) -> Unit) {
103+
onErrorBlock = block
104+
}
105+
106+
public fun setupInitializationResponse() {
107+
// This method helps set up the mock for proper initialization
108+
}
109+
110+
/**
111+
* Registers a handler that will be called when a message matching the given predicate is received.
112+
*
113+
* The handler is expected to return a `RequestResult` which will be used as the response to the request.
114+
*
115+
* @param predicate A predicate that matches the incoming `JSONRPCMessage`
116+
* for which the handler should be triggered.
117+
* @param block A function that processes the incoming `JSONRPCMessage` and returns a `RequestResult`
118+
* to be used as the response.
119+
*/
120+
public fun onMessageReply(predicate: RequestPredicate, block: RequestHandler) {
121+
requestHandlers.add(Pair(predicate, block))
122+
}
123+
124+
/**
125+
* Registers a handler for responses to a specific method.
126+
*
127+
* This method allows registering a handler that will be called when a message with the specified method
128+
* is received. The handler is expected to return a `RequestResult` which is the response to the request.
129+
*
130+
* @param method The method (from the `Method` enum) that the handler should respond to.
131+
* @param block A function that processes the incoming `JSONRPCRequest` and returns a `RequestResult`.
132+
* The returned `RequestResult` will be used as the result of the response.
133+
*/
134+
public fun <T : RequestResult> onMessageReplyResult(method: Method, block: (JSONRPCRequest) -> T) {
135+
onMessageReply(
136+
predicate = {
137+
it.method == method.value
138+
},
139+
block = {
140+
JSONRPCResponse(
141+
id = it.id,
142+
result = block.invoke(it),
143+
)
144+
},
145+
)
146+
}
147+
148+
/**
149+
* Registers a handler that will be called when a request with the specified method is received
150+
* and an error response is to be generated.
151+
*
152+
* This handler is used to respond to requests with a specific method by returning an error response.
153+
* The handler is triggered when a request message with the given `method` is received.
154+
*
155+
* @param method The method (from the `Method` enum) that the handler should respond to with an error.
156+
* @param block A function that processes the incoming `JSONRPCRequest` and returns a `JSONRPCError`
157+
* to be used as the error response.
158+
* The default block returns an internal error with the message "Expected error".
159+
*/
160+
public fun onMessageReplyError(
161+
method: Method,
162+
block: (JSONRPCRequest) -> JSONRPCError = {
163+
JSONRPCError(
164+
code = ErrorCode.Defined.InternalError,
165+
message = "Expected error",
166+
)
167+
},
168+
) {
169+
onMessageReply(
170+
predicate = {
171+
it.method == method.value
172+
},
173+
block = {
174+
JSONRPCResponse(
175+
id = it.id,
176+
error = block.invoke(it),
177+
)
178+
},
179+
)
180+
}
181+
182+
/**
183+
* Waits for a JSON-RPC message that matches the given predicate in the received messages.
184+
*
185+
* @param poolInterval The interval at which the function polls the received messages. Default is 50 milliseconds.
186+
* @param timeout The maximum time to wait for a matching message. Default is 3 seconds.
187+
* @param timeoutMessage The error message to throw when the timeout is reached.
188+
* Default is "No message received matching predicate".
189+
* @param predicate A predicate function that returns true if the message matches the criteria.
190+
* @return The first JSON-RPC message that matches the predicate.
191+
*/
192+
@OptIn(ExperimentalTime::class)
193+
public suspend fun awaitMessage(
194+
poolInterval: Duration = 50.milliseconds,
195+
timeout: Duration = 3.seconds,
196+
timeoutMessage: String = "No message received matching predicate",
197+
predicate: (JSONRPCMessage) -> Boolean,
198+
): JSONRPCMessage {
199+
val clock = Clock.System
200+
val startTime = clock.now()
201+
val finishTime = startTime + timeout
202+
while (clock.now() < finishTime) {
203+
val found = mutex.withLock {
204+
_receivedMessages.firstOrNull { predicate(it) }
205+
}
206+
if (found != null) {
207+
return found
208+
}
209+
delay(poolInterval)
210+
}
211+
error(timeoutMessage)
212+
}
213+
}

0 commit comments

Comments
 (0)