From a149b6d7fdc7a7d87a3bcce747c8ae34ea35c5fb Mon Sep 17 00:00:00 2001 From: Tianshu Bao Date: Thu, 6 Mar 2025 11:39:20 -0800 Subject: [PATCH] add top_p sampling in gemma example PiperOrigin-RevId: 734226133 --- examples/gemma/sampler.py | 80 +++++++++++++++++++++++++++++----- examples/gemma/sampler_test.py | 38 ++++++++++++++-- 2 files changed, 103 insertions(+), 15 deletions(-) diff --git a/examples/gemma/sampler.py b/examples/gemma/sampler.py index 5efb5c698..adbcb1267 100644 --- a/examples/gemma/sampler.py +++ b/examples/gemma/sampler.py @@ -22,7 +22,7 @@ from collections.abc import Sequence import dataclasses -import chex +import flax from flax import nnx import modules import sow_lib @@ -33,6 +33,21 @@ import sentencepiece as spm +def _sample_top_p(probs: jnp.ndarray, p: float, key: jax.Array) -> jnp.ndarray: + """Sample a token using top-p sampling.""" + probs_sorted, indices = jax.lax.top_k(probs, k=probs.shape[-1]) + cumsum_probs = jnp.cumsum(probs_sorted, axis=-1) + mask = cumsum_probs - probs_sorted > p + probs_sorted = jnp.where(mask, 0.0, probs_sorted) + probs_sorted /= jnp.sum(probs_sorted, axis=-1, keepdims=True) + + next_token = jax.random.categorical(key, logits=jnp.log(probs_sorted)) + + next_token = jnp.take_along_axis(indices, next_token[..., None], axis=-1) + next_token = jnp.squeeze(next_token, axis=-1) + return next_token + + def _compute_attention_masks( time_step: jax.Array, seq_len: int, input_mask: jax.Array ) -> jax.Array: @@ -60,7 +75,7 @@ def _compute_attention_masks( return ~attention_mask -@chex.dataclass +@flax.struct.dataclass class _SamplingState: """Internal sampling state.""" @@ -86,13 +101,22 @@ class _SamplingState: total_sampling_steps: int # Fixed-size buffer for accumulating the output logits. - logits_buffer: jnp.ndarray | None = None # [B, L, V] + logits_buffer: jnp.ndarray | None # [B, L, V] # List of tokens that are forbidden to be generated. - forbidden_token_ids: Sequence[int] | None = None + forbidden_token_ids: Sequence[int] | None # Intermediate activations from the model if requested. - intermediates: sow_lib.TransformerIntermediates | None = None + intermediates: sow_lib.TransformerIntermediates | None + + # Random seed for sampling. + seed: jax.Array + + # Tempurature for top_p sampling. + temperature: float = flax.struct.field(pytree_node=False) + + # Top-p sampling threshold. + top_p: float = flax.struct.field(pytree_node=False) @dataclasses.dataclass @@ -161,8 +185,21 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState: if sampler_state.forbidden_token_ids: logits = logits.at[:, :, sampler_state.forbidden_token_ids].set(-jnp.inf) - next_token_candidate = jnp.argmax(logits, axis=-1) # [B, 1] - next_token_candidate = next_token_candidate[:, 0] # [B,] + def sample_top_p(logits, key): + probs = jax.nn.softmax(logits[:, -1] / sampler_state.temperature, axis=-1) + next_token = _sample_top_p(probs, sampler_state.top_p, key) + return next_token + + def sample_best(logits): + next_token = jnp.argmax(logits, axis=-1) + next_token = next_token[:, 0] + return next_token + + if sampler_state.temperature > 0: + key = jax.random.fold_in(sampler_state.seed, decoding_step) + next_token_candidate = sample_top_p(logits, key) + else: + next_token_candidate = sample_best(logits) next_token_candidate = jnp.where( decoding_step < sampler_state.num_input_tokens - 1, @@ -200,14 +237,20 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState: total_sampling_steps=sampler_state.total_sampling_steps, forbidden_token_ids=sampler_state.forbidden_token_ids, intermediates=sampler_state.intermediates, + temperature=sampler_state.temperature, + top_p=sampler_state.top_p, + seed=sampler_state.seed, ) def init_sample_state( self, all_input_ids: list[jax.Array], total_sampling_steps: int, - include_logits: bool = False, - forbidden_token_ids: Sequence[int] | None = None, + include_logits: bool, + forbidden_token_ids: Sequence[int] | None, + temperature: float, + top_p: float, + seed: jax.Array, ) -> _SamplingState: """Initializes the sampling state given input prompts.""" batch_size = len(all_input_ids) @@ -259,6 +302,9 @@ def init_sample_state( intermediates=self.transformer.init_intermediates( batch_size, buffer_size, self.transformer.sow_config ), + temperature=temperature, + top_p=top_p, + seed=seed, ) def tokenize(self, input_string: str) -> jax.Array: @@ -281,7 +327,7 @@ def mask_tokens_after_eos_ids(self, token_buffer): mask = jnp.less_equal( jnp.arange(token_buffer.shape[-1]), eos_indices[:, None] ) - masked_token_buffer = token_buffer * mask + self.vocab.pad_id()*(1 - mask) + masked_token_buffer = token_buffer * mask + self.vocab.pad_id() * (1 - mask) return masked_token_buffer @@ -310,6 +356,9 @@ def __call__( echo: bool = False, return_logits: bool = True, forbidden_tokens: Sequence[str] | None = None, + temperature: float = 0.0, + top_p: float = 0.95, + seed: jax.Array | None = None, ) -> SamplerOutput: """Samples a completion of the input string. @@ -317,10 +366,13 @@ def __call__( input_strings: input prompts to feed to the model for sampling. total_generation_steps: number of generation steps. will correspond to the longest prompt in the batch. - echo: whether to return the prompt as part of the output sample. + echo: whgether to return the prompt as part of the output sample. return_logits: whether to return per-step logits used during generation. forbidden_tokens: list of tokens that are forbidden to be generated. Each token must map to a single token id in the vocab. + temperature: temperature for sampling. + top_p: top-p sampling threshold. + seed: random seed for sampling. Returns: sampler_output: A SamplerOutput object containing the generated samples. @@ -339,11 +391,17 @@ def __call__( all_input_ids = [self.tokenize(x) for x in input_strings] max_input_length = max(len(input_ids) for input_ids in all_input_ids) total_sampling_steps = max_input_length + total_generation_steps + + if seed is None: + seed = jax.random.PRNGKey(0) initial_sampling_state = self.init_sample_state( all_input_ids, include_logits=return_logits, total_sampling_steps=total_sampling_steps, forbidden_token_ids=forbidden_token_ids, + temperature=temperature, + top_p=top_p, + seed=seed, ) sampling_state = self._compiled_sample_fn(initial_sampling_state) diff --git a/examples/gemma/sampler_test.py b/examples/gemma/sampler_test.py index 7a7dcc95b..e3de0eca1 100644 --- a/examples/gemma/sampler_test.py +++ b/examples/gemma/sampler_test.py @@ -22,6 +22,7 @@ import sampler as sampler_lib import sow_lib import transformer as transformer_lib +import jax import jax.numpy as jnp import numpy as np @@ -108,6 +109,25 @@ def test_samples(self): result = sampler(['input string', 'hello world'], total_generation_steps=10) self.assertIsNotNone(result) + top_p_result = sampler( + ['input string', 'hello world'], + total_generation_steps=10, + temperature=9, + top_p=0.95, + ) + self.assertIsNotNone(top_p_result) + self.assertNotEqual(result.text, top_p_result.text) + + top_p_result_2 = sampler( + ['input string', 'hello world'], + total_generation_steps=10, + temperature=9, + top_p=0.95, + seed=jax.random.PRNGKey(42), + ) + self.assertIsNotNone(top_p_result_2) + self.assertNotEqual(top_p_result.text, top_p_result_2.text) + def test_forbidden_tokens(self): vocab = MockVocab() transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types @@ -246,6 +266,11 @@ def test_sampler_init_sample_state(self): sample_state = sampler.init_sample_state( all_input_ids, total_sampling_steps=total_sampling_steps, + include_logits=True, + forbidden_token_ids=None, + temperature=0.0, + top_p=0.95, + seed=jax.random.PRNGKey(0), ) # Check that the position indices correctly ignore padding @@ -282,6 +307,11 @@ def test_sampler_mask_tokens_after_eos_ids(self): sample_state = sampler.init_sample_state( all_input_ids, total_sampling_steps=total_sampling_steps, + include_logits=True, + forbidden_token_ids=None, + temperature=0.0, + top_p=0.95, + seed=jax.random.PRNGKey(0), ) masked_token_buffer = sampler.mask_tokens_after_eos_ids( @@ -389,8 +419,8 @@ def test_compute_attention_mask(self): time_step, seq_len, input_mask ) expected_attn_mask = jnp.array( - [[0, 0, 1, 1, 1, 0, 0, 0], - [0, 0, 1, 0, 1, 0, 0, 0]], dtype=jnp.bool_) + [[0, 0, 1, 1, 1, 0, 0, 0], [0, 0, 1, 0, 1, 0, 0, 0]], dtype=jnp.bool_ + ) self.assertTrue((attn_mask.squeeze(1) == expected_attn_mask).all()) @@ -403,8 +433,8 @@ def test_compute_attention_mask(self): ) print(attn_mask) expected_attn_mask = jnp.array( - [[0, 1, 1, 1], - [0, 1, 0, 1]], dtype=jnp.bool_) + [[0, 1, 1, 1], [0, 1, 0, 1]], dtype=jnp.bool_ + ) self.assertTrue((attn_mask.squeeze(1) == expected_attn_mask).all())