Skip to content

Commit a403cc2

Browse files
committed
Implement legacy SSE client transport
1 parent 305cf7e commit a403cc2

File tree

4 files changed

+963
-92
lines changed

4 files changed

+963
-92
lines changed
Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,354 @@
1+
import Foundation
2+
import Logging
3+
4+
#if !os(Linux)
5+
import EventSource
6+
7+
/// An implementation of the MCP HTTP with SSE transport protocol.
8+
///
9+
/// This transport implements the [HTTP with SSE transport](https://modelcontextprotocol.io/specification/2024-11-05/basic/transports#http-with-sse)
10+
/// specification from the Model Context Protocol.
11+
///
12+
/// It supports:
13+
/// - Sending JSON-RPC messages via HTTP POST requests
14+
/// - Receiving responses via SSE events
15+
/// - Automatic handling of endpoint discovery
16+
///
17+
/// ## Example Usage
18+
///
19+
/// ```swift
20+
/// import MCP
21+
///
22+
/// // Create an SSE transport with the server endpoint
23+
/// let transport = SSETransport(
24+
/// endpoint: URL(string: "http://localhost:8080")!,
25+
/// token: "your-auth-token" // Optional
26+
/// )
27+
///
28+
/// // Initialize the client with the transport
29+
/// let client = Client(name: "MyApp", version: "1.0.0")
30+
/// try await client.connect(transport: transport)
31+
///
32+
/// // The transport will automatically handle SSE events
33+
/// // and deliver them through the client's notification handlers
34+
/// ```
35+
public actor SSEClientTransport: Transport {
36+
/// The server endpoint URL to connect to
37+
public let endpoint: URL
38+
39+
/// Logger instance for transport-related events
40+
public nonisolated let logger: Logger
41+
42+
/// Whether the transport is currently connected
43+
public private(set) var isConnected: Bool = false
44+
45+
/// The URL to send messages to, provided by the server in the 'endpoint' event
46+
private var messageURL: URL?
47+
48+
/// Authentication token for requests (if required)
49+
private let token: String?
50+
51+
/// The URLSession for network requests
52+
private let session: URLSession
53+
54+
/// Task for SSE streaming connection
55+
private var streamingTask: Task<Void, Never>?
56+
57+
/// Used for async/await in connect()
58+
private var connectionContinuation: CheckedContinuation<Void, Swift.Error>?
59+
60+
/// Stream for receiving messages
61+
private let messageStream: AsyncThrowingStream<Data, Swift.Error>
62+
private let messageContinuation: AsyncThrowingStream<Data, Swift.Error>.Continuation
63+
64+
/// Creates a new SSE transport with the specified endpoint
65+
///
66+
/// - Parameters:
67+
/// - endpoint: The server URL to connect to
68+
/// - token: Optional authentication token
69+
/// - configuration: URLSession configuration to use (default: .default)
70+
/// - logger: Optional logger instance for transport events
71+
public init(
72+
endpoint: URL,
73+
token: String? = nil,
74+
configuration: URLSessionConfiguration = .default,
75+
logger: Logger? = nil
76+
) {
77+
self.endpoint = endpoint
78+
self.token = token
79+
self.session = URLSession(configuration: configuration)
80+
81+
// Create message stream
82+
var continuation: AsyncThrowingStream<Data, Swift.Error>.Continuation!
83+
self.messageStream = AsyncThrowingStream<Data, Swift.Error> { continuation = $0 }
84+
self.messageContinuation = continuation
85+
86+
self.logger =
87+
logger
88+
?? Logger(
89+
label: "mcp.transport.sse",
90+
factory: { _ in SwiftLogNoOpLogHandler() }
91+
)
92+
}
93+
94+
/// Establishes connection with the transport
95+
///
96+
/// This creates an SSE connection to the server and waits for the 'endpoint'
97+
/// event to receive the URL for sending messages.
98+
public func connect() async throws {
99+
guard !isConnected else { return }
100+
101+
logger.info("Connecting to SSE endpoint: \(endpoint)")
102+
103+
// Start listening for server events
104+
streamingTask = Task { await listenForServerEvents() }
105+
106+
// Wait for the endpoint URL to be received with a timeout
107+
return try await withThrowingTaskGroup(of: Void.self) { group in
108+
// Add the connection task
109+
group.addTask {
110+
try await self.waitForConnection()
111+
}
112+
113+
// Add the timeout task
114+
group.addTask {
115+
try await Task.sleep(for: .seconds(5)) // 5 second timeout
116+
throw MCPError.internalError("Connection timeout waiting for endpoint URL")
117+
}
118+
119+
// Take the first result and cancel the other task
120+
if let result = try await group.next() {
121+
group.cancelAll()
122+
return result
123+
}
124+
throw MCPError.internalError("Connection failed")
125+
}
126+
}
127+
128+
/// Waits for the connection to be established
129+
private func waitForConnection() async throws {
130+
try await withCheckedThrowingContinuation { continuation in
131+
self.connectionContinuation = continuation
132+
}
133+
}
134+
135+
/// Disconnects from the transport
136+
///
137+
/// This terminates the SSE connection and releases resources.
138+
public func disconnect() async {
139+
guard isConnected else { return }
140+
141+
logger.info("Disconnecting from SSE endpoint")
142+
143+
// Cancel the streaming task
144+
streamingTask?.cancel()
145+
streamingTask = nil
146+
147+
// Clean up
148+
isConnected = false
149+
messageContinuation.finish()
150+
151+
// If there's a pending connection continuation, fail it
152+
if let continuation = connectionContinuation {
153+
continuation.resume(throwing: MCPError.internalError("Connection closed"))
154+
connectionContinuation = nil
155+
}
156+
157+
// Cancel any in-progress requests
158+
session.invalidateAndCancel()
159+
}
160+
161+
/// Sends a JSON-RPC message to the server
162+
///
163+
/// This sends data to the message endpoint provided by the server
164+
/// during connection setup.
165+
///
166+
/// - Parameter data: The JSON-RPC message to send
167+
/// - Throws: MCPError if there's no message URL or if the request fails
168+
public func send(_ data: Data) async throws {
169+
guard isConnected else {
170+
throw MCPError.internalError("Transport not connected")
171+
}
172+
173+
guard let messageURL = messageURL else {
174+
throw MCPError.internalError("No message URL provided by server")
175+
}
176+
177+
logger.debug("Sending message", metadata: ["size": "\(data.count)"])
178+
179+
var request = URLRequest(url: messageURL)
180+
request.httpMethod = "POST"
181+
request.httpBody = data
182+
request.setValue("application/json", forHTTPHeaderField: "Content-Type")
183+
184+
// Add authorization if token is provided
185+
if let token = token {
186+
request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
187+
}
188+
189+
let (_, response) = try await session.data(for: request)
190+
191+
guard let httpResponse = response as? HTTPURLResponse else {
192+
throw MCPError.internalError("Invalid HTTP response")
193+
}
194+
195+
guard (200..<300).contains(httpResponse.statusCode) else {
196+
throw MCPError.internalError("HTTP error: \(httpResponse.statusCode)")
197+
}
198+
}
199+
200+
/// Receives data in an async sequence
201+
///
202+
/// This returns an AsyncThrowingStream that emits Data objects representing
203+
/// each JSON-RPC message received from the server via SSE.
204+
///
205+
/// - Returns: An AsyncThrowingStream of Data objects
206+
public func receive() -> AsyncThrowingStream<Data, Swift.Error> {
207+
return messageStream
208+
}
209+
210+
// MARK: - Private Methods
211+
212+
/// Main task that listens for server-sent events
213+
private func listenForServerEvents() async {
214+
var retryCount = 0
215+
let maxRetries = 3
216+
217+
// Retry loop for dropped connections
218+
while !Task.isCancelled {
219+
do {
220+
try await connectToSSEStream()
221+
// Reset retry count on successful connection
222+
retryCount = 0
223+
} catch {
224+
if !Task.isCancelled {
225+
logger.error("SSE connection error: \(error.localizedDescription)")
226+
retryCount += 1
227+
228+
if retryCount >= maxRetries {
229+
logger.error("Max retries reached, giving up")
230+
break
231+
}
232+
233+
// Wait before retrying with exponential backoff
234+
try? await Task.sleep(for: .seconds(pow(2.0, Double(retryCount))))
235+
}
236+
}
237+
}
238+
}
239+
240+
/// Establishes the SSE stream connection
241+
private func connectToSSEStream() async throws {
242+
logger.debug("Starting SSE connection")
243+
244+
var request = URLRequest(url: endpoint)
245+
request.httpMethod = "GET"
246+
request.setValue("text/event-stream", forHTTPHeaderField: "Accept")
247+
request.setValue("no-cache", forHTTPHeaderField: "Cache-Control")
248+
249+
// Add authorization if token is provided
250+
if let token = token {
251+
request.setValue("Bearer \(token)", forHTTPHeaderField: "Authorization")
252+
}
253+
254+
// On supported platforms, we use the EventSource implementation
255+
let (byteStream, response) = try await session.bytes(for: request)
256+
257+
guard let httpResponse = response as? HTTPURLResponse else {
258+
throw MCPError.internalError("Invalid HTTP response")
259+
}
260+
261+
guard httpResponse.statusCode == 200 else {
262+
throw MCPError.internalError("HTTP error: \(httpResponse.statusCode)")
263+
}
264+
265+
guard let contentType = httpResponse.value(forHTTPHeaderField: "Content-Type"),
266+
contentType.contains("text/event-stream")
267+
else {
268+
throw MCPError.internalError("Invalid content type for SSE stream")
269+
}
270+
271+
logger.debug("SSE connection established")
272+
273+
// Process the SSE stream
274+
for try await event in byteStream.events {
275+
// Check if task has been cancelled
276+
if Task.isCancelled { break }
277+
278+
processServerSentEvent(event)
279+
}
280+
}
281+
282+
/// Processes a server-sent event
283+
private func processServerSentEvent(_ event: SSE) {
284+
// Process event based on type
285+
switch event.event {
286+
case "endpoint":
287+
if !event.data.isEmpty {
288+
processEndpointURL(event.data)
289+
} else {
290+
logger.error("Received empty endpoint data")
291+
}
292+
293+
case "message", nil: // Default event type is "message" per SSE spec
294+
if !event.data.isEmpty,
295+
let messageData = event.data.data(using: .utf8)
296+
{
297+
messageContinuation.yield(messageData)
298+
} else {
299+
logger.warning("Received empty message data")
300+
}
301+
302+
default:
303+
logger.warning("Received unknown event type: \(event.event ?? "nil")")
304+
}
305+
}
306+
307+
/// Processes an endpoint URL string received from the server
308+
private func processEndpointURL(_ endpoint: String) {
309+
logger.debug("Received endpoint path: \(endpoint)")
310+
311+
// Construct the full URL for sending messages
312+
if let url = constructMessageURL(from: endpoint) {
313+
messageURL = url
314+
logger.info("Message URL set to: \(url)")
315+
316+
// Mark as connected
317+
isConnected = true
318+
319+
// Resume the connection continuation if it exists
320+
if let continuation = connectionContinuation {
321+
continuation.resume()
322+
connectionContinuation = nil
323+
}
324+
} else {
325+
logger.error("Failed to construct message URL from path: \(endpoint)")
326+
327+
// Fail the connection if we have a continuation
328+
if let continuation = connectionContinuation {
329+
continuation.resume(throwing: MCPError.internalError("Invalid endpoint URL"))
330+
connectionContinuation = nil
331+
}
332+
}
333+
}
334+
335+
/// Constructs a message URL from a path or absolute URL
336+
private func constructMessageURL(from path: String) -> URL? {
337+
// Handle absolute URLs
338+
if path.starts(with: "http://") || path.starts(with: "https://") {
339+
return URL(string: path)
340+
}
341+
342+
// Handle relative paths
343+
guard var components = URLComponents(url: endpoint, resolvingAgainstBaseURL: true)
344+
else {
345+
return nil
346+
}
347+
348+
// For relative paths, preserve the scheme, host, and port
349+
let pathToUse = path.starts(with: "/") ? path : "/\(path)"
350+
components.path = pathToUse
351+
return components.url
352+
}
353+
}
354+
#endif

0 commit comments

Comments
 (0)