diff --git a/src/client/sse.test.ts b/src/client/sse.test.ts index 77b28508..5976633a 100644 --- a/src/client/sse.test.ts +++ b/src/client/sse.test.ts @@ -329,6 +329,29 @@ describe("SSEClientTransport", () => { expect(mockAuthProvider.tokens).toHaveBeenCalled(); }); + it("attaches custom header from provider on initial SSE connection", async () => { + mockAuthProvider.tokens.mockResolvedValue({ + access_token: "test-token", + token_type: "Bearer" + }); + const customHeaders = { + "X-Custom-Header": "custom-value", + }; + + transport = new SSEClientTransport(baseUrl, { + authProvider: mockAuthProvider, + requestInit: { + headers: customHeaders, + }, + }); + + await transport.start(); + + expect(lastServerRequest.headers.authorization).toBe("Bearer test-token"); + expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value"); + expect(mockAuthProvider.tokens).toHaveBeenCalled(); + }); + it("attaches auth header from provider on POST requests", async () => { mockAuthProvider.tokens.mockResolvedValue({ access_token: "test-token", diff --git a/src/client/sse.ts b/src/client/sse.ts index 5e9f0cf0..7e2c2d81 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -113,13 +113,17 @@ export class SSEClientTransport implements Transport { this._eventSource = new EventSource( this._url.href, this._eventSourceInit ?? { - fetch: (url, init) => this._commonHeaders().then((headers) => fetch(url, { - ...init, - headers: { - ...headers, - Accept: "text/event-stream" - } - })), + fetch: async (url, init) => { + const commonHeaders = await this._commonHeaders(); + const allHeaders = { ...commonHeaders, ...this._requestInit?.headers}; + return fetch(url, { + ...init, + headers: { + ...allHeaders, + Accept: "text/event-stream" + } + }) + } }, ); this._abortController = new AbortController();