20
20
from . import llama_cpp
21
21
from .llama_types import *
22
22
23
+ import numpy as np
24
+ import numpy .typing as npt
25
+
23
26
24
27
class LlamaCache :
25
28
"""Cache for a llama.cpp model."""
@@ -73,11 +76,15 @@ def __init__(
73
76
self ,
74
77
eval_tokens : Deque [int ],
75
78
eval_logits : Deque [List [float ]],
79
+ input_ids : npt .NDArray [np .intc ],
80
+ scores : npt .NDArray [np .single ],
76
81
llama_state , # type: llama_cpp.Array[llama_cpp.c_uint8]
77
82
llama_state_size : int ,
78
83
):
79
84
self .eval_tokens = eval_tokens
80
85
self .eval_logits = eval_logits
86
+ self .input_ids = input_ids
87
+ self .scores = scores
81
88
self .llama_state = llama_state
82
89
self .llama_state_size = llama_state_size
83
90
@@ -207,27 +214,27 @@ def __init__(
207
214
208
215
self ._n_vocab = self .n_vocab ()
209
216
self ._n_ctx = self .n_ctx ()
210
- data = (llama_cpp .llama_token_data * self ._n_vocab )(
211
- * [
212
- llama_cpp .llama_token_data (
213
- id = llama_cpp .llama_token (i ),
214
- logit = llama_cpp .c_float (0.0 ),
215
- p = llama_cpp .c_float (0.0 ),
216
- )
217
- for i in range (self ._n_vocab )
218
- ]
219
- )
220
217
size = llama_cpp .c_size_t (self ._n_vocab )
221
- sorted = False
218
+ sorted = llama_cpp .c_bool (False )
219
+ self ._candidates_data = np .array (
220
+ [],
221
+ dtype = np .dtype (
222
+ [("id" , np .intc ), ("logit" , np .single ), ("p" , np .single )], align = True
223
+ ),
224
+ )
225
+ self ._candidates_data .resize (3 , self ._n_vocab )
222
226
candidates = llama_cpp .llama_token_data_array (
223
- data = data ,
227
+ data = self . _candidates_data . ctypes . data_as ( llama_cpp . llama_token_data_p ) ,
224
228
size = size ,
225
229
sorted = sorted ,
226
230
)
227
231
self ._candidates = candidates
228
232
self ._token_nl = Llama .token_nl ()
229
233
self ._token_eos = Llama .token_eos ()
230
234
235
+ self ._input_ids = np .array ([], dtype = np .intc )
236
+ self ._scores = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
237
+
231
238
def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
232
239
"""Tokenize a string.
233
240
@@ -295,6 +302,8 @@ def reset(self):
295
302
"""Reset the model state."""
296
303
self .eval_tokens .clear ()
297
304
self .eval_logits .clear ()
305
+ self ._input_ids = np .array ([], dtype = np .intc )
306
+ self ._scores = np .ndarray ((0 , self ._n_vocab ), dtype = np .single )
298
307
299
308
def eval (self , tokens : Sequence [int ]):
300
309
"""Evaluate a list of tokens.
@@ -306,7 +315,7 @@ def eval(self, tokens: Sequence[int]):
306
315
n_ctx = self ._n_ctx
307
316
for i in range (0 , len (tokens ), self .n_batch ):
308
317
batch = tokens [i : min (len (tokens ), i + self .n_batch )]
309
- n_past = min (n_ctx - len (batch ), len (self .eval_tokens ))
318
+ n_past = min (n_ctx - len (batch ), len (self ._input_ids ))
310
319
n_tokens = len (batch )
311
320
return_code = llama_cpp .llama_eval (
312
321
ctx = self .ctx ,
@@ -319,13 +328,19 @@ def eval(self, tokens: Sequence[int]):
319
328
raise RuntimeError (f"llama_eval returned { return_code } " )
320
329
# Save tokens
321
330
self .eval_tokens .extend (batch )
331
+ self ._input_ids : npt .NDArray [np .intc ] = np .concatenate (
332
+ (self ._input_ids , np .array (batch , dtype = np .intc )), axis = 0
333
+ )
322
334
# Save logits
323
335
rows = n_tokens if self .params .logits_all else 1
324
336
n_vocab = self ._n_vocab
325
337
cols = n_vocab
326
338
logits_view = llama_cpp .llama_get_logits (self .ctx )
327
339
logits = [logits_view [i * cols : (i + 1 ) * cols ] for i in range (rows )]
328
340
self .eval_logits .extend (logits )
341
+ self ._scores : npt .NDArray [np .single ] = np .concatenate (
342
+ (self ._scores , np .array (logits , dtype = np .single )), axis = 0
343
+ )
329
344
330
345
def _sample (
331
346
self ,
@@ -346,6 +361,7 @@ def _sample(
346
361
):
347
362
assert self .ctx is not None
348
363
assert len (self .eval_logits ) > 0
364
+ assert self ._scores .shape [0 ] > 0
349
365
n_vocab = self ._n_vocab
350
366
n_ctx = self ._n_ctx
351
367
top_k = llama_cpp .c_int (n_vocab ) if top_k .value <= 0 else top_k
@@ -354,18 +370,23 @@ def _sample(
354
370
if last_n_tokens_size .value < 0
355
371
else last_n_tokens_size
356
372
)
357
- logits = self .eval_logits [- 1 ]
373
+ logits : npt . NDArray [ np . single ] = self ._scores [- 1 , : ]
358
374
359
375
if logits_processor is not None :
360
- logits = logits_processor (list (self .eval_tokens ), logits )
361
- self .eval_logits [- 1 ] = logits
376
+ logits = np .array (
377
+ logits_processor (self ._input_ids .tolist (), logits .tolist ()),
378
+ dtype = np .single ,
379
+ )
380
+ self ._scores [- 1 , :] = logits
381
+ self .eval_logits [- 1 ] = logits .tolist ()
362
382
363
383
nl_logit = logits [self ._token_nl ]
364
384
candidates = self ._candidates
365
- for i , logit in enumerate (logits ):
366
- candidates .data [i ].id = llama_cpp .llama_token (i )
367
- candidates .data [i ].logit = llama_cpp .c_float (logit )
368
- candidates .data [i ].p = llama_cpp .c_float (0.0 )
385
+ candidates_data = self ._candidates_data
386
+ candidates_data ["id" ] = np .arange (n_vocab , dtype = np .intc ) # type: ignore
387
+ candidates_data ["logit" ] = logits
388
+ candidates_data ["p" ] = np .zeros (n_vocab , dtype = np .single )
389
+ candidates .data = candidates_data .ctypes .data_as (llama_cpp .llama_token_data_p )
369
390
candidates .sorted = llama_cpp .c_bool (False )
370
391
candidates .size = llama_cpp .c_size_t (n_vocab )
371
392
llama_cpp .llama_sample_repetition_penalty (
@@ -483,8 +504,8 @@ def sample(
483
504
"""
484
505
assert self .ctx is not None
485
506
last_n_tokens_data = [llama_cpp .llama_token (0 )] * max (
486
- 0 , self .last_n_tokens_size - len (self .eval_tokens )
487
- ) + list ( self .eval_tokens ) [- self .last_n_tokens_size :]
507
+ 0 , self .last_n_tokens_size - len (self ._input_ids )
508
+ ) + self ._input_ids [- self .last_n_tokens_size :]. tolist ()
488
509
return self ._sample (
489
510
last_n_tokens_data = (llama_cpp .llama_token * self .last_n_tokens_size )(
490
511
* last_n_tokens_data
@@ -542,9 +563,9 @@ def generate(
542
563
"""
543
564
assert self .ctx is not None
544
565
545
- if reset and len (self .eval_tokens ) > 0 :
566
+ if reset and len (self ._input_ids ) > 0 :
546
567
longest_prefix = 0
547
- for a , b in zip (self .eval_tokens , tokens [:- 1 ]):
568
+ for a , b in zip (self ._input_ids , tokens [:- 1 ]):
548
569
if a == b :
549
570
longest_prefix += 1
550
571
else :
@@ -554,6 +575,8 @@ def generate(
554
575
print ("Llama.generate: prefix-match hit" , file = sys .stderr )
555
576
reset = False
556
577
tokens = tokens [longest_prefix :]
578
+ self ._input_ids = self ._input_ids [:longest_prefix ]
579
+ self ._scores = self ._scores [:longest_prefix , :]
557
580
for _ in range (len (self .eval_tokens ) - longest_prefix ):
558
581
self .eval_tokens .pop ()
559
582
try :
@@ -580,7 +603,7 @@ def generate(
580
603
logits_processor = logits_processor ,
581
604
)
582
605
if stopping_criteria is not None and stopping_criteria (
583
- list ( self .eval_tokens ), self .eval_logits [- 1 ]
606
+ self ._input_ids . tolist ( ), self ._scores [- 1 , :]. tolist ()
584
607
):
585
608
return
586
609
tokens_or_none = yield token
@@ -715,10 +738,10 @@ def _create_completion(
715
738
try :
716
739
cache_item = self .cache [prompt_tokens ]
717
740
cache_prefix_len = Llama .longest_token_prefix (
718
- cache_item .eval_tokens , prompt_tokens
741
+ cache_item .input_ids . tolist () , prompt_tokens
719
742
)
720
743
eval_prefix_len = Llama .longest_token_prefix (
721
- self .eval_tokens , prompt_tokens
744
+ self ._input_ids . tolist () , prompt_tokens
722
745
)
723
746
if cache_prefix_len > eval_prefix_len :
724
747
self .load_state (cache_item )
@@ -807,7 +830,7 @@ def _create_completion(
807
830
self .detokenize (completion_tokens [:returned_tokens ])
808
831
)
809
832
token_offset = len (prompt_tokens ) + returned_tokens
810
- logits = self .eval_logits [token_offset - 1 ]
833
+ logits = self ._scores [token_offset - 1 , :]. tolist ()
811
834
current_logprobs = Llama .logits_to_logprobs (logits )
812
835
sorted_logprobs = list (
813
836
sorted (
@@ -856,7 +879,7 @@ def _create_completion(
856
879
break
857
880
858
881
if stopping_criteria is not None and stopping_criteria (
859
- list ( self .eval_tokens ), self .eval_logits [- 1 ]
882
+ self ._input_ids . tolist ( ), self ._scores [- 1 , :]. tolist ()
860
883
):
861
884
text = self .detokenize (completion_tokens )
862
885
finish_reason = "stop"
@@ -886,7 +909,7 @@ def _create_completion(
886
909
self .detokenize (completion_tokens [:returned_tokens ])
887
910
)
888
911
token_offset = len (prompt_tokens ) + returned_tokens - 1
889
- logits = self .eval_logits [token_offset ]
912
+ logits = self ._scores [token_offset , :]. tolist ()
890
913
current_logprobs = Llama .logits_to_logprobs (logits )
891
914
sorted_logprobs = list (
892
915
sorted (
@@ -988,8 +1011,7 @@ def _create_completion(
988
1011
for token in all_tokens
989
1012
]
990
1013
all_logprobs = [
991
- Llama .logits_to_logprobs (list (map (float , row )))
992
- for row in self .eval_logits
1014
+ Llama .logits_to_logprobs (row .tolist ()) for row in self ._scores
993
1015
][token_offset :]
994
1016
for token , token_str , logprobs_token in zip (
995
1017
all_tokens , all_token_strs , all_logprobs
@@ -1373,6 +1395,8 @@ def save_state(self) -> LlamaState:
1373
1395
return LlamaState (
1374
1396
eval_tokens = self .eval_tokens .copy (),
1375
1397
eval_logits = self .eval_logits .copy (),
1398
+ scores = self ._scores .copy (),
1399
+ input_ids = self ._input_ids .copy (),
1376
1400
llama_state = llama_state_compact ,
1377
1401
llama_state_size = n_bytes ,
1378
1402
)
@@ -1381,6 +1405,8 @@ def load_state(self, state: LlamaState) -> None:
1381
1405
assert self .ctx is not None
1382
1406
self .eval_tokens = state .eval_tokens .copy ()
1383
1407
self .eval_logits = state .eval_logits .copy ()
1408
+ self ._scores = state .scores .copy ()
1409
+ self ._input_ids = state .input_ids .copy ()
1384
1410
state_size = state .llama_state_size
1385
1411
if llama_cpp .llama_set_state_data (self .ctx , state .llama_state ) != state_size :
1386
1412
raise RuntimeError ("Failed to set llama state data" )
0 commit comments