Skip to content

Commit 433a2e3

Browse files
committed
Add extra logits_processor and stopping_criteria
1 parent 30bf8ec commit 433a2e3

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

llama_cpp/llama.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,8 @@ def _create_completion(
677677
mirostat_tau: float = 5.0,
678678
mirostat_eta: float = 0.1,
679679
model: Optional[str] = None,
680+
stopping_criteria: Optional[StoppingCriteriaList] = None,
681+
logits_processor: Optional[LogitsProcessorList] = None,
680682
) -> Union[Iterator[Completion], Iterator[CompletionChunk]]:
681683
assert self.ctx is not None
682684

@@ -739,6 +741,8 @@ def _create_completion(
739741
frequency_penalty=frequency_penalty,
740742
presence_penalty=presence_penalty,
741743
repeat_penalty=repeat_penalty,
744+
stopping_criteria=stopping_criteria,
745+
logits_processor=logits_processor,
742746
):
743747
if token == self._token_eos:
744748
text = self.detokenize(completion_tokens)
@@ -848,6 +852,11 @@ def _create_completion(
848852
finish_reason = "length"
849853
break
850854

855+
if stopping_criteria is not None and stopping_criteria(
856+
list(self.eval_tokens), self.eval_logits[-1]
857+
):
858+
finish_reason = "stop"
859+
851860
if self.verbose:
852861
llama_cpp.llama_print_timings(self.ctx)
853862

@@ -1049,6 +1058,8 @@ def create_completion(
10491058
mirostat_tau: float = 5.0,
10501059
mirostat_eta: float = 0.1,
10511060
model: Optional[str] = None,
1061+
stopping_criteria: Optional[StoppingCriteriaList] = None,
1062+
logits_processor: Optional[LogitsProcessorList] = None,
10521063
) -> Union[Completion, Iterator[CompletionChunk]]:
10531064
"""Generate text from a prompt.
10541065
@@ -1091,6 +1102,8 @@ def create_completion(
10911102
mirostat_tau=mirostat_tau,
10921103
mirostat_eta=mirostat_eta,
10931104
model=model,
1105+
stopping_criteria=stopping_criteria,
1106+
logits_processor=logits_processor,
10941107
)
10951108
if stream:
10961109
chunks: Iterator[CompletionChunk] = completion_or_chunks
@@ -1118,6 +1131,8 @@ def __call__(
11181131
mirostat_tau: float = 5.0,
11191132
mirostat_eta: float = 0.1,
11201133
model: Optional[str] = None,
1134+
stopping_criteria: Optional[StoppingCriteriaList] = None,
1135+
logits_processor: Optional[LogitsProcessorList] = None,
11211136
) -> Union[Completion, Iterator[CompletionChunk]]:
11221137
"""Generate text from a prompt.
11231138
@@ -1160,6 +1175,8 @@ def __call__(
11601175
mirostat_tau=mirostat_tau,
11611176
mirostat_eta=mirostat_eta,
11621177
model=model,
1178+
stopping_criteria=stopping_criteria,
1179+
logits_processor=logits_processor,
11631180
)
11641181

11651182
def _convert_text_completion_to_chat(

0 commit comments

Comments
 (0)