Skip to content

Commit fb762a6

Browse files
authored
Add speculative decoding (#1120)
* Add draft model param to llama class, implement basic prompt lookup decoding draft model * Use samplingcontext for sampling * Use 1d array * Use draft model for sampling * Fix dumb mistake * Allow for later extensions to the LlamaDraftModel api * Cleanup * Adaptive candidate prediction * Update implementation to match hf transformers * Tuning * Fix bug where last token was not used for ngram prediction * Remove heuristic for num_pred_tokens (no benefit) * fix: n_candidates bug. * Add draft_model_num_pred_tokens server setting * Cleanup * Update README
1 parent 71e3e4c commit fb762a6

File tree

6 files changed

+207
-90
lines changed

6 files changed

+207
-90
lines changed

README.md

+18
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,24 @@ Then you'll need to use a custom chat handler to load the clip model and process
378378
)
379379
```
380380

381+
### Speculative Decoding
382+
383+
`llama-cpp-python` supports speculative decoding which allows the model to generate completions based on a draft model.
384+
385+
The fastest way to use speculative decoding is through the `LlamaPromptLookupDecoding` class.
386+
387+
Just pass this as a draft model to the `Llama` class during initialization.
388+
389+
```python
390+
from llama_cpp import Llama
391+
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
392+
393+
llama = Llama(
394+
model_path="path/to/model.gguf",
395+
draft_model=LlamaPromptLookupDecoding(num_pred_tokens=10) # num_pred_tokens is the number of tokens to predict 10 is the default and generally good for gpu, 2 performs better for cpu-only machines.
396+
)
397+
```
398+
381399
### Adjusting the Context Window
382400

383401
The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements.

llama_cpp/llama.py

+91-90
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
import llama_cpp.llama_cpp as llama_cpp
3131
import llama_cpp.llama_chat_format as llama_chat_format
3232

33+
from llama_cpp.llama_speculative import LlamaDraftModel
34+
3335
import numpy as np
3436
import numpy.typing as npt
3537

@@ -39,6 +41,8 @@
3941
_LlamaContext, # type: ignore
4042
_LlamaBatch, # type: ignore
4143
_LlamaTokenDataArray, # type: ignore
44+
_LlamaSamplingParams, # type: ignore
45+
_LlamaSamplingContext, # type: ignore
4246
)
4347

4448

@@ -89,6 +93,8 @@ def __init__(
8993
# Chat Format Params
9094
chat_format: Optional[str] = None,
9195
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
96+
# Speculative Decoding
97+
draft_model: Optional[LlamaDraftModel] = None,
9298
# Misc
9399
verbose: bool = True,
94100
# Extra Params
@@ -152,6 +158,7 @@ def __init__(
152158
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
153159
chat_format: String specifying the chat format to use when calling create_chat_completion.
154160
chat_handler: Optional chat handler to use when calling create_chat_completion.
161+
draft_model: Optional draft model to use for speculative decoding.
155162
verbose: Print verbose output to stderr.
156163
157164
Raises:
@@ -315,6 +322,8 @@ def __init__(
315322
self.chat_format = chat_format
316323
self.chat_handler = chat_handler
317324

325+
self.draft_model = draft_model
326+
318327
self._n_vocab = self.n_vocab()
319328
self._n_ctx = self.n_ctx()
320329

@@ -503,6 +512,7 @@ def sample(
503512
penalize_nl: bool = True,
504513
logits_processor: Optional[LogitsProcessorList] = None,
505514
grammar: Optional[LlamaGrammar] = None,
515+
idx: Optional[int] = None,
506516
):
507517
"""Sample a token from the model.
508518
@@ -517,77 +527,46 @@ def sample(
517527
"""
518528
assert self._ctx is not None
519529
assert self.n_tokens > 0
520-
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
521-
0, self.last_n_tokens_size - self.n_tokens
522-
) + self._input_ids[-self.last_n_tokens_size :].tolist()
523-
last_n_tokens_size = len(last_n_tokens_data)
524-
n_vocab = self._n_vocab
525-
n_ctx = self._n_ctx
526-
top_k = n_vocab if top_k <= 0 else top_k
527-
last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size
528-
last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)(
529-
*last_n_tokens_data
530-
)
531-
logits: npt.NDArray[np.single] = self._scores[-1, :]
530+
531+
if idx is None:
532+
logits: npt.NDArray[np.single] = self._scores[-1, :]
533+
else:
534+
logits = self._scores[idx, :]
532535

533536
if logits_processor is not None:
534-
logits[:] = logits_processor(self._input_ids, logits)
535-
536-
nl_logit = logits[self._token_nl]
537-
self._candidates.copy_logits(logits)
538-
self._ctx.sample_repetition_penalties(
539-
candidates=self._candidates,
540-
last_tokens_data=last_n_tokens_data_c,
541-
penalty_last_n=last_n_tokens_size,
537+
logits[:] = (
538+
logits_processor(self._input_ids, logits)
539+
if idx is None
540+
else logits_processor(self._input_ids[:idx], logits)
541+
)
542+
543+
sampling_params = _LlamaSamplingParams(
544+
top_k=top_k,
545+
top_p=top_p,
546+
min_p=min_p,
547+
tfs_z=tfs_z,
548+
typical_p=typical_p,
549+
temp=temp,
550+
penalty_last_n=self.last_n_tokens_size,
542551
penalty_repeat=repeat_penalty,
543552
penalty_freq=frequency_penalty,
544553
penalty_present=presence_penalty,
554+
mirostat=mirostat_mode,
555+
mirostat_tau=mirostat_tau,
556+
mirostat_eta=mirostat_eta,
557+
penalize_nl=penalize_nl,
558+
)
559+
sampling_context = _LlamaSamplingContext(
560+
params=sampling_params,
561+
grammar=grammar,
562+
)
563+
sampling_context.prev = list(self.eval_tokens)
564+
id = sampling_context.sample(ctx_main=self._ctx, logits_array=logits)
565+
sampling_context.accept(
566+
ctx_main=self._ctx,
567+
id=id,
568+
apply_grammar=grammar is not None,
545569
)
546-
if not penalize_nl:
547-
self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float(
548-
nl_logit
549-
)
550-
551-
if grammar is not None:
552-
self._ctx.sample_grammar(
553-
candidates=self._candidates,
554-
grammar=grammar,
555-
)
556-
557-
if temp < 0.0:
558-
self._ctx.sample_softmax(candidates=self._candidates)
559-
id = self._candidates.candidates.data[0].id
560-
elif temp == 0.0:
561-
id = self._ctx.sample_token_greedy(candidates=self._candidates)
562-
elif mirostat_mode == 1:
563-
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
564-
id = self._ctx.sample_token_mirostat(
565-
candidates=self._candidates,
566-
tau=mirostat_tau,
567-
eta=mirostat_eta,
568-
mu=ctypes.pointer(self._mirostat_mu),
569-
m=100,
570-
)
571-
elif mirostat_mode == 2:
572-
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
573-
id = self._ctx.sample_token_mirostat_v2(
574-
candidates=self._candidates,
575-
tau=mirostat_tau,
576-
eta=mirostat_eta,
577-
mu=ctypes.pointer(self._mirostat_mu),
578-
)
579-
else:
580-
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
581-
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
582-
self._ctx.sample_typical(
583-
candidates=self._candidates, p=typical_p, min_keep=1
584-
)
585-
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
586-
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
587-
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
588-
id = self._ctx.sample_token(candidates=self._candidates)
589-
if grammar is not None:
590-
self._ctx.grammar_accept_token(grammar=grammar, token=id)
591570
return id
592571

593572
def generate(
@@ -656,34 +635,56 @@ def generate(
656635
if grammar is not None:
657636
grammar.reset()
658637

638+
sample_idx = self.n_tokens + len(tokens) - 1
639+
tokens = list(tokens)
640+
659641
# Eval and sample
660642
while True:
661643
self.eval(tokens)
662-
token = self.sample(
663-
top_k=top_k,
664-
top_p=top_p,
665-
min_p=min_p,
666-
typical_p=typical_p,
667-
temp=temp,
668-
repeat_penalty=repeat_penalty,
669-
frequency_penalty=frequency_penalty,
670-
presence_penalty=presence_penalty,
671-
tfs_z=tfs_z,
672-
mirostat_mode=mirostat_mode,
673-
mirostat_tau=mirostat_tau,
674-
mirostat_eta=mirostat_eta,
675-
logits_processor=logits_processor,
676-
grammar=grammar,
677-
penalize_nl=penalize_nl,
678-
)
679-
if stopping_criteria is not None and stopping_criteria(
680-
self._input_ids, self._scores[-1, :]
681-
):
682-
return
683-
tokens_or_none = yield token
684-
tokens = [token]
685-
if tokens_or_none is not None:
686-
tokens.extend(tokens_or_none)
644+
while sample_idx < self.n_tokens:
645+
token = self.sample(
646+
top_k=top_k,
647+
top_p=top_p,
648+
min_p=min_p,
649+
typical_p=typical_p,
650+
temp=temp,
651+
repeat_penalty=repeat_penalty,
652+
frequency_penalty=frequency_penalty,
653+
presence_penalty=presence_penalty,
654+
tfs_z=tfs_z,
655+
mirostat_mode=mirostat_mode,
656+
mirostat_tau=mirostat_tau,
657+
mirostat_eta=mirostat_eta,
658+
logits_processor=logits_processor,
659+
grammar=grammar,
660+
penalize_nl=penalize_nl,
661+
idx=sample_idx,
662+
)
663+
664+
sample_idx += 1
665+
if stopping_criteria is not None and stopping_criteria(
666+
self._input_ids, self._scores[-1, :]
667+
):
668+
return
669+
tokens_or_none = yield token
670+
tokens.clear()
671+
tokens.append(token)
672+
if tokens_or_none is not None:
673+
tokens.extend(tokens_or_none)
674+
675+
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
676+
self.n_tokens = sample_idx
677+
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
678+
break
679+
680+
if self.draft_model is not None:
681+
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
682+
draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)])
683+
tokens.extend(
684+
draft_tokens.astype(int)[
685+
: self._n_ctx - self.n_tokens - len(tokens)
686+
]
687+
)
687688

688689
def create_embedding(
689690
self, input: Union[str, List[str]], model: Optional[str] = None

llama_cpp/llama_speculative.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import abc
2+
3+
from typing import Any
4+
5+
import numpy as np
6+
import numpy.typing as npt
7+
8+
9+
class LlamaDraftModel(abc.ABC):
10+
@abc.abstractmethod
11+
def __call__(
12+
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
13+
) -> npt.NDArray[np.intc]:
14+
raise NotImplementedError()
15+
16+
17+
class LlamaPromptLookupDecoding(LlamaDraftModel):
18+
"""Based on https://github.com/apoorvumang/prompt-lookup-decoding"""
19+
20+
def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10):
21+
self.max_ngram_size = max_ngram_size
22+
self.num_pred_tokens = num_pred_tokens
23+
24+
@staticmethod
25+
def find_candidate_pred_tokens(
26+
input_ids: npt.NDArray[np.intc],
27+
max_ngram_size: int,
28+
num_pred_tokens: int,
29+
):
30+
input_length = input_ids.shape[0]
31+
32+
for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1):
33+
# Create sliding windows of size ngram_size
34+
windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,))
35+
36+
# Convert ngram to an array for comparison
37+
ngram_array = input_ids[-ngram_size:]
38+
39+
# Find where the windows match the ngram
40+
matches = np.all(windows == ngram_array, axis=1)
41+
42+
# Get the indices of matches
43+
match_indices = np.nonzero(matches)[0]
44+
45+
# Iterate through match indices to find a valid continuation
46+
for idx in match_indices:
47+
start_idx = idx + ngram_size
48+
end_idx = start_idx + num_pred_tokens
49+
end_idx = min(end_idx, input_length)
50+
51+
if start_idx < end_idx:
52+
return input_ids[start_idx:end_idx]
53+
54+
# If no match is found, return an empty array
55+
return np.array([], dtype=np.intc)
56+
57+
def __call__(
58+
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
59+
) -> npt.NDArray[np.intc]:
60+
return self.find_candidate_pred_tokens(
61+
input_ids=input_ids,
62+
max_ngram_size=self.max_ngram_size,
63+
num_pred_tokens=self.num_pred_tokens,
64+
)

llama_cpp/server/model.py

+9
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Dict, Optional, Union, List
66

77
import llama_cpp
8+
import llama_cpp.llama_speculative as llama_speculative
89

910
from llama_cpp.server.settings import ModelSettings
1011

@@ -92,6 +93,12 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
9293
)
9394
)
9495

96+
draft_model = None
97+
if settings.draft_model is not None:
98+
draft_model = llama_speculative.LlamaPromptLookupDecoding(
99+
num_pred_tokens=settings.draft_model_num_pred_tokens
100+
)
101+
95102
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
96103
if settings.kv_overrides is not None:
97104
assert isinstance(settings.kv_overrides, list)
@@ -147,6 +154,8 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama:
147154
# Chat Format Params
148155
chat_format=settings.chat_format,
149156
chat_handler=chat_handler,
157+
# Speculative Decoding
158+
draft_model=draft_model,
150159
# Misc
151160
verbose=settings.verbose,
152161
)

llama_cpp/server/settings.py

+9
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,15 @@ class ModelSettings(BaseSettings):
143143
default=None,
144144
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
145145
)
146+
# Speculative Decoding
147+
draft_model: Optional[str] = Field(
148+
default=None,
149+
description="Method to use for speculative decoding. One of (prompt-lookup-decoding).",
150+
)
151+
draft_model_num_pred_tokens: int = Field(
152+
default=10,
153+
description="Number of tokens to predict using the draft model.",
154+
)
146155
# Misc
147156
verbose: bool = Field(
148157
default=True, description="Whether to print debug information."

tests/test_llama_speculative.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import numpy as np
2+
3+
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
4+
5+
def test_find_candidate_pred_tokens():
6+
find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens
7+
8+
# Test Case 1: Matching ngram is found
9+
input_ids1 = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])
10+
result1 = find_candidate_pred_tokens(input_ids1, max_ngram_size=3, num_pred_tokens=2)
11+
assert np.array_equal(result1, np.array([1, 2]))
12+
13+
# Test Case 2: Matching ngram is not found
14+
input_ids2 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
15+
result2 = find_candidate_pred_tokens(input_ids2, max_ngram_size=3, num_pred_tokens=2)
16+
assert np.array_equal(result2, np.array([]))

0 commit comments

Comments
 (0)