@@ -198,6 +198,8 @@ def __init__(
198
198
sorted = sorted ,
199
199
)
200
200
self ._candidates = candidates
201
+ self ._token_nl = Llama .token_nl ()
202
+ self ._token_eos = Llama .token_eos ()
201
203
202
204
def tokenize (self , text : bytes , add_bos : bool = True ) -> List [int ]:
203
205
"""Tokenize a string.
@@ -327,7 +329,7 @@ def _sample(
327
329
else last_n_tokens_size
328
330
)
329
331
logits = self .eval_logits [- 1 ]
330
- nl_logit = logits [Llama . token_nl () ]
332
+ nl_logit = logits [self . _token_nl ]
331
333
candidates = self ._candidates
332
334
for i , logit in enumerate (logits ):
333
335
candidates .data [i ].id = llama_cpp .llama_token (i )
@@ -351,7 +353,7 @@ def _sample(
351
353
alpha_presence = presence_penalty ,
352
354
)
353
355
if not penalize_nl :
354
- candidates .data [Llama . token_nl () ].logit = llama_cpp .c_float (nl_logit )
356
+ candidates .data [self . _token_nl ].logit = llama_cpp .c_float (nl_logit )
355
357
if temp .value == 0.0 :
356
358
return llama_cpp .llama_sample_token_greedy (
357
359
ctx = self .ctx ,
@@ -688,7 +690,7 @@ def _create_completion(
688
690
presence_penalty = presence_penalty ,
689
691
repeat_penalty = repeat_penalty ,
690
692
):
691
- if token == Llama . token_eos () :
693
+ if token == self . _token_eos :
692
694
text = self .detokenize (completion_tokens )
693
695
finish_reason = "stop"
694
696
break
0 commit comments