-
Notifications
You must be signed in to change notification settings - Fork 628
/
Copy pathsse.ts
251 lines (218 loc) · 7.96 KB
/
sse.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import { EventSource, type ErrorEvent, type EventSourceInit } from "eventsource";
import { Transport } from "../shared/transport.js";
import { JSONRPCMessage, JSONRPCMessageSchema } from "../types.js";
import { auth, AuthResult, OAuthClientProvider, UnauthorizedError } from "./auth.js";
export class SseError extends Error {
constructor(
public readonly code: number | undefined,
message: string | undefined,
public readonly event: ErrorEvent,
) {
super(`SSE error: ${message}`);
}
}
/**
* Configuration options for the `SSEClientTransport`.
*/
export type SSEClientTransportOptions = {
/**
* An OAuth client provider to use for authentication.
*
* When an `authProvider` is specified and the SSE connection is started:
* 1. The connection is attempted with any existing access token from the `authProvider`.
* 2. If the access token has expired, the `authProvider` is used to refresh the token.
* 3. If token refresh fails or no access token exists, and auth is required, `OAuthClientProvider.redirectToAuthorization` is called, and an `UnauthorizedError` will be thrown from `connect`/`start`.
*
* After the user has finished authorizing via their user agent, and is redirected back to the MCP client application, call `SSEClientTransport.finishAuth` with the authorization code before retrying the connection.
*
* If an `authProvider` is not provided, and auth is required, an `UnauthorizedError` will be thrown.
*
* `UnauthorizedError` might also be thrown when sending any message over the SSE transport, indicating that the session has expired, and needs to be re-authed and reconnected.
*/
authProvider?: OAuthClientProvider;
/**
* Customizes the initial SSE request to the server (the request that begins the stream).
*
* NOTE: Setting this property will prevent an `Authorization` header from
* being automatically attached to the SSE request, if an `authProvider` is
* also given. This can be worked around by setting the `Authorization` header
* manually.
*/
eventSourceInit?: EventSourceInit;
/**
* Customizes recurring POST requests to the server.
*/
requestInit?: RequestInit;
};
/**
* Client transport for SSE: this will connect to a server using Server-Sent Events for receiving
* messages and make separate POST requests for sending messages.
*/
export class SSEClientTransport implements Transport {
private _eventSource?: EventSource;
private _endpoint?: URL;
private _abortController?: AbortController;
private _url: URL;
private _eventSourceInit?: EventSourceInit;
private _requestInit?: RequestInit;
private _authProvider?: OAuthClientProvider;
onclose?: () => void;
onerror?: (error: Error) => void;
onmessage?: (message: JSONRPCMessage) => void;
constructor(
url: URL,
opts?: SSEClientTransportOptions,
) {
this._url = url;
this._eventSourceInit = opts?.eventSourceInit;
this._requestInit = opts?.requestInit;
this._authProvider = opts?.authProvider;
}
private async _authThenStart(): Promise<void> {
if (!this._authProvider) {
throw new UnauthorizedError("No auth provider");
}
let result: AuthResult;
try {
result = await auth(this._authProvider, { serverUrl: this._url });
} catch (error) {
this.onerror?.(error as Error);
throw error;
}
if (result !== "AUTHORIZED") {
throw new UnauthorizedError();
}
return await this._startOrAuth();
}
private async _commonHeaders(): Promise<HeadersInit> {
const headers: HeadersInit = {};
if (this._authProvider) {
const tokens = await this._authProvider.tokens();
if (tokens) {
headers["Authorization"] = `Bearer ${tokens.access_token}`;
}
}
return headers;
}
private _startOrAuth(): Promise<void> {
return new Promise((resolve, reject) => {
this._eventSource = new EventSource(
this._url.href,
this._eventSourceInit ?? {
fetch: async (url, init) => {
const commonHeaders = await this._commonHeaders();
const allHeaders = { ...commonHeaders, ...this._requestInit?.headers};
return fetch(url, {
...init,
headers: {
...allHeaders,
Accept: "text/event-stream"
}
})
}
},
);
this._abortController = new AbortController();
this._eventSource.onerror = (event) => {
if (event.code === 401 && this._authProvider) {
this._authThenStart().then(resolve, reject);
return;
}
const error = new SseError(event.code, event.message, event);
reject(error);
this.onerror?.(error);
};
this._eventSource.onopen = () => {
// The connection is open, but we need to wait for the endpoint to be received.
};
this._eventSource.addEventListener("endpoint", (event: Event) => {
const messageEvent = event as MessageEvent;
try {
this._endpoint = new URL(messageEvent.data, this._url);
if (this._endpoint.origin !== this._url.origin) {
throw new Error(
`Endpoint origin does not match connection origin: ${this._endpoint.origin}`,
);
}
} catch (error) {
reject(error);
this.onerror?.(error as Error);
void this.close();
return;
}
resolve();
});
this._eventSource.onmessage = (event: Event) => {
const messageEvent = event as MessageEvent;
let message: JSONRPCMessage;
try {
message = JSONRPCMessageSchema.parse(JSON.parse(messageEvent.data));
} catch (error) {
this.onerror?.(error as Error);
return;
}
this.onmessage?.(message);
};
});
}
async start() {
if (this._eventSource) {
throw new Error(
"SSEClientTransport already started! If using Client class, note that connect() calls start() automatically.",
);
}
return await this._startOrAuth();
}
/**
* Call this method after the user has finished authorizing via their user agent and is redirected back to the MCP client application. This will exchange the authorization code for an access token, enabling the next connection attempt to successfully auth.
*/
async finishAuth(authorizationCode: string): Promise<void> {
if (!this._authProvider) {
throw new UnauthorizedError("No auth provider");
}
const result = await auth(this._authProvider, { serverUrl: this._url, authorizationCode });
if (result !== "AUTHORIZED") {
throw new UnauthorizedError("Failed to authorize");
}
}
async close(): Promise<void> {
this._abortController?.abort();
this._eventSource?.close();
this.onclose?.();
}
async send(message: JSONRPCMessage): Promise<void> {
if (!this._endpoint) {
throw new Error("Not connected");
}
try {
const commonHeaders = await this._commonHeaders();
const headers = new Headers({ ...commonHeaders, ...this._requestInit?.headers });
headers.set("content-type", "application/json");
const init = {
...this._requestInit,
method: "POST",
headers,
body: JSON.stringify(message),
signal: this._abortController?.signal,
};
const response = await fetch(this._endpoint, init);
if (!response.ok) {
if (response.status === 401 && this._authProvider) {
const result = await auth(this._authProvider, { serverUrl: this._url });
if (result !== "AUTHORIZED") {
throw new UnauthorizedError();
}
// Purposely _not_ awaited, so we don't call onerror twice
return this.send(message);
}
const text = await response.text().catch(() => null);
throw new Error(
`Error POSTing to endpoint (HTTP ${response.status}): ${text}`,
);
}
} catch (error) {
this.onerror?.(error as Error);
throw error;
}
}
}