Skip to content

Commit 828f9ec

Browse files
committed
Merge branch 'main' of github.com:abetlen/llama_cpp_python into main
2 parents b1daf56 + 825912a commit 828f9ec

File tree

5 files changed

+91
-62
lines changed

5 files changed

+91
-62
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111

1212
- Added first version of the changelog
1313
- Server: Use async routes
14+
- Use numpy for internal buffers to reduce memory usage and improve performance.
1415

1516
### Fixed
1617

llama_cpp/llama.py

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@
2020
from . import llama_cpp
2121
from .llama_types import *
2222

23+
import numpy as np
24+
import numpy.typing as npt
25+
2326

2427
class LlamaCache:
2528
"""Cache for a llama.cpp model."""
@@ -73,11 +76,15 @@ def __init__(
7376
self,
7477
eval_tokens: Deque[int],
7578
eval_logits: Deque[List[float]],
79+
input_ids: npt.NDArray[np.intc],
80+
scores: npt.NDArray[np.single],
7681
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
7782
llama_state_size: int,
7883
):
7984
self.eval_tokens = eval_tokens
8085
self.eval_logits = eval_logits
86+
self.input_ids = input_ids
87+
self.scores = scores
8188
self.llama_state = llama_state
8289
self.llama_state_size = llama_state_size
8390

