Skip to content

Commit 88e9a41

Browse files
committed
Addressed review comments
1 parent 92e00b9 commit 88e9a41

File tree

6 files changed

+216
-15
lines changed

6 files changed

+216
-15
lines changed

examples/prefix-caching/README.md

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# WebLLM App for Prefix Caching Demo
2+
3+
This example demonstrates the use of `cachedPrefixes` in WebLLM.
4+
To try it out, you can do the following steps under this folder
5+
6+
```bash
7+
npm install
8+
npm start
9+
```
10+
11+
Note if you would like to hack WebLLM core package.
12+
You can change web-llm dependencies as `"file:../.."`, and follow the build from source
13+
instruction in the project to build webllm locally. This option is only recommended
14+
if you would like to hack WebLLM core package.

examples/prefix-caching/package.json

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"name": "prefix-caching-example",
3+
"version": "0.1.0",
4+
"private": true,
5+
"scripts": {
6+
"start": "parcel src/prefix-caching.html --port 8888",
7+
"build": "parcel build src/prefix-caching.html --dist-dir lib"
8+
},
9+
"devDependencies": {
10+
"buffer": "^5.7.1",
11+
"parcel": "^2.8.3",
12+
"process": "^0.11.10",
13+
"tslib": "^2.3.1",
14+
"typescript": "^4.9.5",
15+
"url": "^0.11.3"
16+
},
17+
"dependencies": {
18+
"@mlc-ai/web-llm": "../.."
19+
}
20+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
<!doctype html>
2+
<html>
3+
<script>
4+
webLLMGlobal = {};
5+
</script>
6+
<body>
7+
<h2>WebLLM Prefix Caching Test Page</h2>
8+
Open console to see output
9+
<br />
10+
<br />
11+
<label id="init-label"> </label>
12+
13+
<h3>Prompt</h3>
14+
<label id="prompt-label"> </label>
15+
16+
<h3>Response</h3>
17+
<label id="generate-label"> </label>
18+
<br />
19+
<label id="stats-label"> </label>
20+
21+
<script type="module" src="./prefix-caching.ts"></script>
22+
</body>
23+
</html>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import * as webllm from "@mlc-ai/web-llm";
2+
3+
const SYSTEM_PROMPT_PREFIX =
4+
"You are a helpful assistant running in the user's browser, responsible for answering questions.";
5+
6+
function setLabel(id: string, text: string) {
7+
const label = document.getElementById(id);
8+
if (label == null) {
9+
throw Error("Cannot find label " + id);
10+
}
11+
label.innerText = text;
12+
}
13+
14+
async function testPrefix() {
15+
const initProgressCallback = (report: webllm.InitProgressReport) => {
16+
setLabel("init-label", report.text);
17+
};
18+
19+
const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
20+
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
21+
selectedModel,
22+
{
23+
initProgressCallback: initProgressCallback,
24+
logLevel: "INFO",
25+
// Prefilling KV cache for efficiency
26+
cachedPrefixes: [[{ role: "system", content: SYSTEM_PROMPT_PREFIX }]],
27+
},
28+
{
29+
context_window_size: 2048,
30+
},
31+
);
32+
33+
const reply_using_prefix = await engine.chat.completions.create({
34+
messages: [
35+
{ role: "system", content: SYSTEM_PROMPT_PREFIX },
36+
{ role: "user", content: "List three US states." },
37+
],
38+
// below configurations are all optional
39+
n: 1,
40+
temperature: 1.5,
41+
max_tokens: 64,
42+
logprobs: true,
43+
top_logprobs: 2,
44+
});
45+
console.log(reply_using_prefix);
46+
console.log(reply_using_prefix.usage);
47+
}
48+
49+
async function testWithoutPrefix() {
50+
const initProgressCallback = (report: webllm.InitProgressReport) => {
51+
setLabel("init-label", report.text);
52+
};
53+
54+
const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
55+
// Engine Initialization without cachedPrefixes
56+
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
57+
selectedModel,
58+
{
59+
initProgressCallback: initProgressCallback,
60+
logLevel: "INFO",
61+
},
62+
{
63+
context_window_size: 2048,
64+
},
65+
);
66+
67+
const reply_without_prefix = await engine.chat.completions.create({
68+
messages: [
69+
{ role: "system", content: SYSTEM_PROMPT_PREFIX },
70+
{ role: "user", content: "List three US states." },
71+
],
72+
// below configurations are all optional
73+
n: 1,
74+
temperature: 1.5,
75+
max_tokens: 64,
76+
logprobs: true,
77+
top_logprobs: 2,
78+
});
79+
console.log(reply_without_prefix);
80+
console.log(reply_without_prefix.usage);
81+
}
82+
83+
async function testMultiRound() {
84+
const initProgressCallback = (report: webllm.InitProgressReport) => {
85+
setLabel("init-label", report.text);
86+
};
87+
88+
const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
89+
const engine: webllm.MLCEngineInterface = await webllm.CreateMLCEngine(
90+
selectedModel,
91+
{
92+
initProgressCallback: initProgressCallback,
93+
logLevel: "INFO",
94+
cachedPrefixes: [[{ role: "system", content: SYSTEM_PROMPT_PREFIX }]], // Prefilling KV cache for efficiency
95+
},
96+
{
97+
context_window_size: 2048,
98+
},
99+
);
100+
101+
// First Completion with cachedPrefixes
102+
const reply0 = await engine.chat.completions.create({
103+
messages: [
104+
{ role: "system", content: SYSTEM_PROMPT_PREFIX },
105+
{ role: "user", content: "List three US states." },
106+
],
107+
// below configurations are all optional
108+
n: 1,
109+
temperature: 1.5,
110+
max_tokens: 64,
111+
logprobs: true,
112+
top_logprobs: 2,
113+
});
114+
console.log(reply0);
115+
console.log(reply0.usage);
116+
117+
// Second Completion with cachedPrefixes
118+
const reply1 = await engine.chat.completions.create({
119+
messages: [
120+
{ role: "system", content: SYSTEM_PROMPT_PREFIX },
121+
{ role: "user", content: "Where is the US capital?" },
122+
],
123+
// below configurations are all optional
124+
n: 1,
125+
temperature: 1.5,
126+
max_tokens: 64,
127+
logprobs: true,
128+
top_logprobs: 2,
129+
});
130+
console.log(reply1);
131+
console.log(reply1.usage);
132+
}
133+
134+
async function main() {
135+
await testPrefix();
136+
137+
await testWithoutPrefix();
138+
139+
await testMultiRound();
140+
}
141+
142+
main();

