diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 343581dce..ddbac7725 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -806,6 +806,22 @@ def add_mirostat_v2(self, seed: int, tau: float, eta: float): sampler = llama_cpp.llama_sampler_init_mirostat_v2(seed, tau, eta) self._add_sampler(sampler) + def add_xtc(self, probability: float, threshold: float, min_keep: int, seed: int): + sampler = llama_cpp.llama_sampler_init_xtc(probability, threshold, min_keep, seed) + self._add_sampler(sampler) + + def add_dry(self, model: LlamaModel, ctx: LlamaContext, multiplier: float, base: float, + allowed_length: int, penalty_last_n: int, seq_breakers: list[str] = []): + + # Convert Python strings to bytes + seq_breakers_bytes = [s.encode('utf-8') for s in seq_breakers] + # Create array of char* + arr = (ctypes.c_char_p * len(seq_breakers_bytes))(*seq_breakers_bytes) + sampler = llama_cpp.llama_sampler_init_dry(model.vocab, ctx.n_ctx(), multiplier, base, + allowed_length, penalty_last_n, + arr, len(seq_breakers)) + self._add_sampler(sampler) + def add_grammar(self, model: LlamaModel, grammar: LlamaGrammar): sampler = llama_cpp.llama_sampler_init_grammar( model.vocab, grammar._grammar.encode("utf-8"), grammar._root.encode("utf-8") diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 7e9a6af23..d05c421a4 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -680,6 +680,13 @@ def _init_sampler( mirostat_mode: int = 0, mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, @@ -747,11 +754,13 @@ def apply_func(token_data_array: llama_cpp.llama_token_data_array_p): else: n_probs = 0 min_keep = max(1, n_probs) + sampler.add_dry(self._model, self._ctx, dry_multiplier, dry_base, dry_allowed_length, dry_range, dry_seq_breakers) sampler.add_top_k(top_k) sampler.add_typical(typical_p, min_keep) sampler.add_top_p(top_p, min_keep) sampler.add_min_p(min_p, min_keep) sampler.add_temp(temp) + sampler.add_xtc(xtc_probability, xtc_threshold, min_keep, self._seed) sampler.add_dist(self._seed) return sampler @@ -769,6 +778,13 @@ def sample( mirostat_mode: int = 0, mirostat_eta: float = 0.1, mirostat_tau: float = 5.0, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, @@ -804,6 +820,13 @@ def sample( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, penalize_nl=penalize_nl, logits_processor=logits_processor, grammar=grammar, @@ -833,6 +856,13 @@ def generate( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, @@ -872,6 +902,13 @@ def generate( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, penalize_nl=penalize_nl, logits_processor=logits_processor, grammar=grammar, @@ -924,6 +961,13 @@ def generate( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, logits_processor=logits_processor, grammar=grammar, penalize_nl=penalize_nl, @@ -1140,6 +1184,13 @@ def _create_completion( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, @@ -1328,6 +1379,13 @@ def logit_bias_processor( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, frequency_penalty=frequency_penalty, presence_penalty=presence_penalty, repeat_penalty=repeat_penalty, @@ -1760,6 +1818,13 @@ def create_completion( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, @@ -1823,6 +1888,13 @@ def create_completion( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, @@ -1857,6 +1929,13 @@ def __call__( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], model: Optional[str] = None, stopping_criteria: Optional[StoppingCriteriaList] = None, logits_processor: Optional[LogitsProcessorList] = None, @@ -1920,6 +1999,13 @@ def __call__( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, model=model, stopping_criteria=stopping_criteria, logits_processor=logits_processor, @@ -1951,6 +2037,13 @@ def create_chat_completion( mirostat_mode: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 0.1, + xtc_probability: float = 0.0, + xtc_threshold: float = 0.1, + dry_multiplier: float = 0.0, + dry_allowed_length: int = 2, + dry_base: float = 1.75, + dry_range: int = 0, + dry_seq_breakers: list[str] = [], model: Optional[str] = None, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, @@ -2024,6 +2117,13 @@ def create_chat_completion( mirostat_mode=mirostat_mode, mirostat_tau=mirostat_tau, mirostat_eta=mirostat_eta, + xtc_probability=xtc_probability, + xtc_threshold=xtc_threshold, + dry_multiplier=dry_multiplier, + dry_allowed_length=dry_allowed_length, + dry_base=dry_base, + dry_range=dry_range, + dry_seq_breakers=dry_seq_breakers, model=model, logits_processor=logits_processor, grammar=grammar, diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 710bd83c8..f7949b121 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -3882,6 +3882,41 @@ def llama_sampler_init_xtc( ) -> llama_sampler_p: ... +# LLAMA_API struct llama_sampler * llama_sampler_init_dry( +# const struct llama_vocab * vocab, +# int32_t context_size, +# float dry_multiplier, +# float dry_base, +# int32_t dry_allowed_length, +# int32_t dry_penalty_last_n, +# const char ** seq_breakers, +# size_t num_breakers); +@ctypes_function( +"llama_sampler_init_dry", + [ + llama_vocab_p_ctypes, + ctypes.c_int32, + ctypes.c_float, + ctypes.c_float, + ctypes.c_int32, + ctypes.c_int32, + ctypes.POINTER(ctypes.c_char_p), + ctypes.c_size_t + ], + llama_sampler_p_ctypes, +) +def llama_sampler_init_dry( + vocab: llama_vocab_p, + context_size: int, + dry_multiplier: float, + dry_base: float, + dry_allowed_length: int, + dry_penalty_last_n: int, + seq_breakers: list[str], + num_breakers: int, +) -> llama_sampler_p: + ... + # /// @details Top n sigma sampling as described in academic paper "Top-nσ: Not All Logits Are You Need" https://arxiv.org/pdf/2411.07641 # LLAMA_API struct llama_sampler * llama_sampler_init_top_n_sigma(float n);