@@ -207,27 +214,27 @@ def __init__(
207214

208215
self._n_vocab = self.n_vocab()
209216
self._n_ctx = self.n_ctx()
210-
data = (llama_cpp.llama_token_data * self._n_vocab)(
211-
*[
212-
llama_cpp.llama_token_data(
213-
id=llama_cpp.llama_token(i),
214-
logit=llama_cpp.c_float(0.0),
215-
p=llama_cpp.c_float(0.0),
216-
)
217-
for i in range(self._n_vocab)
218-
]
219-
)
220217
size = llama_cpp.c_size_t(self._n_vocab)
221-
sorted = False
218+
sorted = llama_cpp.c_bool(False)
219+
self._candidates_data = np.array(
220+
[],
221+
dtype=np.dtype(
222+
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
223+
),
224+
)
225+
self._candidates_data.resize(3, self._n_vocab)
222226
candidates = llama_cpp.llama_token_data_array(
223-
data=data,
227+
data=self._candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
224228
size=size,
225229
sorted=sorted,
226230
)
227231
self._candidates = candidates
228232
self._token_nl = Llama.token_nl()
229233
self._token_eos = Llama.token_eos()
230234

235+
self._input_ids = np.array([], dtype=np.intc)
236+
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
237+
231238
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
232239
"""Tokenize a string.
233240
@@ -295,6 +302,8 @@ def reset(self):
295302
"""Reset the model state."""
296303
self.eval_tokens.clear()
297304
self.eval_logits.clear()
305+
self._input_ids = np.array([], dtype=np.intc)
306+
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
298307

299308
def eval(self, tokens: Sequence[int]):
300309
"""Evaluate a list of tokens.
@@ -306,7 +315,7 @@ def eval(self, tokens: Sequence[int]):
306315
n_ctx = self._n_ctx
307316
for i in range(0, len(tokens), self.n_batch):
308317
batch = tokens[i : min(len(tokens), i + self.n_batch)]
309-
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
318+
n_past = min(n_ctx - len(batch), len(self._input_ids))
310319
n_tokens = len(batch)
311320
return_code = llama_cpp.llama_eval(
312321
ctx=self.ctx,
@@ -319,13 +328,19 @@ def eval(self, tokens: Sequence[int]):
319328
raise RuntimeError(f"llama_eval returned {return_code}")
320329
# Save tokens
321330
self.eval_tokens.extend(batch)
331+
self._input_ids: npt.NDArray[np.intc] = np.concatenate(
332+
(self._input_ids, np.array(batch, dtype=np.intc)), axis=0
333+
)
322334
# Save logits
323335
rows = n_tokens if self.params.logits_all else 1
324336
n_vocab = self._n_vocab
325337
cols = n_vocab
326338
logits_view = llama_cpp.llama_get_logits(self.ctx)
327339
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
328340
self.eval_logits.extend(logits)
341+
self._scores: npt.NDArray[np.single] = np.concatenate(
342+
(self._scores, np.array(logits, dtype=np.single)), axis=0
343+
)
329344

330345
def _sample(
331346
self,
@@ -346,6 +361,7 @@ def _sample(
346361
):
347362
assert self.ctx is not None
348363
assert len(self.eval_logits) > 0
364+
assert self._scores.shape[0] > 0
349365
n_vocab = self._n_vocab
350366
n_ctx = self._n_ctx
351367
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
@@ -354,18 +370,23 @@ def _sample(
354370
if last_n_tokens_size.value < 0
355371
else last_n_tokens_size
356372
)
357-
logits = self.eval_logits[-1]
373+
logits: npt.NDArray[np.single] = self._scores[-1, :]
358374

359375
if logits_processor is not None:
360-
logits = logits_processor(list(self.eval_tokens), logits)
361-
self.eval_logits[-1] = logits
376+
logits = np.array(
377+
logits_processor(self._input_ids.tolist(), logits.tolist()),
378+
dtype=np.single,
379+
)
380+
self._scores[-1, :] = logits
381+
self.eval_logits[-1] = logits.tolist()
362382

363383
nl_logit = logits[self._token_nl]
364384
candidates = self._candidates
365-
for i, logit in enumerate(logits):
366-
candidates.data[i].id = llama_cpp.llama_token(i)
367-
candidates.data[i].logit = llama_cpp.c_float(logit)
368-
candidates.data[i].p = llama_cpp.c_float(0.0)
385+
candidates_data = self._candidates_data
386+
candidates_data["id"] = np.arange(n_vocab, dtype=np.intc) # type: ignore
387+
candidates_data["logit"] = logits
388+
candidates_data["p"] = np.zeros(n_vocab, dtype=np.single)
389+
candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p)
369390
candidates.sorted = llama_cpp.c_bool(False)
370391
candidates.size = llama_cpp.c_size_t(n_vocab)
371392
llama_cpp.llama_sample_repetition_penalty(
@@ -483,8 +504,8 @@ def sample(
483504
"""
484505
assert self.ctx is not None
485506
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
486-
0, self.last_n_tokens_size - len(self.eval_tokens)
487-
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
507+
0, self.last_n_tokens_size - len(self._input_ids)
508+
) + self._input_ids[-self.last_n_tokens_size :].tolist()
488509
return self._sample(
489510
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
490511
*last_n_tokens_data
@@ -542,9 +563,9 @@ def generate(
542563
"""
543564
assert self.ctx is not None
544565

545-
if reset and len(self.eval_tokens) > 0:
566+
if reset and len(self._input_ids) > 0:
546567
longest_prefix = 0
547-
for a, b in zip(self.eval_tokens, tokens[:-1]):
568+
for a, b in zip(self._input_ids, tokens[:-1]):
548569
if a == b:
549570
longest_prefix += 1
550571
else:
@@ -554,6 +575,8 @@ def generate(
554575
print("Llama.generate: prefix-match hit", file=sys.stderr)
555576
reset = False
556577
tokens = tokens[longest_prefix:]
578+
self._input_ids = self._input_ids[:longest_prefix]
579+
self._scores = self._scores[:longest_prefix, :]
557580
for _ in range(len(self.eval_tokens) - longest_prefix):
558581
self.eval_tokens.pop()
559582
try:
@@ -580,7 +603,7 @@ def generate(
580603
logits_processor=logits_processor,
581604
)
582605
if stopping_criteria is not None and stopping_criteria(
583-
list(self.eval_tokens), self.eval_logits[-1]
606+
self._input_ids.tolist(), self._scores[-1, :].tolist()
584607
):
585608
return
586609
tokens_or_none = yield token
@@ -715,10 +738,10 @@ def _create_completion(
715738
try:
716739
cache_item = self.cache[prompt_tokens]
717740
cache_prefix_len = Llama.longest_token_prefix(
718-
cache_item.eval_tokens, prompt_tokens
741+
cache_item.input_ids.tolist(), prompt_tokens
719742
)
720743
eval_prefix_len = Llama.longest_token_prefix(
721-
self.eval_tokens, prompt_tokens
744+
self._input_ids.tolist(), prompt_tokens
722745
)
723746
if cache_prefix_len > eval_prefix_len:
724747
self.load_state(cache_item)
@@ -807,7 +830,7 @@ def _create_completion(
807830
self.detokenize(completion_tokens[:returned_tokens])
808831
)
809832
token_offset = len(prompt_tokens) + returned_tokens
810-
logits = self.eval_logits[token_offset - 1]
833+
logits = self._scores[token_offset - 1, :].tolist()
811834
current_logprobs = Llama.logits_to_logprobs(logits)
812835
sorted_logprobs = list(
813836
sorted(
@@ -856,7 +879,7 @@ def _create_completion(
856879
break
857880

858881
if stopping_criteria is not None and stopping_criteria(
859-
list(self.eval_tokens), self.eval_logits[-1]
882+
self._input_ids.tolist(), self._scores[-1, :].tolist()
860883
):
861884
text = self.detokenize(completion_tokens)
862885
finish_reason = "stop"
@@ -886,7 +909,7 @@ def _create_completion(
886909
self.detokenize(completion_tokens[:returned_tokens])
887910
)
888911
token_offset = len(prompt_tokens) + returned_tokens - 1
889-
logits = self.eval_logits[token_offset]
912+
logits = self._scores[token_offset, :].tolist()
890913
current_logprobs = Llama.logits_to_logprobs(logits)
891914
sorted_logprobs = list(
892915
sorted(
@@ -988,8 +1011,7 @@ def _create_completion(
9881011
for token in all_tokens
9891012
]
9901013
all_logprobs = [
991-
Llama.logits_to_logprobs(list(map(float, row)))
992-
for row in self.eval_logits
1014+
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
9931015
][token_offset:]
9941016
for token, token_str, logprobs_token in zip(
9951017
all_tokens, all_token_strs, all_logprobs
@@ -1373,6 +1395,8 @@ def save_state(self) -> LlamaState:
13731395
return LlamaState(
13741396
eval_tokens=self.eval_tokens.copy(),
13751397
eval_logits=self.eval_logits.copy(),
1398+
scores=self._scores.copy(),
1399+
input_ids=self._input_ids.copy(),
13761400
llama_state=llama_state_compact,
13771401
llama_state_size=n_bytes,
13781402
)
@@ -1381,6 +1405,8 @@ def load_state(self, state: LlamaState) -> None:
13811405
assert self.ctx is not None
13821406
self.eval_tokens = state.eval_tokens.copy()
13831407
self.eval_logits = state.eval_logits.copy()
1408+
self._scores = state.scores.copy()
1409+
self._input_ids = state.input_ids.copy()
13841410
state_size = state.llama_state_size
13851411
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
13861412
raise RuntimeError("Failed to set llama state data")

poetry.lock

Lines changed: 26 additions & 22 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)