@@ -34,6 +34,7 @@ import {
34
34
PrefillChunkSizeSmallerThanImageError ,
35
35
CannotFindImageEmbedError ,
36
36
} from "./error" ;
37
+ import { ChatCompletionMessageParam } from "./openai_api_protocols/chat_completion" ;
37
38
38
39
type ImageURL = ChatCompletionContentPartImage . ImageURL ;
39
40
@@ -128,6 +129,8 @@ export class LLMChatPipeline {
128
129
private curRoundGrammarInitTotalTime = 0 ;
129
130
// Total time of getting next bitmask and accepting token in seconds
130
131
private curRoundGrammarPerTokenTotalTime = 0 ;
132
+ private seqIdToPrefix : Map < number , number [ ] > ;
133
+ private nextSequenceId : number ;
131
134
132
135
constructor (
133
136
tvm : tvmjs . Instance ,
@@ -173,6 +176,8 @@ export class LLMChatPipeline {
173
176
log . info ( "token_postproc_method: " , this . token_postproc_method ) ;
174
177
log . info ( "prepend_space_in_encode: " , this . prepend_space_in_encode ) ;
175
178
179
+ this . seqIdToPrefix = new Map < number , number [ ] > ( ) ;
180
+ this . nextSequenceId = 0 ;
176
181
this . device = this . tvm . webgpu ( ) ;
177
182
178
183
// 1. Create VM and get the core functions
@@ -344,7 +349,12 @@ export class LLMChatPipeline {
344
349
* Reset KV Cache
345
350
*/
346
351
resetKVCache ( ) {
347
- this . fclearKVCaches ( this . kvCache ) ;
352
+ // Check whether to keep prefixes in the KV cache
353
+ if ( this . seqIdToPrefix . size === 0 ) {
354
+ this . fclearKVCaches ( this . kvCache ) ;
355
+ } else {
356
+ this . fKVCacheRemoveSequence ! ( this . kvCache , new tvmjs . Scalar ( 0 , "int64" ) ) ;
357
+ }
348
358
this . fKVCacheAddSequence ! ( this . kvCache , new tvmjs . Scalar ( 0 , "int64" ) ) ;
349
359
if ( this . slidingWindowSize != - 1 ) {
350
360
this . fKVCacheEnableSlidingWindowForSeq (
@@ -483,6 +493,15 @@ export class LLMChatPipeline {
483
493
await this . tvm . asyncLoadWebGPUPipelines ( this . vm . getInternalModule ( ) ) ;
484
494
}
485
495
496
+ matchPrefix ( inputTokens : number [ ] , prefixTokens : number [ ] ) : number {
497
+ for ( let i = 0 ; i < prefixTokens . length ; i ++ ) {
498
+ if ( inputTokens [ i ] !== prefixTokens [ i ] ) {
499
+ return i ;
500
+ }
501
+ }
502
+ return prefixTokens . length ;
503
+ }
504
+
486
505
/**
487
506
* Generate the first token given input prompt
488
507
*/
@@ -491,11 +510,17 @@ export class LLMChatPipeline {
491
510
msgRole : Role , // either user or tool
492
511
inp_role_str ?: string ,
493
512
genConfig ?: GenerationConfig ,
513
+ seqID = 0 ,
494
514
) : Promise < void > {
495
- if ( msgRole !== Role . user && msgRole !== Role . tool ) {
496
- throw new MessageOrderError (
497
- "The last message should be from `user` or `tool`." ,
498
- ) ;
515
+ if ( seqID === 0 ) {
516
+ if ( msgRole !== Role . user && msgRole !== Role . tool ) {
517
+ throw new MessageOrderError (
518
+ "The last message should be from `user` or `tool`." ,
519
+ ) ;
520
+ }
521
+ } else {
522
+ // Set the input as system prompt during prefix prefilling
523
+ this . conversation . override_system_message = inp ;
499
524
}
500
525
if ( this . resetStatsPerPrefill ) {
501
526
this . resetRuntimeStats ( ) ;
@@ -583,11 +608,13 @@ export class LLMChatPipeline {
583
608
}
584
609
585
610
// 0. Get inputData from conversation
586
- if ( conversation . isTextCompletion ) {
587
- conversation . prompt = inp ;
588
- } else {
589
- conversation . appendMessage ( msgRole , inp , inp_role_str ) ;
590
- conversation . appendReplyHeader ( Role . assistant ) ;
611
+ if ( seqID === 0 ) {
612
+ if ( conversation . isTextCompletion ) {
613
+ conversation . prompt = inp ;
614
+ } else {
615
+ conversation . appendMessage ( msgRole , inp , inp_role_str ) ;
616
+ conversation . appendReplyHeader ( Role . assistant ) ;
617
+ }
591
618
}
592
619
const retGetInputData = this . getInputData ( ) ;
593
620
const inputData : Array < Array < number > | ImageURL > = retGetInputData [ 0 ] ;
@@ -610,11 +637,68 @@ export class LLMChatPipeline {
610
637
throw new CannotFindImageEmbedError ( ) ;
611
638
}
612
639
640
+ let maxMatchedLen = - 1 ;
641
+ let matchedSeqId = - 1 ;
642
+
643
+ // Prefix matching and forking
644
+ const inputTokens = inputData . flat ( ) as number [ ] ;
645
+ for ( const [ id , prefixTokens ] of this . seqIdToPrefix ) {
646
+ const matchedLen = this . matchPrefix ( inputTokens , prefixTokens ) ;
647
+ if ( matchedLen > maxMatchedLen ) {
648
+ maxMatchedLen = matchedLen ;
649
+ matchedSeqId = id ;
650
+ }
651
+ }
652
+
653
+ // If a match is found, fork the sequence
654
+ if ( matchedSeqId !== - 1 && maxMatchedLen > 0 ) {
655
+ console . log (
656
+ "Forking sequence" ,
657
+ matchedSeqId ,
658
+ "at position" ,
659
+ maxMatchedLen ,
660
+ ) ;
661
+ if ( seqID === 0 ) {
662
+ this . fKVCacheRemoveSequence ! (
663
+ this . kvCache ,
664
+ new tvmjs . Scalar ( seqID , "int64" ) ,
665
+ ) ;
666
+ }
667
+ this . tvm . beginScope ( ) ;
668
+ this . tvm . getGlobalFunc ( "vm.builtin.kv_state_fork_sequence" ) (
669
+ this . kvCache ,
670
+ new tvmjs . Scalar ( matchedSeqId , "int64" ) , // fork_parent_id
671
+ new tvmjs . Scalar ( seqID , "int64" ) , // fork_child_id
672
+ new tvmjs . Scalar ( maxMatchedLen , "int64" ) , // fork_position
673
+ ) ;
674
+ this . tvm . endScope ( ) ;
675
+ } else if ( seqID !== 0 ) {
676
+ // If no match is found, add the new sequence to the KV cache
677
+ console . log ( "Adding new sequence to KV cache: " , seqID ) ;
678
+ this . fKVCacheAddSequence ! ( this . kvCache , new tvmjs . Scalar ( seqID , "int64" ) ) ;
679
+ }
680
+
681
+ // Add the new sequence to the seqIdToPrefix map (if it is a prefix)
682
+ if ( seqID !== 0 ) {
683
+ this . seqIdToPrefix . set ( seqID , inputTokens ) ;
684
+ }
685
+
613
686
// 1. Chunk inputData to embed and forward in one shot for each, minimize intermediate data
614
- const retGetChunks = getChunkedPrefillInputData (
615
- inputData ,
616
- this . prefillChunkSize ,
617
- ) ;
687
+ let retGetChunks ;
688
+ if ( maxMatchedLen === - 1 ) {
689
+ retGetChunks = getChunkedPrefillInputData (
690
+ inputData ,
691
+ this . prefillChunkSize ,
692
+ ) ;
693
+ } else {
694
+ // If a matched prefix exists, only forward the remaining tokens
695
+ retGetChunks = getChunkedPrefillInputData (
696
+ inputData . map ( ( arr ) =>
697
+ Array . isArray ( arr ) ? arr . slice ( maxMatchedLen ) : arr ,
698
+ ) ,
699
+ this . prefillChunkSize ,
700
+ ) ;
701
+ }
618
702
const chunks : Array < Array < number > | ImageURL > [ ] = retGetChunks [ 0 ] ;
619
703
const chunkLens : Array < number > = retGetChunks [ 1 ] ;
620
704
@@ -626,7 +710,7 @@ export class LLMChatPipeline {
626
710
const chunkLen = chunkLens [ i ] ;
627
711
const prevFilledLen = this . filledKVCacheLength ;
628
712
logits = this . tvm . detachFromCurrentScope (
629
- await this . embedAndForward ( chunk , chunkLen ) ,
713
+ await this . embedAndForward ( chunk , chunkLen , seqID ) ,
630
714
) ;
631
715
if ( this . filledKVCacheLength !== prevFilledLen + chunkLen ) {
632
716
throw new Error (
@@ -651,6 +735,41 @@ export class LLMChatPipeline {
651
735
this . processNextToken ( nextToken , genConfig ) ;
652
736
}
653
737
738
+ async prefillConvSequence (
739
+ messages : ChatCompletionMessageParam [ ] ,
740
+ inp_role_str ?: string ,
741
+ genConfig ?: GenerationConfig ,
742
+ ) : Promise < void > {
743
+ for ( const message of messages ) {
744
+ this . nextSequenceId = this . nextSequenceId + 1 ;
745
+ const newSeqId = this . nextSequenceId ;
746
+ // Call the regular prefillStep with the new seqID
747
+ if ( typeof message . content === "string" ) {
748
+ // Support long system prompt
749
+ if ( message . role === "system" ) {
750
+ await this . prefillStep (
751
+ message . content ,
752
+ Role . tool ,
753
+ inp_role_str ,
754
+ genConfig ,
755
+ newSeqId ,
756
+ ) ;
757
+ } else {
758
+ throw Error (
759
+ "Invalid role in prefix message: " +
760
+ message . role +
761
+ ", expected 'system'." ,
762
+ ) ;
763
+ }
764
+ } else {
765
+ throw Error (
766
+ "Invalid content in prefix message, does not support image input." ,
767
+ ) ;
768
+ }
769
+ }
770
+ this . conversation . reset ( ) ;
771
+ }
772
+
654
773
async decodeStep ( genConfig ?: GenerationConfig ) : Promise < void > {
655
774
if ( this . stopTriggered ) {
656
775
throw Error ( "Cannot run decode when stopped" ) ;
@@ -869,13 +988,15 @@ export class LLMChatPipeline {
869
988
*
870
989
* @param inputData data to embed and forward
871
990
* @param inputDataLen length of this inputData, should smaller than prefill chunk size.
991
+ * @param seqID sequence ID of the input data in KV cache for prefix caching
872
992
* @returns The logits returned by this forward as tvmjs.NDArray on GPU.
873
993
*
874
994
* @note Precondition: inputData's data length is smaller than prefill chunk size
875
995
*/
876
996
private async embedAndForward (
877
997
inputData : Array < Array < number > | ImageURL > ,
878
998
inputDataLen : number ,
999
+ seqID = 0 ,
879
1000
) : Promise < tvmjs . NDArray > {
880
1001
if ( inputDataLen > this . prefillChunkSize ) {
881
1002
throw new Error (
@@ -913,7 +1034,8 @@ export class LLMChatPipeline {
913
1034
914
1035
// 3. Forward the concatenated embeddings
915
1036
const inputLenShape = this . tvm . makeShapeTuple ( [ inputDataLen ] ) ;
916
- const seqIdsTuple = this . tvm . makeShapeTuple ( [ 0 ] ) ;
1037
+ // set seqIdsTuple to be childID
1038
+ const seqIdsTuple = this . tvm . makeShapeTuple ( [ seqID ] ) ;
917
1039
this . fKVCacheBeginForward ! ( this . kvCache , seqIdsTuple , inputLenShape ) ;
918
1040
let retValue ;
919
1041
if ( inputDataLen > 1 ) {
0 commit comments