Skip to content

Commit d957422

Browse files
committed
Implement sampling as in llama.cpp main example
1 parent 93a9019 commit d957422

File tree

1 file changed

+70
-80
lines changed

1 file changed

+70
-80
lines changed

llama_cpp/llama.py

Lines changed: 70 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,13 @@ def _sample(
268268
top_k: llama_cpp.c_int,
269269
top_p: llama_cpp.c_float,
270270
temp: llama_cpp.c_float,
271-
mirostat_mode: llama_cpp.c_int,
272-
mirostat_tau: llama_cpp.c_float,
273-
mirostat_eta: llama_cpp.c_float,
274-
mirostat_mu: llama_cpp.c_float,
275-
mirostat_m: llama_cpp.c_int,
271+
tfs_z: llama_cpp.c_float,
276272
repeat_penalty: llama_cpp.c_float,
277273
frequency_penalty: llama_cpp.c_float,
278274
presence_penalty: llama_cpp.c_float,
275+
mirostat_mode: llama_cpp.c_int,
276+
mirostat_tau: llama_cpp.c_float,
277+
mirostat_eta: llama_cpp.c_float,
279278
):
280279
assert self.ctx is not None
281280
assert len(self.eval_logits) > 0
@@ -305,45 +304,48 @@ def _sample(
305304
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
306305
penalty=repeat_penalty,
307306
)
308-
if mirostat_mode.value == 1:
307+
llama_cpp.llama_sample_frequency_and_presence_penalties(
308+
ctx=self.ctx,
309+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
310+
last_tokens_data=last_n_tokens_data,
311+
last_tokens_size=last_n_tokens_size,
312+
alpha_frequency=frequency_penalty,
313+
alpha_presence=presence_penalty,
314+
)
315+
if temp.value == 0.0:
316+
return llama_cpp.llama_sample_token_greedy(
317+
ctx=self.ctx,
318+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
319+
)
320+
elif mirostat_mode.value == 1:
321+
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
322+
mirostat_m = llama_cpp.c_int(100)
309323
llama_cpp.llama_sample_temperature(
310324
ctx=self.ctx,
311-
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
325+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
312326
temp=temp,
313327
)
314-
llama_cpp.llama_sample_token_mirostat(
328+
return llama_cpp.llama_sample_token_mirostat(
315329
ctx=self.ctx,
316-
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
330+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
317331
tau=mirostat_tau,
318332
eta=mirostat_eta,
319-
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
320-
m=mirostat_m
333+
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
334+
m=mirostat_m,
321335
)
322336
elif mirostat_mode.value == 2:
337+
mirostat_mu = llama_cpp.c_float(2.0 * mirostat_tau.value)
323338
llama_cpp.llama_sample_temperature(
324339
ctx=self.ctx,
325340
candidates=llama_cpp.ctypes.pointer(candidates),
326341
temp=temp,
327342
)
328-
llama_cpp.llama_sample_token_mirostat_v2(
343+
return llama_cpp.llama_sample_token_mirostat_v2(
329344
ctx=self.ctx,
330-
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
345+
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
331346
tau=mirostat_tau,
332347
eta=mirostat_eta,
333-
mu=llama_cpp.ctypes.byref(mirostat_mu) # type: ignore
334-
)
335-
llama_cpp.llama_sample_frequency_and_presence_penalties(
336-
ctx=self.ctx,
337-
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
338-
last_tokens_data=last_n_tokens_data,
339-
last_tokens_size=last_n_tokens_size,
340-
alpha_frequency=frequency_penalty,
341-
alpha_presence=presence_penalty,
342-
)
343-
if float(temp.value) == 0.0:
344-
return llama_cpp.llama_sample_token_greedy(
345-
ctx=self.ctx,
346-
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
348+
mu=llama_cpp.ctypes.byref(mirostat_mu), # type: ignore
347349
)
348350
else:
349351
llama_cpp.llama_sample_top_k(
@@ -355,7 +357,7 @@ def _sample(
355357
llama_cpp.llama_sample_tail_free(
356358
ctx=self.ctx,
357359
candidates=llama_cpp.ctypes.byref(candidates), # type: ignore
358-
z=llama_cpp.c_float(1.0),
360+
z=tfs_z,
359361
min_keep=llama_cpp.c_size_t(1),
360362
)
361363
llama_cpp.llama_sample_typical(
@@ -382,17 +384,16 @@ def _sample(
382384

383385
def sample(
384386
self,
385-
top_k: int,
386-
top_p: float,
387-
temp: float,
388-
mirostat_mode: int,
389-
mirostat_tau: float,
390-
mirostat_eta: float,
391-
mirostat_mu: float,
392-
mirostat_m: int,
393-
repeat_penalty: float,
387+
top_k: int = 40,
388+
top_p: float = 0.95,
389+
temp: float = 0.80,
390+
repeat_penalty: float = 1.1,
394391
frequency_penalty: float = 0.0,
395392
presence_penalty: float = 0.0,
393+
tfs_z: float = 1.0,
394+
mirostat_mode: int = 0,
395+
mirostat_eta: float = 0.1,
396+
mirostat_tau: float = 5.0,
396397
):
397398
"""Sample a token from the model.
398399
@@ -417,14 +418,13 @@ def sample(
417418
top_k=llama_cpp.c_int(top_k),
418419
top_p=llama_cpp.c_float(top_p),
419420
temp=llama_cpp.c_float(temp),
420-
mirostat_mode=llama_cpp.c_int(mirostat_mode),
421-
mirostat_mu=llama_cpp.c_float(mirostat_mu),
422-
mirostat_tau=llama_cpp.c_float(mirostat_tau),
423-
mirostat_eta=llama_cpp.c_float(mirostat_eta),
424-
mirostat_m=llama_cpp.c_int(mirostat_m),
421+
tfs_z=llama_cpp.c_float(tfs_z),
425422
repeat_penalty=llama_cpp.c_float(repeat_penalty),
426423
frequency_penalty=llama_cpp.c_float(frequency_penalty),
427424
presence_penalty=llama_cpp.c_float(presence_penalty),
425+
mirostat_mode=llama_cpp.c_int(mirostat_mode),
426+
mirostat_tau=llama_cpp.c_float(mirostat_tau),
427+
mirostat_eta=llama_cpp.c_float(mirostat_eta),
428428
)
429429

430430
def generate(
@@ -433,15 +433,13 @@ def generate(
433433
top_k: int,
434434
top_p: float,
435435
temp: float,
436-
mirostat_mode: int,
437-
mirostat_tau: float,
438-
mirostat_eta: float,
439-
mirostat_mu: float,
440-
mirostat_m: int,
441436
repeat_penalty: float,
437+
reset: bool = True,
442438
frequency_penalty: float = 0.0,
443439
presence_penalty: float = 0.0,
444-
reset: bool = True,
440+
mirostat_mode: int = 0,
441+
mirostat_tau: float = 5.0,
442+
mirostat_eta: float = 0.1,
445443
) -> Generator[
446444
llama_cpp.llama_token, Optional[Sequence[llama_cpp.llama_token]], None
447445
]:
@@ -494,14 +492,12 @@ def generate(
494492
top_k=top_k,
495493
top_p=top_p,
496494
temp=temp,
495+
repeat_penalty=repeat_penalty,
496+
frequency_penalty=frequency_penalty,
497+
presence_penalty=presence_penalty,
497498
mirostat_mode=mirostat_mode,
498499
mirostat_tau=mirostat_tau,
499500
mirostat_eta=mirostat_eta,
500-
mirostat_mu=mirostat_mu,
501-
mirostat_m=mirostat_m,
502-
frequency_penalty=frequency_penalty,
503-
presence_penalty=presence_penalty,
504-
repeat_penalty=repeat_penalty,
505501
)
506502
tokens_or_none = yield token
507503
tokens = [token]
@@ -571,11 +567,6 @@ def _create_completion(
571567
suffix: Optional[str] = None,
572568
max_tokens: int = 16,
573569
temperature: float = 0.8,
574-
mirostat_mode: int = 0,
575-
mirostat_tau: float = 5.0,
576-
mirostat_eta: float = 0.1,
577-
mirostat_mu: float = 10,
578-
mirostat_m: int = 100,
579570
top_p: float = 0.95,
580571
logprobs: Optional[int] = None,
581572
echo: bool = False,
@@ -585,6 +576,9 @@ def _create_completion(
585576
repeat_penalty: float = 1.1,
586577
top_k: int = 40,
587578
stream: bool = False,
579+
mirostat_mode: int = 0,
580+
mirostat_tau: float = 5.0,
581+
mirostat_eta: float = 0.1,
588582
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
589583
assert self.ctx is not None
590584
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
@@ -643,8 +637,6 @@ def _create_completion(
643637
mirostat_mode=mirostat_mode,
644638
mirostat_tau=mirostat_tau,
645639
mirostat_eta=mirostat_eta,
646-
mirostat_mu=mirostat_mu,
647-
mirostat_m=mirostat_m,
648640
frequency_penalty=frequency_penalty,
649641
presence_penalty=presence_penalty,
650642
repeat_penalty=repeat_penalty,
@@ -817,11 +809,6 @@ def create_completion(
817809
suffix: Optional[str] = None,
818810
max_tokens: int = 128,
819811
temperature: float = 0.8,
820-
mirostat_mode: int = 0,
821-
mirostat_tau: float = 5.0,
822-
mirostat_eta: float = 0.1,
823-
mirostat_mu: float = 10,
824-
mirostat_m: int = 100,
825812
top_p: float = 0.95,
826813
logprobs: Optional[int] = None,
827814
echo: bool = False,
@@ -831,6 +818,9 @@ def create_completion(
831818
repeat_penalty: float = 1.1,
832819
top_k: int = 40,
833820
stream: bool = False,
821+
mirostat_mode: int = 0,
822+
mirostat_tau: float = 5.0,
823+
mirostat_eta: float = 0.1,
834824
) -> Union[Completion, Iterator[CompletionChunk]]:
835825
"""Generate text from a prompt.
836826
@@ -859,11 +849,6 @@ def create_completion(
859849
suffix=suffix,
860850
max_tokens=max_tokens,
861851
temperature=temperature,
862-
mirostat_mode=mirostat_mode,
863-
mirostat_tau=mirostat_tau,
864-
mirostat_eta=mirostat_eta,
865-
mirostat_mu=mirostat_mu,
866-
mirostat_m=mirostat_m,
867852
top_p=top_p,
868853
logprobs=logprobs,
869854
echo=echo,
@@ -873,6 +858,9 @@ def create_completion(
873858
repeat_penalty=repeat_penalty,
874859
top_k=top_k,
875860
stream=stream,
861+
mirostat_mode=mirostat_mode,
862+
mirostat_tau=mirostat_tau,
863+
mirostat_eta=mirostat_eta,
876864
)
877865
if stream:
878866
chunks: Iterator[CompletionChunk] = completion_or_chunks
@@ -886,11 +874,6 @@ def __call__(
886874
suffix: Optional[str] = None,
887875
max_tokens: int = 128,
888876
temperature: float = 0.8,
889-
mirostat_mode: int = 0,
890-
mirostat_tau: float = 5.0,
891-
mirostat_eta: float = 0.1,
892-
mirostat_mu: float = 10,
893-
mirostat_m: int = 100,
894877
top_p: float = 0.95,
895878
logprobs: Optional[int] = None,
896879
echo: bool = False,
@@ -900,6 +883,9 @@ def __call__(
900883
repeat_penalty: float = 1.1,
901884
top_k: int = 40,
902885
stream: bool = False,
886+
mirostat_mode: int = 0,
887+
mirostat_tau: float = 5.0,
888+
mirostat_eta: float = 0.1,
903889
) -> Union[Completion, Iterator[CompletionChunk]]:
904890
"""Generate text from a prompt.
905891
@@ -928,11 +914,6 @@ def __call__(
928914
suffix=suffix,
929915
max_tokens=max_tokens,
930916
temperature=temperature,
931-
mirostat_mode=mirostat_mode,
932-
mirostat_tau=mirostat_tau,
933-
mirostat_eta=mirostat_eta,
934-
mirostat_mu=mirostat_mu,
935-
mirostat_m=mirostat_m,
936917
top_p=top_p,
937918
logprobs=logprobs,
938919
echo=echo,
@@ -942,6 +923,9 @@ def __call__(
942923
repeat_penalty=repeat_penalty,
943924
top_k=top_k,
944925
stream=stream,
926+
mirostat_mode=mirostat_mode,
927+
mirostat_tau=mirostat_tau,
928+
mirostat_eta=mirostat_eta,
945929
)
946930

947931
def _convert_text_completion_to_chat(
@@ -1014,6 +998,9 @@ def create_chat_completion(
1014998
presence_penalty: float = 0.0,
1015999
frequency_penalty: float = 0.0,
10161000
repeat_penalty: float = 1.1,
1001+
mirostat_mode: int = 0,
1002+
mirostat_tau: float = 5.0,
1003+
mirostat_eta: float = 0.1,
10171004
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
10181005
"""Generate a chat completion from a list of messages.
10191006
@@ -1048,6 +1035,9 @@ def create_chat_completion(
10481035
repeat_penalty=repeat_penalty,
10491036
presence_penalty=presence_penalty,
10501037
frequency_penalty=frequency_penalty,
1038+
mirostat_mode=mirostat_mode,
1039+
mirostat_tau=mirostat_tau,
1040+
mirostat_eta=mirostat_eta,
10511041
)
10521042
if stream:
10531043
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore

0 commit comments

Comments
 (0)