@@ -268,14 +268,13 @@ def _sample(
268
268
top_k : llama_cpp .c_int ,
269
269
top_p : llama_cpp .c_float ,
270
270
temp : llama_cpp .c_float ,
271
- mirostat_mode : llama_cpp .c_int ,
272
- mirostat_tau : llama_cpp .c_float ,
273
- mirostat_eta : llama_cpp .c_float ,
274
- mirostat_mu : llama_cpp .c_float ,
275
- mirostat_m : llama_cpp .c_int ,
271
+ tfs_z : llama_cpp .c_float ,
276
272
repeat_penalty : llama_cpp .c_float ,
277
273
frequency_penalty : llama_cpp .c_float ,
278
274
presence_penalty : llama_cpp .c_float ,
275
+ mirostat_mode : llama_cpp .c_int ,
276
+ mirostat_tau : llama_cpp .c_float ,
277
+ mirostat_eta : llama_cpp .c_float ,
279
278
):
280
279
assert self .ctx is not None
281
280
assert len (self .eval_logits ) > 0
@@ -305,45 +304,48 @@ def _sample(
305
304
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
306
305
penalty = repeat_penalty ,
307
306
)
308
- if mirostat_mode .value == 1 :
307
+ llama_cpp .llama_sample_frequency_and_presence_penalties (
308
+ ctx = self .ctx ,
309
+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
310
+ last_tokens_data = last_n_tokens_data ,
311
+ last_tokens_size = last_n_tokens_size ,
312
+ alpha_frequency = frequency_penalty ,
313
+ alpha_presence = presence_penalty ,
314
+ )
315
+ if temp .value == 0.0 :
316
+ return llama_cpp .llama_sample_token_greedy (
317
+ ctx = self .ctx ,
318
+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
319
+ )
320
+ elif mirostat_mode .value == 1 :
321
+ mirostat_mu = llama_cpp .c_float (2.0 * mirostat_tau .value )
322
+ mirostat_m = llama_cpp .c_int (100 )
309
323
llama_cpp .llama_sample_temperature (
310
324
ctx = self .ctx ,
311
- candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
325
+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
312
326
temp = temp ,
313
327
)
314
- llama_cpp .llama_sample_token_mirostat (
328
+ return llama_cpp .llama_sample_token_mirostat (
315
329
ctx = self .ctx ,
316
- candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
330
+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
317
331
tau = mirostat_tau ,
318
332
eta = mirostat_eta ,
319
- mu = llama_cpp .ctypes .byref (mirostat_mu ), # type: ignore
320
- m = mirostat_m
333
+ mu = llama_cpp .ctypes .byref (mirostat_mu ), # type: ignore
334
+ m = mirostat_m ,
321
335
)
322
336
elif mirostat_mode .value == 2 :
337
+ mirostat_mu = llama_cpp .c_float (2.0 * mirostat_tau .value )
323
338
llama_cpp .llama_sample_temperature (
324
339
ctx = self .ctx ,
325
340
candidates = llama_cpp .ctypes .pointer (candidates ),
326
341
temp = temp ,
327
342
)
328
- llama_cpp .llama_sample_token_mirostat_v2 (
343
+ return llama_cpp .llama_sample_token_mirostat_v2 (
329
344
ctx = self .ctx ,
330
- candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
345
+ candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
331
346
tau = mirostat_tau ,
332
347
eta = mirostat_eta ,
333
- mu = llama_cpp .ctypes .byref (mirostat_mu ) # type: ignore
334
- )
335
- llama_cpp .llama_sample_frequency_and_presence_penalties (
336
- ctx = self .ctx ,
337
- candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
338
- last_tokens_data = last_n_tokens_data ,
339
- last_tokens_size = last_n_tokens_size ,
340
- alpha_frequency = frequency_penalty ,
341
- alpha_presence = presence_penalty ,
342
- )
343
- if float (temp .value ) == 0.0 :
344
- return llama_cpp .llama_sample_token_greedy (
345
- ctx = self .ctx ,
346
- candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
348
+ mu = llama_cpp .ctypes .byref (mirostat_mu ), # type: ignore
347
349
)
348
350
else :
349
351
llama_cpp .llama_sample_top_k (
@@ -355,7 +357,7 @@ def _sample(
355
357
llama_cpp .llama_sample_tail_free (
356
358
ctx = self .ctx ,
357
359
candidates = llama_cpp .ctypes .byref (candidates ), # type: ignore
358
- z = llama_cpp . c_float ( 1.0 ) ,
360
+ z = tfs_z ,
359
361
min_keep = llama_cpp .c_size_t (1 ),
360
362
)
361
363
llama_cpp .llama_sample_typical (
@@ -382,17 +384,16 @@ def _sample(
382
384
383
385
def sample (
384
386
self ,
385
- top_k : int ,
386
- top_p : float ,
387
- temp : float ,
388
- mirostat_mode : int ,
389
- mirostat_tau : float ,
390
- mirostat_eta : float ,
391
- mirostat_mu : float ,
392
- mirostat_m : int ,
393
- repeat_penalty : float ,
387
+ top_k : int = 40 ,
388
+ top_p : float = 0.95 ,
389
+ temp : float = 0.80 ,
390
+ repeat_penalty : float = 1.1 ,
394
391
frequency_penalty : float = 0.0 ,
395
392
presence_penalty : float = 0.0 ,
393
+ tfs_z : float = 1.0 ,
394
+ mirostat_mode : int = 0 ,
395
+ mirostat_eta : float = 0.1 ,
396
+ mirostat_tau : float = 5.0 ,
396
397
):
397
398
"""Sample a token from the model.
398
399
@@ -417,14 +418,13 @@ def sample(
417
418
top_k = llama_cpp .c_int (top_k ),
418
419
top_p = llama_cpp .c_float (top_p ),
419
420
temp = llama_cpp .c_float (temp ),
420
- mirostat_mode = llama_cpp .c_int (mirostat_mode ),
421
- mirostat_mu = llama_cpp .c_float (mirostat_mu ),
422
- mirostat_tau = llama_cpp .c_float (mirostat_tau ),
423
- mirostat_eta = llama_cpp .c_float (mirostat_eta ),
424
- mirostat_m = llama_cpp .c_int (mirostat_m ),
421
+ tfs_z = llama_cpp .c_float (tfs_z ),
425
422
repeat_penalty = llama_cpp .c_float (repeat_penalty ),
426
423
frequency_penalty = llama_cpp .c_float (frequency_penalty ),
427
424
presence_penalty = llama_cpp .c_float (presence_penalty ),
425
+ mirostat_mode = llama_cpp .c_int (mirostat_mode ),
426
+ mirostat_tau = llama_cpp .c_float (mirostat_tau ),
427
+ mirostat_eta = llama_cpp .c_float (mirostat_eta ),
428
428
)
429
429
430
430
def generate (
@@ -433,15 +433,13 @@ def generate(
433
433
top_k : int ,
434
434
top_p : float ,
435
435
temp : float ,
436
- mirostat_mode : int ,
437
- mirostat_tau : float ,
438
- mirostat_eta : float ,
439
- mirostat_mu : float ,
440
- mirostat_m : int ,
441
436
repeat_penalty : float ,
437
+ reset : bool = True ,
442
438
frequency_penalty : float = 0.0 ,
443
439
presence_penalty : float = 0.0 ,
444
- reset : bool = True ,
440
+ mirostat_mode : int = 0 ,
441
+ mirostat_tau : float = 5.0 ,
442
+ mirostat_eta : float = 0.1 ,
445
443
) -> Generator [
446
444
llama_cpp .llama_token , Optional [Sequence [llama_cpp .llama_token ]], None
447
445
]:
@@ -494,14 +492,12 @@ def generate(
494
492
top_k = top_k ,
495
493
top_p = top_p ,
496
494
temp = temp ,
495
+ repeat_penalty = repeat_penalty ,
496
+ frequency_penalty = frequency_penalty ,
497
+ presence_penalty = presence_penalty ,
497
498
mirostat_mode = mirostat_mode ,
498
499
mirostat_tau = mirostat_tau ,
499
500
mirostat_eta = mirostat_eta ,
500
- mirostat_mu = mirostat_mu ,
501
- mirostat_m = mirostat_m ,
502
- frequency_penalty = frequency_penalty ,
503
- presence_penalty = presence_penalty ,
504
- repeat_penalty = repeat_penalty ,
505
501
)
506
502
tokens_or_none = yield token
507
503
tokens = [token ]
@@ -571,11 +567,6 @@ def _create_completion(
571
567
suffix : Optional [str ] = None ,
572
568
max_tokens : int = 16 ,
573
569
temperature : float = 0.8 ,
574
- mirostat_mode : int = 0 ,
575
- mirostat_tau : float = 5.0 ,
576
- mirostat_eta : float = 0.1 ,
577
- mirostat_mu : float = 10 ,
578
- mirostat_m : int = 100 ,
579
570
top_p : float = 0.95 ,
580
571
logprobs : Optional [int ] = None ,
581
572
echo : bool = False ,
@@ -585,6 +576,9 @@ def _create_completion(
585
576
repeat_penalty : float = 1.1 ,
586
577
top_k : int = 40 ,
587
578
stream : bool = False ,
579
+ mirostat_mode : int = 0 ,
580
+ mirostat_tau : float = 5.0 ,
581
+ mirostat_eta : float = 0.1 ,
588
582
) -> Union [Iterator [Completion ], Iterator [CompletionChunk ]]:
589
583
assert self .ctx is not None
590
584
completion_id : str = f"cmpl-{ str (uuid .uuid4 ())} "
@@ -643,8 +637,6 @@ def _create_completion(
643
637
mirostat_mode = mirostat_mode ,
644
638
mirostat_tau = mirostat_tau ,
645
639
mirostat_eta = mirostat_eta ,
646
- mirostat_mu = mirostat_mu ,
647
- mirostat_m = mirostat_m ,
648
640
frequency_penalty = frequency_penalty ,
649
641
presence_penalty = presence_penalty ,
650
642
repeat_penalty = repeat_penalty ,
@@ -817,11 +809,6 @@ def create_completion(
817
809
suffix : Optional [str ] = None ,
818
810
max_tokens : int = 128 ,
819
811
temperature : float = 0.8 ,
820
- mirostat_mode : int = 0 ,
821
- mirostat_tau : float = 5.0 ,
822
- mirostat_eta : float = 0.1 ,
823
- mirostat_mu : float = 10 ,
824
- mirostat_m : int = 100 ,
825
812
top_p : float = 0.95 ,
826
813
logprobs : Optional [int ] = None ,
827
814
echo : bool = False ,
@@ -831,6 +818,9 @@ def create_completion(
831
818
repeat_penalty : float = 1.1 ,
832
819
top_k : int = 40 ,
833
820
stream : bool = False ,
821
+ mirostat_mode : int = 0 ,
822
+ mirostat_tau : float = 5.0 ,
823
+ mirostat_eta : float = 0.1 ,
834
824
) -> Union [Completion , Iterator [CompletionChunk ]]:
835
825
"""Generate text from a prompt.
836
826
@@ -859,11 +849,6 @@ def create_completion(
859
849
suffix = suffix ,
860
850
max_tokens = max_tokens ,
861
851
temperature = temperature ,
862
- mirostat_mode = mirostat_mode ,
863
- mirostat_tau = mirostat_tau ,
864
- mirostat_eta = mirostat_eta ,
865
- mirostat_mu = mirostat_mu ,
866
- mirostat_m = mirostat_m ,
867
852
top_p = top_p ,
868
853
logprobs = logprobs ,
869
854
echo = echo ,
@@ -873,6 +858,9 @@ def create_completion(
873
858
repeat_penalty = repeat_penalty ,
874
859
top_k = top_k ,
875
860
stream = stream ,
861
+ mirostat_mode = mirostat_mode ,
862
+ mirostat_tau = mirostat_tau ,
863
+ mirostat_eta = mirostat_eta ,
876
864
)
877
865
if stream :
878
866
chunks : Iterator [CompletionChunk ] = completion_or_chunks
@@ -886,11 +874,6 @@ def __call__(
886
874
suffix : Optional [str ] = None ,
887
875
max_tokens : int = 128 ,
888
876
temperature : float = 0.8 ,
889
- mirostat_mode : int = 0 ,
890
- mirostat_tau : float = 5.0 ,
891
- mirostat_eta : float = 0.1 ,
892
- mirostat_mu : float = 10 ,
893
- mirostat_m : int = 100 ,
894
877
top_p : float = 0.95 ,
895
878
logprobs : Optional [int ] = None ,
896
879
echo : bool = False ,
@@ -900,6 +883,9 @@ def __call__(
900
883
repeat_penalty : float = 1.1 ,
901
884
top_k : int = 40 ,
902
885
stream : bool = False ,
886
+ mirostat_mode : int = 0 ,
887
+ mirostat_tau : float = 5.0 ,
888
+ mirostat_eta : float = 0.1 ,
903
889
) -> Union [Completion , Iterator [CompletionChunk ]]:
904
890
"""Generate text from a prompt.
905
891
@@ -928,11 +914,6 @@ def __call__(
928
914
suffix = suffix ,
929
915
max_tokens = max_tokens ,
930
916
temperature = temperature ,
931
- mirostat_mode = mirostat_mode ,
932
- mirostat_tau = mirostat_tau ,
933
- mirostat_eta = mirostat_eta ,
934
- mirostat_mu = mirostat_mu ,
935
- mirostat_m = mirostat_m ,
936
917
top_p = top_p ,
937
918
logprobs = logprobs ,
938
919
echo = echo ,
@@ -942,6 +923,9 @@ def __call__(
942
923
repeat_penalty = repeat_penalty ,
943
924
top_k = top_k ,
944
925
stream = stream ,
926
+ mirostat_mode = mirostat_mode ,
927
+ mirostat_tau = mirostat_tau ,
928
+ mirostat_eta = mirostat_eta ,
945
929
)
946
930
947
931
def _convert_text_completion_to_chat (
@@ -1014,6 +998,9 @@ def create_chat_completion(
1014
998
presence_penalty : float = 0.0 ,
1015
999
frequency_penalty : float = 0.0 ,
1016
1000
repeat_penalty : float = 1.1 ,
1001
+ mirostat_mode : int = 0 ,
1002
+ mirostat_tau : float = 5.0 ,
1003
+ mirostat_eta : float = 0.1 ,
1017
1004
) -> Union [ChatCompletion , Iterator [ChatCompletionChunk ]]:
1018
1005
"""Generate a chat completion from a list of messages.
1019
1006
@@ -1048,6 +1035,9 @@ def create_chat_completion(
1048
1035
repeat_penalty = repeat_penalty ,
1049
1036
presence_penalty = presence_penalty ,
1050
1037
frequency_penalty = frequency_penalty ,
1038
+ mirostat_mode = mirostat_mode ,
1039
+ mirostat_tau = mirostat_tau ,
1040
+ mirostat_eta = mirostat_eta ,
1051
1041
)
1052
1042
if stream :
1053
1043
chunks : Iterator [CompletionChunk ] = completion_or_chunks # type: ignore
0 commit comments