Skip to content

Commit cd102e9

Browse files
committed
Cache shared library function calls for static tokens
1 parent b895511 commit cd102e9

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

llama_cpp/llama.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ def __init__(
198198
sorted=sorted,
199199
)
200200
self._candidates = candidates
201+
self._token_nl = Llama.token_nl()
202+
self._token_eos = Llama.token_eos()
201203

202204
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
203205
"""Tokenize a string.
@@ -327,7 +329,7 @@ def _sample(
327329
else last_n_tokens_size
328330
)
329331
logits = self.eval_logits[-1]
330-
nl_logit = logits[Llama.token_nl()]
332+
nl_logit = logits[self._token_nl]
331333
candidates = self._candidates
332334
for i, logit in enumerate(logits):
333335
candidates.data[i].id = llama_cpp.llama_token(i)
@@ -351,7 +353,7 @@ def _sample(
351353
alpha_presence=presence_penalty,
352354
)
353355
if not penalize_nl:
354-
candidates.data[Llama.token_nl()].logit = llama_cpp.c_float(nl_logit)
356+
candidates.data[self._token_nl].logit = llama_cpp.c_float(nl_logit)
355357
if temp.value == 0.0:
356358
return llama_cpp.llama_sample_token_greedy(
357359
ctx=self.ctx,
@@ -688,7 +690,7 @@ def _create_completion(
688690
presence_penalty=presence_penalty,
689691
repeat_penalty=repeat_penalty,
690692
):
691-
if token == Llama.token_eos():
693+
if token == self._token_eos:
692694
text = self.detokenize(completion_tokens)
693695
finish_reason = "stop"
694696
break

0 commit comments

Comments
 (0)