diff --git a/src/services/mcp/McpHub.ts b/src/services/mcp/McpHub.ts index f5ddf3e57b..8486c6c697 100644 --- a/src/services/mcp/McpHub.ts +++ b/src/services/mcp/McpHub.ts @@ -743,7 +743,9 @@ export class McpHub { } } else if (configInjected.type === "streamable-http") { // Streamable HTTP connection - transport = new StreamableHTTPClientTransport(new URL(configInjected.url), { + // Normalize URL by removing trailing slashes to avoid 405 errors + const normalizedUrl = configInjected.url.replace(/\/+$/, "") + transport = new StreamableHTTPClientTransport(new URL(normalizedUrl), { requestInit: { headers: configInjected.headers, }, @@ -769,6 +771,8 @@ export class McpHub { } } else if (configInjected.type === "sse") { // SSE connection + // Normalize URL by removing trailing slashes for consistency + const normalizedUrl = configInjected.url.replace(/\/+$/, "") const sseOptions = { requestInit: { headers: configInjected.headers, @@ -787,7 +791,7 @@ export class McpHub { }, } global.EventSource = ReconnectingEventSource - transport = new SSEClientTransport(new URL(configInjected.url), { + transport = new SSEClientTransport(new URL(normalizedUrl), { ...sseOptions, eventSourceInit: reconnectingEventSourceOptions, }) @@ -1378,8 +1382,8 @@ export class McpHub { await this.deleteConnection(serverName, serverSource) // Re-add as a disabled connection // Re-read config from file to get updated disabled state - const updatedConfig = await this.readServerConfigFromFile(serverName, serverSource) - await this.connectToServer(serverName, updatedConfig, serverSource) + const updatedConfig = await this.readServerConfigFromFile(serverName, serverSource) + await this.connectToServer(serverName, updatedConfig, serverSource) } else if (!disabled && connection.server.status === "disconnected") { // If enabling a disabled server, connect it // Re-read config from file to get updated disabled state diff --git a/src/services/mcp/__tests__/McpHub.spec.ts b/src/services/mcp/__tests__/McpHub.spec.ts index 1db924ed6c..d04894778e 100644 --- a/src/services/mcp/__tests__/McpHub.spec.ts +++ b/src/services/mcp/__tests__/McpHub.spec.ts @@ -2147,4 +2147,227 @@ describe("McpHub", () => { ) }) }) + + describe("URL normalization", () => { + let StdioClientTransport: ReturnType + let StreamableHTTPClientTransport: ReturnType + let SSEClientTransport: ReturnType + let Client: ReturnType + + beforeEach(async () => { + // Reset mocks + vi.clearAllMocks() + + // Get references to the mocked constructors + const stdioModule = await import("@modelcontextprotocol/sdk/client/stdio.js") + const clientModule = await import("@modelcontextprotocol/sdk/client/index.js") + StdioClientTransport = stdioModule.StdioClientTransport as ReturnType + Client = clientModule.Client as ReturnType + + // Mock StreamableHTTPClientTransport + vi.mock("@modelcontextprotocol/sdk/client/streamableHttp.js", () => ({ + StreamableHTTPClientTransport: vi.fn(), + })) + const streamableModule = await import("@modelcontextprotocol/sdk/client/streamableHttp.js") + StreamableHTTPClientTransport = streamableModule.StreamableHTTPClientTransport as ReturnType + + // Mock SSEClientTransport + vi.mock("@modelcontextprotocol/sdk/client/sse.js", () => ({ + SSEClientTransport: vi.fn(), + })) + const sseModule = await import("@modelcontextprotocol/sdk/client/sse.js") + SSEClientTransport = sseModule.SSEClientTransport as ReturnType + }) + + it("should remove trailing slash from streamable-http URLs", async () => { + // Mock StreamableHTTPClientTransport + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + onerror: null, + onclose: null, + } + + StreamableHTTPClientTransport.mockImplementation((url: URL, options: any) => { + // Verify that the trailing slash is removed + expect(url.toString()).toBe("https://api.githubcopilot.com/mcp") + return mockTransport + }) + + // Mock Client + Client.mockImplementation(() => ({ + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + })) + + // Create a new McpHub instance + const mcpHub = new McpHub(mockProvider as ClineProvider) + + // Mock the config file read with URL containing trailing slash + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + github: { + type: "streamable-http", + url: "https://api.githubcopilot.com/mcp/", + headers: { + Authorization: "Bearer GITHUB_PAT", + }, + }, + }, + }), + ) + + // Initialize servers (this will trigger connectToServer) + await mcpHub["initializeGlobalMcpServers"]() + + // Verify StreamableHTTPClientTransport was called with normalized URL + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(expect.any(URL), expect.any(Object)) + const urlArg = StreamableHTTPClientTransport.mock.calls[0][0] + expect(urlArg.toString()).toBe("https://api.githubcopilot.com/mcp") + }) + + it("should remove multiple trailing slashes from streamable-http URLs", async () => { + // Mock StreamableHTTPClientTransport + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + onerror: null, + onclose: null, + } + + StreamableHTTPClientTransport.mockImplementation((url: URL, options: any) => { + // Verify that multiple trailing slashes are removed + expect(url.toString()).toBe("https://api.example.com/endpoint") + return mockTransport + }) + + // Mock Client + Client.mockImplementation(() => ({ + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + })) + + // Create a new McpHub instance + const mcpHub = new McpHub(mockProvider as ClineProvider) + + // Mock the config file read with URL containing multiple trailing slashes + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "test-server": { + type: "streamable-http", + url: "https://api.example.com/endpoint///", + }, + }, + }), + ) + + // Initialize servers + await mcpHub["initializeGlobalMcpServers"]() + + // Verify StreamableHTTPClientTransport was called with normalized URL + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(expect.any(URL), expect.any(Object)) + const urlArg = StreamableHTTPClientTransport.mock.calls[0][0] + expect(urlArg.toString()).toBe("https://api.example.com/endpoint") + }) + + it("should remove trailing slash from SSE URLs", async () => { + // Mock SSEClientTransport + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + onerror: null, + onclose: null, + } + + SSEClientTransport.mockImplementation((url: URL, options: any) => { + // Verify that the trailing slash is removed + expect(url.toString()).toBe("https://sse.example.com/events") + return mockTransport + }) + + // Mock Client + Client.mockImplementation(() => ({ + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + })) + + // Create a new McpHub instance + const mcpHub = new McpHub(mockProvider as ClineProvider) + + // Mock the config file read with SSE URL containing trailing slash + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "sse-server": { + type: "sse", + url: "https://sse.example.com/events/", + }, + }, + }), + ) + + // Initialize servers + await mcpHub["initializeGlobalMcpServers"]() + + // Verify SSEClientTransport was called with normalized URL + expect(SSEClientTransport).toHaveBeenCalledWith(expect.any(URL), expect.any(Object)) + const urlArg = SSEClientTransport.mock.calls[0][0] + expect(urlArg.toString()).toBe("https://sse.example.com/events") + }) + + it("should handle URLs without trailing slashes correctly", async () => { + // Mock StreamableHTTPClientTransport + const mockTransport = { + start: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + onerror: null, + onclose: null, + } + + StreamableHTTPClientTransport.mockImplementation((url: URL, options: any) => { + // Verify that URL without trailing slash remains unchanged + expect(url.toString()).toBe("https://api.example.com/path") + return mockTransport + }) + + // Mock Client + Client.mockImplementation(() => ({ + connect: vi.fn().mockResolvedValue(undefined), + close: vi.fn().mockResolvedValue(undefined), + getInstructions: vi.fn().mockReturnValue("test instructions"), + request: vi.fn().mockResolvedValue({ tools: [], resources: [], resourceTemplates: [] }), + })) + + // Create a new McpHub instance + const mcpHub = new McpHub(mockProvider as ClineProvider) + + // Mock the config file read with URL without trailing slash + vi.mocked(fs.readFile).mockResolvedValue( + JSON.stringify({ + mcpServers: { + "no-slash-server": { + type: "streamable-http", + url: "https://api.example.com/path", + }, + }, + }), + ) + + // Initialize servers + await mcpHub["initializeGlobalMcpServers"]() + + // Verify URL remains unchanged + expect(StreamableHTTPClientTransport).toHaveBeenCalledWith(expect.any(URL), expect.any(Object)) + const urlArg = StreamableHTTPClientTransport.mock.calls[0][0] + expect(urlArg.toString()).toBe("https://api.example.com/path") + }) + }) })