Skip to content

Commit 2af2b16

Browse files
committed
Added cachedPrefixes to cache long system prompts when creating engine.
1 parent 632d347 commit 2af2b16

File tree

3 files changed

+153
-16
lines changed

3 files changed

+153
-16
lines changed

src/config.ts

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {
99
NonNegativeError,
1010
RangeError,
1111
} from "./error";
12+
import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completion";
1213

1314
/**
1415
* Conversation template config
@@ -114,6 +115,7 @@ export interface MLCEngineConfig {
114115
initProgressCallback?: InitProgressCallback;
115116
logitProcessorRegistry?: Map<string, LogitProcessor>;
116117
logLevel?: LogLevel;
118+
cachedPrefixes?: ChatCompletionMessageParam[][];
117119
}
118120

119121
/**

src/engine.ts

+13
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ export class MLCEngine implements MLCEngineInterface {
131131
private logitProcessorRegistry?: Map<string, LogitProcessor>;
132132
private initProgressCallback?: InitProgressCallback;
133133
private appConfig: AppConfig;
134+
private cachedPrefixes: ChatCompletionMessageParam[][];
134135

135136
// Signals and flags
136137
private interruptSignal = false;
@@ -149,6 +150,7 @@ export class MLCEngine implements MLCEngineInterface {
149150
this.setLogLevel(engineConfig?.logLevel || DefaultLogLevel);
150151
this.setInitProgressCallback(engineConfig?.initProgressCallback);
151152
this.setLogitProcessorRegistry(engineConfig?.logitProcessorRegistry);
153+
this.cachedPrefixes = engineConfig?.cachedPrefixes || [];
152154

153155
this.chat = new API.Chat(this);
154156
this.completions = new API.Completions(this);
@@ -392,6 +394,16 @@ export class MLCEngine implements MLCEngineInterface {
392394
this.loadedModelIdToPipeline.set(modelId, newPipeline);
393395
this.loadedModelIdToLock.set(modelId, new CustomLock());
394396

397+
// Call prefillConvSequence() if cachedPrefixes is specified
398+
if (
399+
newPipeline instanceof LLMChatPipeline &&
400+
this.cachedPrefixes.length > 0
401+
) {
402+
for (let i = 0; i < this.cachedPrefixes.length; i++) {
403+
await newPipeline.prefillConvSequence(this.cachedPrefixes[i]);
404+
}
405+
}
406+
395407
// Clean up
396408
const tend = performance.now();
397409
if (this.initProgressCallback !== undefined) {
@@ -444,6 +456,7 @@ export class MLCEngine implements MLCEngineInterface {
444456
if (genConfig !== undefined) {
445457
postInitAndCheckGenerationConfigValues(genConfig);
446458
}
459+
console.log("prefill in _generate, input: ", input);
447460
await this.prefill(input, pipeline, chatConfig, genConfig);
448461

449462
while (!pipeline.stopped()) {

src/llm_chat.ts

+138-16
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import {
3434
PrefillChunkSizeSmallerThanImageError,
3535
CannotFindImageEmbedError,
3636
} from "./error";
37+
import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completion";
3738

3839
type ImageURL = ChatCompletionContentPartImage.ImageURL;
3940

@@ -128,6 +129,8 @@ export class LLMChatPipeline {
128129
private curRoundGrammarInitTotalTime = 0;
129130
// Total time of getting next bitmask and accepting token in seconds
130131
private curRoundGrammarPerTokenTotalTime = 0;
132+
private seqIdToPrefix: Map<number, number[]>;
133+
private nextSequenceId: number;
131134

132135
constructor(
133136
tvm: tvmjs.Instance,
@@ -173,6 +176,8 @@ export class LLMChatPipeline {
173176
log.info("token_postproc_method: ", this.token_postproc_method);
174177
log.info("prepend_space_in_encode: ", this.prepend_space_in_encode);
175178

179+
this.seqIdToPrefix = new Map<number, number[]>();
180+
this.nextSequenceId = 0;
176181
this.device = this.tvm.webgpu();
177182

178183
// 1. Create VM and get the core functions
@@ -344,7 +349,12 @@ export class LLMChatPipeline {
344349
* Reset KV Cache
345350
*/
346351
resetKVCache() {
347-
this.fclearKVCaches(this.kvCache);
352+
// Check whether to keep prefixes in the KV cache
353+
if (this.seqIdToPrefix.size === 0) {
354+
this.fclearKVCaches(this.kvCache);
355+
} else {
356+
this.fKVCacheRemoveSequence!(this.kvCache, new tvmjs.Scalar(0, "int64"));
357+
}
348358
this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(0, "int64"));
349359
if (this.slidingWindowSize != -1) {
350360
this.fKVCacheEnableSlidingWindowForSeq(
@@ -483,6 +493,15 @@ export class LLMChatPipeline {
483493
await this.tvm.asyncLoadWebGPUPipelines(this.vm.getInternalModule());
484494
}
485495

496+
matchPrefix(inputTokens: number[], prefixTokens: number[]): number {
497+
for (let i = 0; i < prefixTokens.length; i++) {
498+
if (inputTokens[i] !== prefixTokens[i]) {
499+
return i;
500+
}
501+
}
502+
return prefixTokens.length;
503+
}
504+
486505
/**
487506
* Generate the first token given input prompt
488507
*/
@@ -491,11 +510,17 @@ export class LLMChatPipeline {
491510
msgRole: Role, // either user or tool
492511
inp_role_str?: string,
493512
genConfig?: GenerationConfig,
513+
seqID = 0,
494514
): Promise<void> {
495-
if (msgRole !== Role.user && msgRole !== Role.tool) {
496-
throw new MessageOrderError(
497-
"The last message should be from `user` or `tool`.",
498-
);
515+
if (seqID === 0) {
516+
if (msgRole !== Role.user && msgRole !== Role.tool) {
517+
throw new MessageOrderError(
518+
"The last message should be from `user` or `tool`.",
519+
);
520+
}
521+
} else {
522+
// Set the input as system prompt during prefix prefilling
523+
this.conversation.override_system_message = inp;
499524
}
500525
if (this.resetStatsPerPrefill) {
501526
this.resetRuntimeStats();
@@ -583,11 +608,13 @@ export class LLMChatPipeline {
583608
}
584609

585610
// 0. Get inputData from conversation
586-
if (conversation.isTextCompletion) {
587-
conversation.prompt = inp;
588-
} else {
589-
conversation.appendMessage(msgRole, inp, inp_role_str);
590-
conversation.appendReplyHeader(Role.assistant);
611+
if (seqID === 0) {
612+
if (conversation.isTextCompletion) {
613+
conversation.prompt = inp;
614+
} else {
615+
conversation.appendMessage(msgRole, inp, inp_role_str);
616+
conversation.appendReplyHeader(Role.assistant);
617+
}
591618
}
592619
const retGetInputData = this.getInputData();
593620
const inputData: Array<Array<number> | ImageURL> = retGetInputData[0];
@@ -610,11 +637,68 @@ export class LLMChatPipeline {
610637
throw new CannotFindImageEmbedError();
611638
}
612639

640+
let maxMatchedLen = -1;
641+
let matchedSeqId = -1;
642+
643+
// Prefix matching and forking
644+
const inputTokens = inputData.flat() as number[];
645+
for (const [id, prefixTokens] of this.seqIdToPrefix) {
646+
const matchedLen = this.matchPrefix(inputTokens, prefixTokens);
647+
if (matchedLen > maxMatchedLen) {
648+
maxMatchedLen = matchedLen;
649+
matchedSeqId = id;
650+
}
651+
}
652+
653+
// If a match is found, fork the sequence
654+
if (matchedSeqId !== -1 && maxMatchedLen > 0) {
655+
console.log(
656+
"Forking sequence",
657+
matchedSeqId,
658+
"at position",
659+
maxMatchedLen,
660+
);
661+
if (seqID === 0) {
662+
this.fKVCacheRemoveSequence!(
663+
this.kvCache,
664+
new tvmjs.Scalar(seqID, "int64"),
665+
);
666+
}
667+
this.tvm.beginScope();
668+
this.tvm.getGlobalFunc("vm.builtin.kv_state_fork_sequence")(
669+
this.kvCache,
670+
new tvmjs.Scalar(matchedSeqId, "int64"), // fork_parent_id
671+
new tvmjs.Scalar(seqID, "int64"), // fork_child_id
672+
new tvmjs.Scalar(maxMatchedLen, "int64"), // fork_position
673+
);
674+
this.tvm.endScope();
675+
} else if (seqID !== 0) {
676+
// If no match is found, add the new sequence to the KV cache
677+
console.log("Adding new sequence to KV cache: ", seqID);
678+
this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(seqID, "int64"));
679+
}
680+
681+
// Add the new sequence to the seqIdToPrefix map (if it is a prefix)
682+
if (seqID !== 0) {
683+
this.seqIdToPrefix.set(seqID, inputTokens);
684+
}
685+
613686
// 1. Chunk inputData to embed and forward in one shot for each, minimize intermediate data
614-
const retGetChunks = getChunkedPrefillInputData(
615-
inputData,
616-
this.prefillChunkSize,
617-
);
687+
let retGetChunks;
688+
if (maxMatchedLen === -1) {
689+
retGetChunks = getChunkedPrefillInputData(
690+
inputData,
691+
this.prefillChunkSize,
692+
);
693+
} else {
694+
// If a matched prefix exists, only forward the remaining tokens
695+
retGetChunks = getChunkedPrefillInputData(
696+
inputData.map((arr) =>
697+
Array.isArray(arr) ? arr.slice(maxMatchedLen) : arr,
698+
),
699+
this.prefillChunkSize,
700+
);
701+
}
618702
const chunks: Array<Array<number> | ImageURL>[] = retGetChunks[0];
619703
const chunkLens: Array<number> = retGetChunks[1];
620704

@@ -626,7 +710,7 @@ export class LLMChatPipeline {
626710
const chunkLen = chunkLens[i];
627711
const prevFilledLen = this.filledKVCacheLength;
628712
logits = this.tvm.detachFromCurrentScope(
629-
await this.embedAndForward(chunk, chunkLen),
713+
await this.embedAndForward(chunk, chunkLen, seqID),
630714
);
631715
if (this.filledKVCacheLength !== prevFilledLen + chunkLen) {
632716
throw new Error(
@@ -651,6 +735,41 @@ export class LLMChatPipeline {
651735
this.processNextToken(nextToken, genConfig);
652736
}
653737

738+
async prefillConvSequence(
739+
messages: ChatCompletionMessageParam[],
740+
inp_role_str?: string,
741+
genConfig?: GenerationConfig,
742+
): Promise<void> {
743+
for (const message of messages) {
744+
this.nextSequenceId = this.nextSequenceId + 1;
745+
const newSeqId = this.nextSequenceId;
746+
// Call the regular prefillStep with the new seqID
747+
if (typeof message.content === "string") {
748+
// Support long system prompt
749+
if (message.role === "system") {
750+
await this.prefillStep(
751+
message.content,
752+
Role.tool,
753+
inp_role_str,
754+
genConfig,
755+
newSeqId,
756+
);
757+
} else {
758+
throw Error(
759+
"Invalid role in prefix message: " +
760+
message.role +
761+
", expected 'system'.",
762+
);
763+
}
764+
} else {
765+
throw Error(
766+
"Invalid content in prefix message, does not support image input.",
767+
);
768+
}
769+
}
770+
this.conversation.reset();
771+
}
772+
654773
async decodeStep(genConfig?: GenerationConfig): Promise<void> {
655774
if (this.stopTriggered) {
656775
throw Error("Cannot run decode when stopped");
@@ -869,13 +988,15 @@ export class LLMChatPipeline {
869988
*
870989
* @param inputData data to embed and forward
871990
* @param inputDataLen length of this inputData, should smaller than prefill chunk size.
991+
* @param seqID sequence ID of the input data in KV cache for prefix caching
872992
* @returns The logits returned by this forward as tvmjs.NDArray on GPU.
873993
*
874994
* @note Precondition: inputData's data length is smaller than prefill chunk size
875995
*/
876996
private async embedAndForward(
877997
inputData: Array<Array<number> | ImageURL>,
878998
inputDataLen: number,
999+
seqID = 0,
8791000
): Promise<tvmjs.NDArray> {
8801001
if (inputDataLen > this.prefillChunkSize) {
8811002
throw new Error(
@@ -913,7 +1034,8 @@ export class LLMChatPipeline {
9131034

9141035
// 3. Forward the concatenated embeddings
9151036
const inputLenShape = this.tvm.makeShapeTuple([inputDataLen]);
916-
const seqIdsTuple = this.tvm.makeShapeTuple([0]);
1037+
// set seqIdsTuple to be childID
1038+
const seqIdsTuple = this.tvm.makeShapeTuple([seqID]);
9171039
this.fKVCacheBeginForward!(this.kvCache, seqIdsTuple, inputLenShape);
9181040
let retValue;
9191041
if (inputDataLen > 1) {

0 commit comments

Comments
 (0)