Skip to content

Commit bdaa258

Browse files
authored
Delay SSE GET connection until after session ID is established (#97)
* Delay SSE GET connection until after session ID is established - Add a signaling mechanism to ensure the SSE streaming task waits for the initial session ID to be set before attempting the GET connection. - Use a TaskGroup to race between session ID signal and a 10-second timeout, logging the outcome. - Trigger the signal when a session ID is received for the first time from any response. - Add detailed logging for all code paths to aid debugging and clarify connection timing. - Clean up signal resources on disconnect to prevent leaks. This addresses the initialization timing issue where the SSE GET could be attempted before the session was established, as discussed in MCP PR #91. * Update tests for SSE timing * Use configurable timeout for acquiring session ID * Remove Actor protocol conformance for HTTPClientTransport * Configure shorter session ID timeout for tests * Add documentation comments to HTTPClientTransport * Rename sessionIDWaitTimeout to sseInitializationTimeout
1 parent 3ff1085 commit bdaa258

File tree

2 files changed

+237
-4
lines changed

2 files changed

+237
-4
lines changed

Sources/MCP/Base/Transports/HTTPClientTransport.swift

Lines changed: 186 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,66 @@ import Logging
99
import FoundationNetworking
1010
#endif
1111

12-
public actor HTTPClientTransport: Actor, Transport {
12+
/// An implementation of the MCP Streamable HTTP transport protocol for clients.
13+
///
14+
/// This transport implements the [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http)
15+
/// specification from the Model Context Protocol.
16+
///
17+
/// It supports:
18+
/// - Sending JSON-RPC messages via HTTP POST requests
19+
/// - Receiving responses via both direct JSON responses and SSE streams
20+
/// - Session management using the `Mcp-Session-Id` header
21+
/// - Automatic reconnection for dropped SSE streams
22+
/// - Platform-specific optimizations for different operating systems
23+
///
24+
/// The transport supports two modes:
25+
/// - Regular HTTP (`streaming=false`): Simple request/response pattern
26+
/// - Streaming HTTP with SSE (`streaming=true`): Enables server-to-client push messages
27+
///
28+
/// - Important: Server-Sent Events (SSE) functionality is not supported on Linux platforms.
29+
public actor HTTPClientTransport: Transport {
30+
/// The server endpoint URL to connect to
1331
public let endpoint: URL
1432
private let session: URLSession
33+
34+
/// The session ID assigned by the server, used for maintaining state across requests
1535
public private(set) var sessionID: String?
1636
private let streaming: Bool
1737
private var streamingTask: Task<Void, Never>?
38+
39+
/// Logger instance for transport-related events
1840
public nonisolated let logger: Logger
1941

42+
/// Maximum time to wait for a session ID before proceeding with SSE connection
43+
public let sseInitializationTimeout: TimeInterval
44+
2045
private var isConnected = false
2146
private let messageStream: AsyncThrowingStream<Data, Swift.Error>
2247
private let messageContinuation: AsyncThrowingStream<Data, Swift.Error>.Continuation
2348

49+
private var initialSessionIDSignalTask: Task<Void, Never>?
50+
private var initialSessionIDContinuation: CheckedContinuation<Void, Never>?
51+
52+
/// Creates a new HTTP transport client with the specified endpoint
53+
///
54+
/// - Parameters:
55+
/// - endpoint: The server URL to connect to
56+
/// - configuration: URLSession configuration to use for HTTP requests
57+
/// - streaming: Whether to enable SSE streaming mode (default: true)
58+
/// - sseInitializationTimeout: Maximum time to wait for session ID before proceeding with SSE (default: 10 seconds)
59+
/// - logger: Optional logger instance for transport events
2460
public init(
2561
endpoint: URL,
2662
configuration: URLSessionConfiguration = .default,
2763
streaming: Bool = true,
64+
sseInitializationTimeout: TimeInterval = 10,
2865
logger: Logger? = nil
2966
) {
3067
self.init(
3168
endpoint: endpoint,
3269
session: URLSession(configuration: configuration),
3370
streaming: streaming,
71+
sseInitializationTimeout: sseInitializationTimeout,
3472
logger: logger
3573
)
3674
}
@@ -39,11 +77,13 @@ public actor HTTPClientTransport: Actor, Transport {
3977
endpoint: URL,
4078
session: URLSession,
4179
streaming: Bool = false,
80+
sseInitializationTimeout: TimeInterval = 10,
4281
logger: Logger? = nil
4382
) {
4483
self.endpoint = endpoint
4584
self.session = session
4685
self.streaming = streaming
86+
self.sseInitializationTimeout = sseInitializationTimeout
4787

4888
// Create message stream
4989
var continuation: AsyncThrowingStream<Data, Swift.Error>.Continuation!
@@ -58,11 +98,37 @@ public actor HTTPClientTransport: Actor, Transport {
5898
)
5999
}
60100

101+
// Setup the initial session ID signal
102+
private func setupInitialSessionIDSignal() {
103+
self.initialSessionIDSignalTask = Task {
104+
await withCheckedContinuation { continuation in
105+
self.initialSessionIDContinuation = continuation
106+
// This task will suspend here until continuation.resume() is called
107+
}
108+
}
109+
}
110+
111+
// Trigger the initial session ID signal when a session ID is established
112+
private func triggerInitialSessionIDSignal() {
113+
if let continuation = self.initialSessionIDContinuation {
114+
continuation.resume()
115+
self.initialSessionIDContinuation = nil // Consume the continuation
116+
logger.debug("Initial session ID signal triggered for SSE task.")
117+
}
118+
}
119+
61120
/// Establishes connection with the transport
121+
///
122+
/// This prepares the transport for communication and sets up SSE streaming
123+
/// if streaming mode is enabled. The actual HTTP connection happens with the
124+
/// first message sent.
62125
public func connect() async throws {
63126
guard !isConnected else { return }
64127
isConnected = true
65128

129+
// Setup initial session ID signal
130+
setupInitialSessionIDSignal()
131+
66132
if streaming {
67133
// Start listening to server events
68134
streamingTask = Task { await startListeningForServerEvents() }
@@ -72,6 +138,9 @@ public actor HTTPClientTransport: Actor, Transport {
72138
}
73139

74140
/// Disconnects from the transport
141+
///
142+
/// This terminates any active connections, cancels the streaming task,
143+
/// and releases any resources being used by the transport.
75144
public func disconnect() async {
76145
guard isConnected else { return }
77146
isConnected = false
@@ -86,10 +155,28 @@ public actor HTTPClientTransport: Actor, Transport {
86155
// Clean up message stream
87156
messageContinuation.finish()
88157

158+
// Cancel the initial session ID signal task if active
159+
initialSessionIDSignalTask?.cancel()
160+
initialSessionIDSignalTask = nil
161+
// Resume the continuation if it's still pending to avoid leaks
162+
initialSessionIDContinuation?.resume()
163+
initialSessionIDContinuation = nil
164+
89165
logger.info("HTTP clienttransport disconnected")
90166
}
91167

92168
/// Sends data through an HTTP POST request
169+
///
170+
/// This sends a JSON-RPC message to the server via HTTP POST and processes
171+
/// the response according to the MCP Streamable HTTP specification. It handles:
172+
///
173+
/// - Adding appropriate Accept headers for both JSON and SSE
174+
/// - Including the session ID in requests if one has been established
175+
/// - Processing different response types (JSON vs SSE)
176+
/// - Handling HTTP error codes according to the specification
177+
///
178+
/// - Parameter data: The JSON-RPC message to send
179+
/// - Throws: MCPError for transport failures or server errors
93180
public func send(_ data: Data) async throws {
94181
guard isConnected else {
95182
throw MCPError.internalError("Transport not connected")
@@ -129,7 +216,12 @@ public actor HTTPClientTransport: Actor, Transport {
129216

130217
// Extract session ID if present
131218
if let newSessionID = httpResponse.value(forHTTPHeaderField: "Mcp-Session-Id") {
219+
let wasSessionIDNil = (self.sessionID == nil)
132220
self.sessionID = newSessionID
221+
if wasSessionIDNil {
222+
// Trigger signal on first session ID
223+
triggerInitialSessionIDSignal()
224+
}
133225
logger.debug("Session ID received", metadata: ["sessionID": "\(newSessionID)"])
134226
}
135227

@@ -161,7 +253,12 @@ public actor HTTPClientTransport: Actor, Transport {
161253

162254
// Extract session ID if present
163255
if let newSessionID = httpResponse.value(forHTTPHeaderField: "Mcp-Session-Id") {
256+
let wasSessionIDNil = (self.sessionID == nil)
164257
self.sessionID = newSessionID
258+
if wasSessionIDNil {
259+
// Trigger signal on first session ID
260+
triggerInitialSessionIDSignal()
261+
}
165262
logger.debug("Session ID received", metadata: ["sessionID": "\(newSessionID)"])
166263
}
167264

@@ -238,13 +335,29 @@ public actor HTTPClientTransport: Actor, Transport {
238335
}
239336

240337
/// Receives data in an async sequence
338+
///
339+
/// This returns an AsyncThrowingStream that emits Data objects representing
340+
/// each JSON-RPC message received from the server. This includes:
341+
///
342+
/// - Direct responses to client requests
343+
/// - Server-initiated messages delivered via SSE streams
344+
///
345+
/// - Returns: An AsyncThrowingStream of Data objects
241346
public func receive() -> AsyncThrowingStream<Data, Swift.Error> {
242347
return messageStream
243348
}
244349

245350
// MARK: - SSE
246351

247352
/// Starts listening for server events using SSE
353+
///
354+
/// This establishes a long-lived HTTP connection using Server-Sent Events (SSE)
355+
/// to enable server-to-client push messaging. It handles:
356+
///
357+
/// - Waiting for session ID if needed
358+
/// - Opening the SSE connection
359+
/// - Automatic reconnection on connection drops
360+
/// - Processing received events
248361
private func startListeningForServerEvents() async {
249362
#if os(Linux)
250363
// SSE is not fully supported on Linux
@@ -257,6 +370,63 @@ public actor HTTPClientTransport: Actor, Transport {
257370
// This is the original code for platforms that support SSE
258371
guard isConnected else { return }
259372

373+
// Wait for the initial session ID signal, but only if sessionID isn't already set
374+
if self.sessionID == nil, let signalTask = self.initialSessionIDSignalTask {
375+
logger.debug("SSE streaming task waiting for initial sessionID signal...")
376+
377+
// Race the signalTask against a timeout
378+
let timeoutTask = Task {
379+
try? await Task.sleep(for: .seconds(self.sseInitializationTimeout))
380+
return false
381+
}
382+
383+
let signalCompletionTask = Task {
384+
await signalTask.value
385+
return true // Indicates signal received
386+
}
387+
388+
// Use TaskGroup to race the two tasks
389+
var signalReceived = false
390+
do {
391+
signalReceived = try await withThrowingTaskGroup(of: Bool.self) { group in
392+
group.addTask {
393+
await signalCompletionTask.value
394+
}
395+
group.addTask {
396+
await timeoutTask.value
397+
}
398+
399+
// Take the first result and cancel the other task
400+
if let firstResult = try await group.next() {
401+
group.cancelAll()
402+
return firstResult
403+
}
404+
return false
405+
}
406+
} catch {
407+
logger.error("Error while waiting for session ID signal: \(error)")
408+
}
409+
410+
// Clean up tasks
411+
timeoutTask.cancel()
412+
413+
if signalReceived {
414+
logger.debug("SSE streaming task proceeding after initial sessionID signal.")
415+
} else {
416+
logger.warning(
417+
"Timeout waiting for initial sessionID signal. SSE stream will proceed (sessionID might be nil)."
418+
)
419+
}
420+
} else if self.sessionID != nil {
421+
logger.debug(
422+
"Initial sessionID already available. Proceeding with SSE streaming task immediately."
423+
)
424+
} else {
425+
logger.info(
426+
"Proceeding with SSE connection attempt; sessionID is nil. This might be expected for stateless servers or if initialize hasn't provided one yet."
427+
)
428+
}
429+
260430
// Retry loop for connection drops
261431
while isConnected && !Task.isCancelled {
262432
do {
@@ -274,6 +444,11 @@ public actor HTTPClientTransport: Actor, Transport {
274444

275445
#if !os(Linux)
276446
/// Establishes an SSE connection to the server
447+
///
448+
/// This initiates a GET request to the server endpoint with appropriate
449+
/// headers to establish an SSE stream according to the MCP specification.
450+
///
451+
/// - Throws: MCPError for connection failures or server errors
277452
private func connectToEventStream() async throws {
278453
guard isConnected else { return }
279454

@@ -309,13 +484,23 @@ public actor HTTPClientTransport: Actor, Transport {
309484

310485
// Extract session ID if present
311486
if let newSessionID = httpResponse.value(forHTTPHeaderField: "Mcp-Session-Id") {
487+
let wasSessionIDNil = (self.sessionID == nil)
312488
self.sessionID = newSessionID
489+
if wasSessionIDNil {
490+
// Trigger signal on first session ID, though this is unlikely to happen here
491+
// as GET usually follows a POST that would have already set the session ID
492+
triggerInitialSessionIDSignal()
493+
}
313494
logger.debug("Session ID received", metadata: ["sessionID": "\(newSessionID)"])
314495
}
315496

316497
try await self.processSSE(stream)
317498
}
318499

500+
/// Processes an SSE byte stream, extracting events and delivering them
501+
///
502+
/// - Parameter stream: The URLSession.AsyncBytes stream to process
503+
/// - Throws: Error for stream processing failures
319504
private func processSSE(_ stream: URLSession.AsyncBytes) async throws {
320505
do {
321506
for try await event in stream.events {

0 commit comments

Comments
 (0)