Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 92 additions & 14 deletions core/context/mcp/MCPConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ function is401Error(error: unknown) {
);
}

function createMcpClient() {
return new Client(
{
name: "continue-client",
version: "1.0.0",
},
{
capabilities: {},
},
);
}

export type MCPExtras = {
ide: IDE;
};
Expand Down Expand Up @@ -85,19 +97,70 @@ class MCPConnection {
// Don't construct transport in constructor to avoid blocking
this.transport = {} as Transport; // Will be set in connectClient

this.client = new Client(
{
name: "continue-client",
version: "1.0.0",
},
{
capabilities: {},
},
);
this.client = createMcpClient();

this.abortController = new AbortController();
}

private async resetClientAndTransport() {
try {
await this.client.close();
} catch {
// Ignore close errors while replacing stale clients/transports.
}

try {
await this.transport.close?.();
} catch {
// Ignore close errors while replacing stale clients/transports.
}

this.client = createMcpClient();
this.transport = {} as Transport;
}

private shouldReconnectAfterError(error: unknown) {
if (this.options.type !== "sse" || this.status === "disabled") {
return false;
}

const message = (
error instanceof Error ? error.message : String(error)
).toLowerCase();

const sessionError =
message.includes("session") &&
(message.includes("invalid") ||
message.includes("unknown") ||
message.includes("expired") ||
message.includes("not found") ||
message.includes("valid") ||
message.includes("missing"));

return (
sessionError ||
message.includes("connection closed") ||
message.includes("transport closed")
);
}

private async withSseReconnectRetry<T>(operation: () => Promise<T>) {
try {
return await operation();
} catch (error) {
if (!this.shouldReconnectAfterError(error)) {
throw error;
}

await this.connectClient(true, new AbortController().signal);
if (this.status !== "connected") {
throw error;
}

return await operation();
}
}

async disconnect(disable = false) {
this.abortController.abort();
await this.client.close();
Expand Down Expand Up @@ -147,6 +210,7 @@ class MCPConnection {

this.abortController.abort();
this.abortController = new AbortController();
await this.resetClientAndTransport();

// currently support oauth for sse transports only
if (this.options.type === "sse") {
Expand Down Expand Up @@ -613,11 +677,25 @@ Org-level secrets can only be used for MCP by Background Agents (https://docs.co
}

async getResource(uri: string) {
return await this.client.readResource(
{ uri },
{
timeout: this.options.timeout,
},
return await this.withSseReconnectRetry(() =>
this.client.readResource(
{ uri },
{
timeout: this.options.timeout,
},
),
);
}

async getPrompt(...args: Parameters<Client["getPrompt"]>) {
return await this.withSseReconnectRetry(() =>
this.client.getPrompt(...args),
);
}

async callTool(...args: Parameters<Client["callTool"]>) {
return await this.withSseReconnectRetry(() =>
this.client.callTool(...args),
);
}
}
Expand Down
43 changes: 43 additions & 0 deletions core/context/mcp/MCPConnection.vitest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,49 @@ describe("MCPConnection", () => {
expect(mockConnect).toHaveBeenCalled();
});

it("should reconnect and retry SSE tool calls after stale session errors", async () => {
const conn = new MCPConnection({
name: "test-mcp",
id: "test-id",
type: "sse",
url: "http://test.com/events",
});
conn.status = "connected";

const mockCallTool = vi
.spyOn(Client.prototype, "callTool")
.mockRejectedValueOnce(new Error("Invalid session ID"))
.mockResolvedValueOnce({ content: [], isError: false } as any);
const mockReconnect = vi
.spyOn(conn, "connectClient")
.mockImplementation(async () => {
conn.status = "connected";
});

const result = await conn.callTool({ name: "test-tool" } as any);

expect(result).toEqual({ content: [], isError: false });
expect(mockReconnect).toHaveBeenCalledWith(true, expect.any(AbortSignal));
expect(mockCallTool).toHaveBeenCalledTimes(2);
});

it("should not retry non-SSE tool calls after stale session errors", async () => {
const conn = new MCPConnection(options);
conn.status = "connected";

const mockCallTool = vi
.spyOn(Client.prototype, "callTool")
.mockRejectedValue(new Error("Invalid session ID"));
const mockReconnect = vi.spyOn(conn, "connectClient");

await expect(conn.callTool({ name: "test-tool" } as any)).rejects.toThrow(
"Invalid session ID",
);

expect(mockReconnect).not.toHaveBeenCalled();
expect(mockCallTool).toHaveBeenCalledTimes(1);
});

it.skip("should include stderr output in error message when stdio command fails", async () => {
// Clear any existing mocks to ensure we get real behavior
vi.restoreAllMocks();
Expand Down
2 changes: 1 addition & 1 deletion core/context/mcp/MCPManagerSingleton.ts
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ export class MCPManagerSingleton {
`Error getting prompt: MCP Connection ${serverName} not found`,
);
}
return await connection.client.getPrompt({
return await connection.getPrompt({
name: promptName,
arguments: args,
});
Expand Down
2 changes: 1 addition & 1 deletion core/tools/callTool.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async function callToolFromUri(
args,
extras.tool?.function?.parameters,
);
const response = await client.client.callTool(
const response = await client.callTool(
{
name: toolName,
arguments: coercedArgs,
Expand Down
Loading