Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add top_p sampling in gemma example #4591

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 69 additions & 11 deletions examples/gemma/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from collections.abc import Sequence
import dataclasses

import chex
import flax
from flax import nnx
import modules
import sow_lib
Expand All @@ -33,6 +33,21 @@
import sentencepiece as spm


def _sample_top_p(probs: jnp.ndarray, p: float, key: jax.Array) -> jnp.ndarray:
"""Sample a token using top-p sampling."""
probs_sorted, indices = jax.lax.top_k(probs, k=probs.shape[-1])
cumsum_probs = jnp.cumsum(probs_sorted, axis=-1)
mask = cumsum_probs - probs_sorted > p
probs_sorted = jnp.where(mask, 0.0, probs_sorted)
probs_sorted /= jnp.sum(probs_sorted, axis=-1, keepdims=True)

next_token = jax.random.categorical(key, logits=jnp.log(probs_sorted))

next_token = jnp.take_along_axis(indices, next_token[..., None], axis=-1)
next_token = jnp.squeeze(next_token, axis=-1)
return next_token


def _compute_attention_masks(
time_step: jax.Array, seq_len: int, input_mask: jax.Array
) -> jax.Array:
Expand Down Expand Up @@ -60,7 +75,7 @@ def _compute_attention_masks(
return ~attention_mask


@chex.dataclass
@flax.struct.dataclass
class _SamplingState:
"""Internal sampling state."""

Expand All @@ -86,13 +101,22 @@ class _SamplingState:
total_sampling_steps: int

# Fixed-size buffer for accumulating the output logits.
logits_buffer: jnp.ndarray | None = None # [B, L, V]
logits_buffer: jnp.ndarray | None # [B, L, V]

# List of tokens that are forbidden to be generated.
forbidden_token_ids: Sequence[int] | None = None
forbidden_token_ids: Sequence[int] | None

# Intermediate activations from the model if requested.
intermediates: sow_lib.TransformerIntermediates | None = None
intermediates: sow_lib.TransformerIntermediates | None

# Random seed for sampling.
seed: jax.Array

# Tempurature for top_p sampling.
temperature: float = flax.struct.field(pytree_node=False)

# Top-p sampling threshold.
top_p: float = flax.struct.field(pytree_node=False)


@dataclasses.dataclass
Expand Down Expand Up @@ -161,8 +185,21 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
if sampler_state.forbidden_token_ids:
logits = logits.at[:, :, sampler_state.forbidden_token_ids].set(-jnp.inf)

next_token_candidate = jnp.argmax(logits, axis=-1) # [B, 1]
next_token_candidate = next_token_candidate[:, 0] # [B,]
def sample_top_p(logits, key):
probs = jax.nn.softmax(logits[:, -1] / sampler_state.temperature, axis=-1)
next_token = _sample_top_p(probs, sampler_state.top_p, key)
return next_token

def sample_best(logits):
next_token = jnp.argmax(logits, axis=-1)
next_token = next_token[:, 0]
return next_token

if sampler_state.temperature > 0:
key = jax.random.fold_in(sampler_state.seed, decoding_step)
next_token_candidate = sample_top_p(logits, key)
else:
next_token_candidate = sample_best(logits)

next_token_candidate = jnp.where(
decoding_step < sampler_state.num_input_tokens - 1,
Expand Down Expand Up @@ -200,14 +237,20 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
total_sampling_steps=sampler_state.total_sampling_steps,
forbidden_token_ids=sampler_state.forbidden_token_ids,
intermediates=sampler_state.intermediates,
temperature=sampler_state.temperature,
top_p=sampler_state.top_p,
seed=sampler_state.seed,
)

def init_sample_state(
self,
all_input_ids: list[jax.Array],
total_sampling_steps: int,
include_logits: bool = False,
forbidden_token_ids: Sequence[int] | None = None,
include_logits: bool,
forbidden_token_ids: Sequence[int] | None,
temperature: float,
top_p: float,
seed: jax.Array,
) -> _SamplingState:
"""Initializes the sampling state given input prompts."""
batch_size = len(all_input_ids)
Expand Down Expand Up @@ -259,6 +302,9 @@ def init_sample_state(
intermediates=self.transformer.init_intermediates(
batch_size, buffer_size, self.transformer.sow_config
),
temperature=temperature,
top_p=top_p,
seed=seed,
)

def tokenize(self, input_string: str) -> jax.Array:
Expand All @@ -281,7 +327,7 @@ def mask_tokens_after_eos_ids(self, token_buffer):
mask = jnp.less_equal(
jnp.arange(token_buffer.shape[-1]), eos_indices[:, None]
)
masked_token_buffer = token_buffer * mask + self.vocab.pad_id()*(1 - mask)
masked_token_buffer = token_buffer * mask + self.vocab.pad_id() * (1 - mask)

return masked_token_buffer

Expand Down Expand Up @@ -310,17 +356,23 @@ def __call__(
echo: bool = False,
return_logits: bool = True,
forbidden_tokens: Sequence[str] | None = None,
temperature: float = 0.0,
top_p: float = 0.95,
seed: jax.Array | None = None,
) -> SamplerOutput:
"""Samples a completion of the input string.
Args:
input_strings: input prompts to feed to the model for sampling.
total_generation_steps: number of generation steps. will correspond to the
longest prompt in the batch.
echo: whether to return the prompt as part of the output sample.
echo: whgether to return the prompt as part of the output sample.
return_logits: whether to return per-step logits used during generation.
forbidden_tokens: list of tokens that are forbidden to be generated. Each
token must map to a single token id in the vocab.
temperature: temperature for sampling.
top_p: top-p sampling threshold.
seed: random seed for sampling.
Returns:
sampler_output: A SamplerOutput object containing the generated samples.
Expand All @@ -339,11 +391,17 @@ def __call__(
all_input_ids = [self.tokenize(x) for x in input_strings]
max_input_length = max(len(input_ids) for input_ids in all_input_ids)
total_sampling_steps = max_input_length + total_generation_steps

if seed is None:
seed = jax.random.PRNGKey(0)
initial_sampling_state = self.init_sample_state(
all_input_ids,
include_logits=return_logits,
total_sampling_steps=total_sampling_steps,
forbidden_token_ids=forbidden_token_ids,
temperature=temperature,
top_p=top_p,
seed=seed,
)

sampling_state = self._compiled_sample_fn(initial_sampling_state)
Expand Down
38 changes: 34 additions & 4 deletions examples/gemma/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import sampler as sampler_lib
import sow_lib
import transformer as transformer_lib
import jax
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -108,6 +109,25 @@ def test_samples(self):
result = sampler(['input string', 'hello world'], total_generation_steps=10)
self.assertIsNotNone(result)

top_p_result = sampler(
['input string', 'hello world'],
total_generation_steps=10,
temperature=9,
top_p=0.95,
)
self.assertIsNotNone(top_p_result)
self.assertNotEqual(result.text, top_p_result.text)

top_p_result_2 = sampler(
['input string', 'hello world'],
total_generation_steps=10,
temperature=9,
top_p=0.95,
seed=jax.random.PRNGKey(42),
)
self.assertIsNotNone(top_p_result_2)
self.assertNotEqual(top_p_result.text, top_p_result_2.text)

def test_forbidden_tokens(self):
vocab = MockVocab()
transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types
Expand Down Expand Up @@ -246,6 +266,11 @@ def test_sampler_init_sample_state(self):
sample_state = sampler.init_sample_state(
all_input_ids,
total_sampling_steps=total_sampling_steps,
include_logits=True,
forbidden_token_ids=None,
temperature=0.0,
top_p=0.95,
seed=jax.random.PRNGKey(0),
)

# Check that the position indices correctly ignore padding
Expand Down Expand Up @@ -282,6 +307,11 @@ def test_sampler_mask_tokens_after_eos_ids(self):
sample_state = sampler.init_sample_state(
all_input_ids,
total_sampling_steps=total_sampling_steps,
include_logits=True,
forbidden_token_ids=None,
temperature=0.0,
top_p=0.95,
seed=jax.random.PRNGKey(0),
)

masked_token_buffer = sampler.mask_tokens_after_eos_ids(
Expand Down Expand Up @@ -389,8 +419,8 @@ def test_compute_attention_mask(self):
time_step, seq_len, input_mask
)
expected_attn_mask = jnp.array(
[[0, 0, 1, 1, 1, 0, 0, 0],
[0, 0, 1, 0, 1, 0, 0, 0]], dtype=jnp.bool_)
[[0, 0, 1, 1, 1, 0, 0, 0], [0, 0, 1, 0, 1, 0, 0, 0]], dtype=jnp.bool_
)

self.assertTrue((attn_mask.squeeze(1) == expected_attn_mask).all())

Expand All @@ -403,8 +433,8 @@ def test_compute_attention_mask(self):
)
print(attn_mask)
expected_attn_mask = jnp.array(
[[0, 1, 1, 1],
[0, 1, 0, 1]], dtype=jnp.bool_)
[[0, 1, 1, 1], [0, 1, 0, 1]], dtype=jnp.bool_
)

self.assertTrue((attn_mask.squeeze(1) == expected_attn_mask).all())

Expand Down
Loading