Skip to content

Make Apply Streaming Cancelable #5694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core/commands/slash/http.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { streamResponse } from "@continuedev/fetch";
import { SlashCommand } from "../../index.js";
import { removeQuotesAndEscapes } from "../../util/index.js";
import { streamResponse } from "../../llm/stream.js";

const HttpSlashCommand: SlashCommand = {
name: "http",
Expand Down
135 changes: 65 additions & 70 deletions core/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ import { TTS } from "./util/tts";

import {
ContextItemWithId,
DiffLine,
IdeSettings,
ModelDescription,
RangeInFile,
Expand All @@ -62,61 +61,7 @@ import { LLMLogger } from "./llm/logger";
import { llmStreamChat } from "./llm/streamChat";
import type { FromCoreProtocol, ToCoreProtocol } from "./protocol";
import type { IMessenger, Message } from "./protocol/messenger";

// This function is used for jetbrains inline edit and apply
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed streamDiffLinesGenerator/moved logic inline because point of it was to account for aborted message ids which wasn't being used, made ~1/2 the lines duplicate, and is now defunct

async function* streamDiffLinesGenerator(
configHandler: ConfigHandler,
abortedMessageIds: Set<string>,
msg: Message<ToCoreProtocol["streamDiffLines"][0]>,
): AsyncGenerator<DiffLine> {
const {
highlighted,
prefix,
suffix,
input,
language,
modelTitle,
includeRulesInSystemMessage,
} = msg.data;

const { config } = await configHandler.loadConfig();
if (!config) {
throw new Error("Failed to load config");
}

// Title can be an edit, chat, or apply model
// Fall back to chat
const llm =
config.modelsByRole.edit.find((m) => m.title === modelTitle) ??
config.modelsByRole.apply.find((m) => m.title === modelTitle) ??
config.modelsByRole.chat.find((m) => m.title === modelTitle) ??
config.selectedModelByRole.chat;

if (!llm) {
throw new Error("No model selected");
}

// rules included for edit, NOT apply
const rules = includeRulesInSystemMessage ? config.rules : undefined;

for await (const diffLine of streamDiffLines({
highlighted,
prefix,
suffix,
llm,
rulesToInclude: rules,
input,
language,
onlyOneInsertion: false,
overridePrompt: undefined,
})) {
if (abortedMessageIds.has(msg.messageId)) {
abortedMessageIds.delete(msg.messageId);
break;
}
yield diffLine;
}
}
import { StreamAbortManager } from "./util/abortManager";

export class Core {
configHandler: ConfigHandler;
Expand All @@ -132,7 +77,18 @@ export class Core {
this.globalContext.get("indexingPaused") === true,
);

private abortedMessageIds: Set<string> = new Set();
private messageAbortControllers = new Map<string, AbortController>();
private addMessageAbortController(id: string): AbortController {
const controller = new AbortController();
this.messageAbortControllers.set(id, controller);
controller.signal.addEventListener("abort", () => {
this.messageAbortControllers.delete(id);
});
return controller;
}
private abortById(messageId: string) {
this.messageAbortControllers.get(messageId)?.abort();
}

invoke<T extends keyof ToCoreProtocol>(
messageType: T,
Expand Down Expand Up @@ -301,7 +257,7 @@ export class Core {
});

on("abort", (msg) => {
this.abortedMessageIds.add(msg.messageId);
this.abortById(msg.data ?? msg.messageId);
});

on("ping", (msg) => {
Expand Down Expand Up @@ -470,27 +426,28 @@ export class Core {
}
});

on("llm/streamChat", (msg) =>
llmStreamChat(
on("llm/streamChat", (msg) => {
const abortController = this.addMessageAbortController(msg.messageId);
return llmStreamChat(
this.configHandler,
this.abortedMessageIds,
abortController,
msg,
this.ide,
this.messenger,
),
);
);
});

on("llm/complete", async (msg) => {
const model = (await this.configHandler.loadConfig()).config
?.selectedModelByRole.chat;

const { config } = await this.configHandler.loadConfig();
const model = config?.selectedModelByRole.chat;
if (!model) {
throw new Error("No chat model selected");
}
const abortController = this.addMessageAbortController(msg.messageId);

const completion = await model.complete(
msg.data.prompt,
new AbortController().signal,
abortController.signal,
msg.data.completionOptions,
);
return completion;
Expand Down Expand Up @@ -532,9 +489,47 @@ export class Core {
this.completionProvider.cancel();
});

on("streamDiffLines", (msg) =>
streamDiffLinesGenerator(this.configHandler, this.abortedMessageIds, msg),
);
on("streamDiffLines", async (msg) => {
const { config } = await this.configHandler.loadConfig();
if (!config) {
throw new Error("Failed to load config");
}

const { data } = msg;

// Title can be an edit, chat, or apply model
// Fall back to chat
const llm =
config.modelsByRole.edit.find((m) => m.title === data.modelTitle) ??
config.modelsByRole.apply.find((m) => m.title === data.modelTitle) ??
config.modelsByRole.chat.find((m) => m.title === data.modelTitle) ??
config.selectedModelByRole.chat;

if (!llm) {
throw new Error("No model selected");
}

return streamDiffLines({
highlighted: data.highlighted,
prefix: data.prefix,
suffix: data.suffix,
llm,
// rules included for edit, NOT apply
rulesToInclude: data.includeRulesInSystemMessage
? config.rules
: undefined,
input: data.input,
language: data.language,
onlyOneInsertion: false,
overridePrompt: undefined,
abortControllerId: data.fileUri ?? "current-file-stream", // not super important since currently cancelling apply will cancel all streams it's one file at a time
});
});

on("cancelApply", async (msg) => {
const abortManager = StreamAbortManager.getInstance();
abortManager.clear();
});

on("completeOnboarding", this.handleCompleteOnboarding.bind(this));

Expand Down
7 changes: 5 additions & 2 deletions core/edit/recursiveStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ const RECURSIVE_PROMPT = `Continue EXACTLY where you left`;

export async function* recursiveStream(
llm: ILLM,
abortController: AbortController,
prompt: ChatMessage[] | string,
prediction: Prediction | undefined,
currentBuffer = "",
Expand All @@ -28,7 +29,7 @@ export async function* recursiveStream(
// let whiteSpaceAtEndOfBuffer = buffer.match(/\s*$/)?.[0] ?? ""; // attempts at fixing whitespace bug with recursive boundaries

if (typeof prompt === "string") {
const generator = llm.streamComplete(prompt, new AbortController().signal, {
const generator = llm.streamComplete(prompt, abortController.signal, {
raw: true,
prediction: undefined,
reasoning: false,
Expand All @@ -50,6 +51,7 @@ export async function* recursiveStream(
// // TODO - Prediction capabilities lost because of partial input
// yield* recursiveStream(
// llm,
// abortController,
// continuationPrompt,
// undefined,
// buffer,
Expand All @@ -60,7 +62,7 @@ export async function* recursiveStream(
}
}
} else {
const generator = llm.streamChat(prompt, new AbortController().signal, {
const generator = llm.streamChat(prompt, abortController.signal, {
prediction,
reasoning: false,
});
Expand Down Expand Up @@ -90,6 +92,7 @@ export async function* recursiveStream(
// await generator.return(DUD_PROMPT_LOG);
// yield* recursiveStream(
// llm,
// abortController,
// continuationPrompt,
// undefined,
// buffer,
Expand Down
7 changes: 6 additions & 1 deletion core/edit/streamDiffLines.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import { streamDiff } from "../diff/streamDiff";
import { streamLines } from "../diff/util";
import { getSystemMessageWithRules } from "../llm/rules/getSystemMessageWithRules";
import { gptEditPrompt } from "../llm/templates/edit";
import { StreamAbortManager } from "../util/abortManager";
import { findLast } from "../util/findLast";
import { Telemetry } from "../util/posthog";
import { recursiveStream } from "./recursiveStream";
Expand Down Expand Up @@ -63,6 +64,7 @@ export async function* streamDiffLines({
highlighted,
suffix,
llm,
abortControllerId,
input,
language,
onlyOneInsertion,
Expand All @@ -73,12 +75,15 @@ export async function* streamDiffLines({
highlighted: string;
suffix: string;
llm: ILLM;
abortControllerId: string;
input: string;
language: string | undefined;
onlyOneInsertion: boolean;
overridePrompt: ChatMessage[] | undefined;
rulesToInclude: RuleWithSource[] | undefined;
}): AsyncGenerator<DiffLine> {
const abortManager = StreamAbortManager.getInstance();
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took the singleton approach so that functional stream diff lines function (also used in vscode-only code) can have persisted abort controllers that can be cancelled with core messaging

const abortController = abortManager.get(abortControllerId);
void Telemetry.capture(
"inlineEdit",
{
Expand Down Expand Up @@ -157,7 +162,7 @@ export async function* streamDiffLines({
content: highlighted,
};

const completion = recursiveStream(llm, prompt, prediction);
const completion = recursiveStream(llm, abortController, prompt, prediction);

let lines = streamLines(completion);

Expand Down
1 change: 1 addition & 0 deletions core/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,7 @@ export interface StreamDiffLinesPayload {
language: string | undefined;
modelTitle: string | undefined;
includeRulesInSystemMessage: boolean;
fileUri?: string;
}

export interface HighlightedCodePayload {
Expand Down
3 changes: 3 additions & 0 deletions core/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ export abstract class BaseLLM implements ILLM {

// Error mapping to be more helpful
if (!resp.ok) {
if (resp.status === 499) {
return resp; // client side cancellation
}
let text = await resp.text();
if (resp.status === 404 && !resp.url.includes("/v1")) {
const error = JSON.parse(text)?.error?.replace(/"/g, "'");
Expand Down
6 changes: 5 additions & 1 deletion core/llm/llms/Anthropic.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { streamSse } from "@continuedev/fetch";
import { ChatMessage, CompletionOptions, LLMOptions } from "../../index.js";
import { renderChatMessage, stripImages } from "../../util/messageContent.js";
import { BaseLLM } from "../index.js";
import { streamSse } from "../stream.js";

class Anthropic extends BaseLLM {
static providerName = "anthropic";
Expand Down Expand Up @@ -213,6 +213,10 @@ class Anthropic extends BaseLLM {
signal,
});

if (response.status === 499) {
return; // Aborted by user
}

if (!response.ok) {
const json = await response.json();
if (json.type === "error") {
Expand Down
12 changes: 4 additions & 8 deletions core/llm/llms/Asksage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,8 @@ class Asksage extends BaseLLM {
body: JSON.stringify(args),
});

if (!response.ok) {
throw new Error(
`API request failed with status ${response.status}: ${response.statusText}`,
);
if (response.status === 499) {
return ""; // Aborted by user
}

const data = await response.json();
Expand Down Expand Up @@ -167,10 +165,8 @@ class Asksage extends BaseLLM {
signal,
});

if (!response.ok) {
throw new Error(
`API request failed with status ${response.status}: ${response.statusText}`,
);
if (response.status === 499) {
return; // Aborted by user
}

const data = await response.json();
Expand Down
2 changes: 1 addition & 1 deletion core/llm/llms/Cloudflare.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { streamSse } from "@continuedev/fetch";
import { ChatMessage, CompletionOptions } from "../../index.js";
import { renderChatMessage } from "../../util/messageContent.js";
import { BaseLLM } from "../index.js";
import { streamSse } from "../stream.js";

export default class Cloudflare extends BaseLLM {
static providerName = "cloudflare";
Expand Down
6 changes: 5 additions & 1 deletion core/llm/llms/Cohere.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { streamJSON } from "@continuedev/fetch";
import {
ChatMessage,
Chunk,
Expand All @@ -6,7 +7,6 @@ import {
} from "../../index.js";
import { renderChatMessage, stripImages } from "../../util/messageContent.js";
import { BaseLLM } from "../index.js";
import { streamJSON } from "../stream.js";

class Cohere extends BaseLLM {
static providerName = "cohere";
Expand Down Expand Up @@ -84,6 +84,10 @@ class Cohere extends BaseLLM {
signal,
});

if (resp.status === 499) {
return; // Aborted by user
}

if (options.stream === false) {
const data = await resp.json();
yield { role: "assistant", content: data.text };
Expand Down
2 changes: 1 addition & 1 deletion core/llm/llms/Deepseek.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { streamSse } from "@continuedev/fetch";
import { CompletionOptions, LLMOptions } from "../../index.js";
import { streamSse } from "../stream.js";
import { osModelsEditPrompt } from "../templates/edit.js";

import OpenAI from "./OpenAI.js";
Expand Down
Loading
Loading