Skip to content

Commit

Permalink
add top_p sampling in gemma example
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733047305
  • Loading branch information
Flax Team committed Mar 5, 2025
1 parent 0769411 commit ece7238
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 15 deletions.
80 changes: 69 additions & 11 deletions examples/gemma/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from collections.abc import Sequence
import dataclasses

import chex
import flax
from flax import nnx
import modules
import sow_lib
Expand All @@ -33,6 +33,21 @@
import sentencepiece as spm


def _sample_top_p(probs: jnp.ndarray, p: float, key: jax.Array) -> jnp.ndarray:
"""Sample a token using top-p sampling."""
probs_sorted, indices = jax.lax.top_k(probs, k=probs.shape[-1])
cumsum_probs = jnp.cumsum(probs_sorted, axis=-1)
mask = cumsum_probs - probs_sorted > p
probs_sorted = jnp.where(mask, 0.0, probs_sorted)
probs_sorted /= jnp.sum(probs_sorted, axis=-1, keepdims=True)

next_token = jax.random.categorical(key, logits=jnp.log(probs_sorted))

next_token = jnp.take_along_axis(indices, next_token[..., None], axis=-1)
next_token = jnp.squeeze(next_token, axis=-1)
return next_token


def _compute_attention_masks(
time_step: jax.Array, seq_len: int, input_mask: jax.Array
) -> jax.Array:
Expand Down Expand Up @@ -60,7 +75,7 @@ def _compute_attention_masks(
return ~attention_mask


@chex.dataclass
@flax.struct.dataclass
class _SamplingState:
"""Internal sampling state."""

Expand All @@ -86,13 +101,22 @@ class _SamplingState:
total_sampling_steps: int

# Fixed-size buffer for accumulating the output logits.
logits_buffer: jnp.ndarray | None = None # [B, L, V]
logits_buffer: jnp.ndarray | None # [B, L, V]

# List of tokens that are forbidden to be generated.
forbidden_token_ids: Sequence[int] | None = None
forbidden_token_ids: Sequence[int] | None

# Intermediate activations from the model if requested.
intermediates: sow_lib.TransformerIntermediates | None = None
intermediates: sow_lib.TransformerIntermediates | None

# Random seed for sampling.
seed: jax.Array

# Tempurature for top_p sampling.
temperature: float = flax.struct.field(pytree_node=False)

# Top-p sampling threshold.
top_p: float = flax.struct.field(pytree_node=False)


@dataclasses.dataclass
Expand Down Expand Up @@ -161,8 +185,21 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
if sampler_state.forbidden_token_ids:
logits = logits.at[:, :, sampler_state.forbidden_token_ids].set(-jnp.inf)

next_token_candidate = jnp.argmax(logits, axis=-1) # [B, 1]
next_token_candidate = next_token_candidate[:, 0] # [B,]
def sample_top_p(logits, key):
probs = jax.nn.softmax(logits[:, -1] / sampler_state.temperature, axis=-1)
next_token = _sample_top_p(probs, sampler_state.top_p, key)
return next_token

def sample_best(logits):
next_token = jnp.argmax(logits, axis=-1)
next_token = next_token[:, 0]
return next_token

if sampler_state.temperature > 0:
key = jax.random.fold_in(sampler_state.seed, decoding_step)
next_token_candidate = sample_top_p(logits, key)
else:
next_token_candidate = sample_best(logits)

next_token_candidate = jnp.where(
decoding_step < sampler_state.num_input_tokens - 1,
Expand Down Expand Up @@ -200,14 +237,20 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
total_sampling_steps=sampler_state.total_sampling_steps,
forbidden_token_ids=sampler_state.forbidden_token_ids,
intermediates=sampler_state.intermediates,
temperature=sampler_state.temperature,
top_p=sampler_state.top_p,
seed=sampler_state.seed,
)

def init_sample_state(
self,
all_input_ids: list[jax.Array],
total_sampling_steps: int,
include_logits: bool = False,
forbidden_token_ids: Sequence[int] | None = None,
include_logits: bool,
forbidden_token_ids: Sequence[int] | None,
temperature: float,
top_p: float,
seed: jax.Array,
) -> _SamplingState:
"""Initializes the sampling state given input prompts."""
batch_size = len(all_input_ids)
Expand Down Expand Up @@ -259,6 +302,9 @@ def init_sample_state(
intermediates=self.transformer.init_intermediates(
batch_size, buffer_size, self.transformer.sow_config
),
temperature=temperature,
top_p=top_p,
seed=seed,
)

def tokenize(self, input_string: str) -> jax.Array:
Expand All @@ -281,7 +327,7 @@ def mask_tokens_after_eos_ids(self, token_buffer):
mask = jnp.less_equal(
jnp.arange(token_buffer.shape[-1]), eos_indices[:, None]
)
masked_token_buffer = token_buffer * mask + self.vocab.pad_id()*(1 - mask)
masked_token_buffer = token_buffer * mask + self.vocab.pad_id() * (1 - mask)

return masked_token_buffer

Expand Down Expand Up @@ -310,17 +356,23 @@ def __call__(
echo: bool = False,
return_logits: bool = True,
forbidden_tokens: Sequence[str] | None = None,
temperature: float = 0.0,
top_p: float = 0.95,
seed: jax.Array | None = None,
) -> SamplerOutput:
"""Samples a completion of the input string.
Args:
input_strings: input prompts to feed to the model for sampling.
total_generation_steps: number of generation steps. will correspond to the
longest prompt in the batch.
echo: whether to return the prompt as part of the output sample.
echo: whgether to return the prompt as part of the output sample.
return_logits: whether to return per-step logits used during generation.
forbidden_tokens: list of tokens that are forbidden to be generated. Each
token must map to a single token id in the vocab.
temperature: temperature for sampling.
top_p: top-p sampling threshold.
seed: random seed for sampling.
Returns:
sampler_output: A SamplerOutput object containing the generated samples.
Expand All @@ -339,11 +391,17 @@ def __call__(
all_input_ids = [self.tokenize(x) for x in input_strings]
max_input_length = max(len(input_ids) for input_ids in all_input_ids)
total_sampling_steps = max_input_length + total_generation_steps

if seed is None:
seed = jax.random.PRNGKey(0)
initial_sampling_state = self.init_sample_state(
all_input_ids,
include_logits=return_logits,
total_sampling_steps=total_sampling_steps,
forbidden_token_ids=forbidden_token_ids,
temperature=temperature,
top_p=top_p,
seed=seed,
)

sampling_state = self._compiled_sample_fn(initial_sampling_state)
Expand Down
38 changes: 34 additions & 4 deletions examples/gemma/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import sampler as sampler_lib
import sow_lib
import transformer as transformer_lib
import jax
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -108,6 +109,25 @@ def test_samples(self):
result = sampler(['input string', 'hello world'], total_generation_steps=10)
self.assertIsNotNone(result)

top_p_result = sampler(
['input string', 'hello world'],
total_generation_steps=10,
temperature=9,
top_p=0.95,
)
self.assertIsNotNone(top_p_result)
self.assertNotEqual(result.text, top_p_result.text)

top_p_result_2 = sampler(
['input string', 'hello world'],
total_generation_steps=10,
temperature=9,
top_p=0.95,
seed=jax.random.PRNGKey(42),
)
self.assertIsNotNone(top_p_result_2)
self.assertNotEqual(top_p_result.text, top_p_result_2.text)

def test_forbidden_tokens(self):
vocab = MockVocab()
transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types
Expand Down Expand Up @@ -246,6 +266,11 @@ def test_sampler_init_sample_state(self):
sample_state = sampler.init_sample_state(
all_input_ids,
total_sampling_steps=total_sampling_steps,
include_logits=True,
forbidden_token_ids=None,
temperature=0.0,
top_p=0.95,
seed=jax.random.PRNGKey(0),
)

# Check that the position indices correctly ignore padding
Expand Down Expand Up @@ -282,6 +307,11 @@ def test_sampler_mask_tokens_after_eos_ids(self):
sample_state = sampler.init_sample_state(
all_input_ids,
total_sampling_steps=total_sampling_steps,
include_logits=True,
forbidden_token_ids=None,
temperature=0.0,
top_p=0.95,
seed=jax.random.PRNGKey(0),
)

masked_token_buffer = sampler.mask_tokens_after_eos_ids(
Expand Down Expand Up @@ -389,8 +419,8 @@ def test_compute_attention_mask(self):
time_step, seq_len, input_mask
)
expected_attn_mask = jnp.array(
[[0, 0, 1, 1, 1, 0, 0, 0],
[0, 0, 1, 0, 1, 0, 0, 0]], dtype=jnp.bool_)
[[0, 0, 1, 1, 1, 0, 0, 0], [0, 0, 1, 0, 1, 0, 0, 0]], dtype=jnp.bool_
)

self.assertTrue((attn_mask.squeeze(1) == expected_attn_mask).all())

Expand All @@ -403,8 +433,8 @@ def test_compute_attention_mask(self):
)
print(attn_mask)
expected_attn_mask = jnp.array(
[[0, 1, 1, 1],
[0, 1, 0, 1]], dtype=jnp.bool_)
[[0, 1, 1, 1], [0, 1, 0, 1]], dtype=jnp.bool_
)

self.assertTrue((attn_mask.squeeze(1) == expected_attn_mask).all())

Expand Down

0 comments on commit ece7238

Please sign in to comment.