src/config.ts

+4
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,13 @@ export interface ChatOptions extends Partial<ChatConfig> {}
106106
* appConfig: Configure the app, including the list of models and whether to use IndexedDB cache.
107107
* initProgressCallback: A callback for showing the progress of loading the model.
108108
* logitProcessorRegistry: A register for stateful logit processors, see `webllm.LogitProcessor`.
109+
* cachedPrefixes: Specifies a system prompt (prefix) that will be prefilled when loading the engine
110+
* to create their corresponding KV cache and store them for reuse. These cached kv pairs persist
111+
* until the engine is reloaded.
109112
*
110113
* @note All fields are optional, and `logitProcessorRegistry` is only used for `MLCEngine` and not
111114
* other `MLCEngine`s.
115+
* @note cachedPrefixes is experimental. It may change in future versions.
112116
*/
113117
export interface MLCEngineConfig {
114118
appConfig?: AppConfig;

src/llm_chat.ts

+13-15
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completi
3838

3939
type ImageURL = ChatCompletionContentPartImage.ImageURL;
4040

41+
// Default sequence ID for chat completion
42+
const CHAT_SEQUENCE_ID = 0;
43+
4144
export class LLMChatPipeline {
4245
private config: ChatConfig;
4346
private tokenizer: Tokenizer;
@@ -177,7 +180,7 @@ export class LLMChatPipeline {
177180
log.info("prepend_space_in_encode: ", this.prepend_space_in_encode);
178181

179182
this.seqIdToPrefix = new Map<number, number[]>();
180-
this.nextSequenceId = 0;
183+
this.nextSequenceId = CHAT_SEQUENCE_ID;
181184
this.device = this.tvm.webgpu();
182185

183186
// 1. Create VM and get the core functions
@@ -510,9 +513,9 @@ export class LLMChatPipeline {
510513
msgRole: Role, // either user or tool
511514
inp_role_str?: string,
512515
genConfig?: GenerationConfig,
513-
seqID = 0,
516+
seqID = CHAT_SEQUENCE_ID,
514517
): Promise<void> {
515-
if (seqID === 0) {
518+
if (seqID === CHAT_SEQUENCE_ID) {
516519
if (msgRole !== Role.user && msgRole !== Role.tool) {
517520
throw new MessageOrderError(
518521
"The last message should be from `user` or `tool`.",
@@ -608,7 +611,7 @@ export class LLMChatPipeline {
608611
}
609612

610613
// 0. Get inputData from conversation
611-
if (seqID === 0) {
614+
if (seqID === CHAT_SEQUENCE_ID) {
612615
if (conversation.isTextCompletion) {
613616
conversation.prompt = inp;
614617
} else {
@@ -652,13 +655,8 @@ export class LLMChatPipeline {
652655

653656
// If a match is found, fork the sequence
654657
if (matchedSeqId !== -1 && maxMatchedLen > 0) {
655-
console.log(
656-
"Forking sequence",
657-
matchedSeqId,
658-
"at position",
659-
maxMatchedLen,
660-
);
661-
if (seqID === 0) {
658+
log.info("Forking sequence", matchedSeqId, "at position", maxMatchedLen);
659+
if (seqID === CHAT_SEQUENCE_ID) {
662660
this.fKVCacheRemoveSequence!(
663661
this.kvCache,
664662
new tvmjs.Scalar(seqID, "int64"),
@@ -672,14 +670,14 @@ export class LLMChatPipeline {
672670
new tvmjs.Scalar(maxMatchedLen, "int64"), // fork_position
673671
);
674672
this.tvm.endScope();
675-
} else if (seqID !== 0) {
673+
} else if (seqID !== CHAT_SEQUENCE_ID) {
676674
// If no match is found, add the new sequence to the KV cache
677-
console.log("Adding prefix to KV cache: ", seqID);
675+
log.info("Adding prefix to KV cache: ", seqID);
678676
this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(seqID, "int64"));
679677
}
680678

681679
// Add the new sequence to the seqIdToPrefix map (if it is a prefix)
682-
if (seqID !== 0) {
680+
if (seqID !== CHAT_SEQUENCE_ID) {
683681
this.seqIdToPrefix.set(seqID, inputTokens);
684682
}
685683

@@ -996,7 +994,7 @@ export class LLMChatPipeline {
996994
private async embedAndForward(
997995
inputData: Array<Array<number> | ImageURL>,
998996
inputDataLen: number,
999-
seqID = 0,
997+
seqID = CHAT_SEQUENCE_ID,
1000998
): Promise<tvmjs.NDArray> {
1001999
if (inputDataLen > this.prefillChunkSize) {
10021000
throw new Error(

0 commit comments

Comments
 (0)