Skip to content

Commit ece7238

Browse files
author
Flax Team
committed
add top_p sampling in gemma example
PiperOrigin-RevId: 733047305
1 parent 0769411 commit ece7238

File tree

2 files changed

+103
-15
lines changed

2 files changed

+103
-15
lines changed

examples/gemma/sampler.py

Lines changed: 69 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from collections.abc import Sequence
2323
import dataclasses
2424

25-
import chex
25+
import flax
2626
from flax import nnx
2727
import modules
2828
import sow_lib
@@ -33,6 +33,21 @@
3333
import sentencepiece as spm
3434

3535

36+
def _sample_top_p(probs: jnp.ndarray, p: float, key: jax.Array) -> jnp.ndarray:
37+
"""Sample a token using top-p sampling."""
38+
probs_sorted, indices = jax.lax.top_k(probs, k=probs.shape[-1])
39+
cumsum_probs = jnp.cumsum(probs_sorted, axis=-1)
40+
mask = cumsum_probs - probs_sorted > p
41+
probs_sorted = jnp.where(mask, 0.0, probs_sorted)
42+
probs_sorted /= jnp.sum(probs_sorted, axis=-1, keepdims=True)
43+
44+
next_token = jax.random.categorical(key, logits=jnp.log(probs_sorted))
45+
46+
next_token = jnp.take_along_axis(indices, next_token[..., None], axis=-1)
47+
next_token = jnp.squeeze(next_token, axis=-1)
48+
return next_token
49+
50+
3651
def _compute_attention_masks(
3752
time_step: jax.Array, seq_len: int, input_mask: jax.Array
3853
) -> jax.Array:
@@ -60,7 +75,7 @@ def _compute_attention_masks(
6075
return ~attention_mask
6176

6277

63-
@chex.dataclass
78+
@flax.struct.dataclass
6479
class _SamplingState:
6580
"""Internal sampling state."""
6681

@@ -86,13 +101,22 @@ class _SamplingState:
86101
total_sampling_steps: int
87102

88103
# Fixed-size buffer for accumulating the output logits.
89-
logits_buffer: jnp.ndarray | None = None # [B, L, V]
104+
logits_buffer: jnp.ndarray | None # [B, L, V]
90105

91106
# List of tokens that are forbidden to be generated.
92-
forbidden_token_ids: Sequence[int] | None = None
107+
forbidden_token_ids: Sequence[int] | None
93108

94109
# Intermediate activations from the model if requested.
95-
intermediates: sow_lib.TransformerIntermediates | None = None
110+
intermediates: sow_lib.TransformerIntermediates | None
111+
112+
# Random seed for sampling.
113+
seed: jax.Array
114+
115+
# Tempurature for top_p sampling.
116+
temperature: float = flax.struct.field(pytree_node=False)
117+
118+
# Top-p sampling threshold.
119+
top_p: float = flax.struct.field(pytree_node=False)
96120

97121

98122
@dataclasses.dataclass
@@ -161,8 +185,21 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
161185
if sampler_state.forbidden_token_ids:
162186
logits = logits.at[:, :, sampler_state.forbidden_token_ids].set(-jnp.inf)
163187

164-
next_token_candidate = jnp.argmax(logits, axis=-1) # [B, 1]
165-
next_token_candidate = next_token_candidate[:, 0] # [B,]
188+
def sample_top_p(logits, key):
189+
probs = jax.nn.softmax(logits[:, -1] / sampler_state.temperature, axis=-1)
190+
next_token = _sample_top_p(probs, sampler_state.top_p, key)
191+
return next_token
192+
193+
def sample_best(logits):
194+
next_token = jnp.argmax(logits, axis=-1)
195+
next_token = next_token[:, 0]
196+
return next_token
197+
198+
if sampler_state.temperature > 0:
199+
key = jax.random.fold_in(sampler_state.seed, decoding_step)
200+
next_token_candidate = sample_top_p(logits, key)
201+
else:
202+
next_token_candidate = sample_best(logits)
166203

167204
next_token_candidate = jnp.where(
168205
decoding_step < sampler_state.num_input_tokens - 1,
@@ -200,14 +237,20 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
200237
total_sampling_steps=sampler_state.total_sampling_steps,
201238
forbidden_token_ids=sampler_state.forbidden_token_ids,
202239
intermediates=sampler_state.intermediates,
240+
temperature=sampler_state.temperature,
241+
top_p=sampler_state.top_p,
242+
seed=sampler_state.seed,
203243
)
204244

205245
def init_sample_state(
206246
self,
207247
all_input_ids: list[jax.Array],
208248
total_sampling_steps: int,
209-
include_logits: bool = False,
210-
forbidden_token_ids: Sequence[int] | None = None,
249+
include_logits: bool,
250+
forbidden_token_ids: Sequence[int] | None,
251+
temperature: float,
252+
top_p: float,
253+
seed: jax.Array,
211254
) -> _SamplingState:
212255
"""Initializes the sampling state given input prompts."""
213256
batch_size = len(all_input_ids)
@@ -259,6 +302,9 @@ def init_sample_state(
259302
intermediates=self.transformer.init_intermediates(
260303
batch_size, buffer_size, self.transformer.sow_config
261304
),
305+
temperature=temperature,
306+
top_p=top_p,
307+
seed=seed,
262308
)
263309

264310
def tokenize(self, input_string: str) -> jax.Array:
@@ -281,7 +327,7 @@ def mask_tokens_after_eos_ids(self, token_buffer):
281327
mask = jnp.less_equal(
282328
jnp.arange(token_buffer.shape[-1]), eos_indices[:, None]
283329
)
284-
masked_token_buffer = token_buffer * mask + self.vocab.pad_id()*(1 - mask)
330+
masked_token_buffer = token_buffer * mask + self.vocab.pad_id() * (1 - mask)
285331

286332
return masked_token_buffer
287333

@@ -310,17 +356,23 @@ def __call__(
310356
echo: bool = False,
311357
return_logits: bool = True,
312358
forbidden_tokens: Sequence[str] | None = None,
359+
temperature: float = 0.0,
360+
top_p: float = 0.95,
361+
seed: jax.Array | None = None,
313362
) -> SamplerOutput:
314363
"""Samples a completion of the input string.
315364
316365
Args:
317366
input_strings: input prompts to feed to the model for sampling.
318367
total_generation_steps: number of generation steps. will correspond to the
319368
longest prompt in the batch.
320-
echo: whether to return the prompt as part of the output sample.
369+
echo: whgether to return the prompt as part of the output sample.
321370
return_logits: whether to return per-step logits used during generation.
322371
forbidden_tokens: list of tokens that are forbidden to be generated. Each
323372
token must map to a single token id in the vocab.
373+
temperature: temperature for sampling.
374+
top_p: top-p sampling threshold.
375+
seed: random seed for sampling.
324376
325377
Returns:
326378
sampler_output: A SamplerOutput object containing the generated samples.
@@ -339,11 +391,17 @@ def __call__(
339391
all_input_ids = [self.tokenize(x) for x in input_strings]
340392
max_input_length = max(len(input_ids) for input_ids in all_input_ids)
341393
total_sampling_steps = max_input_length + total_generation_steps
394+
395+
if seed is None:
396+
seed = jax.random.PRNGKey(0)
342397
initial_sampling_state = self.init_sample_state(
343398
all_input_ids,
344399
include_logits=return_logits,
345400
total_sampling_steps=total_sampling_steps,
346401
forbidden_token_ids=forbidden_token_ids,
402+
temperature=temperature,
403+
top_p=top_p,
404+
seed=seed,
347405
)
348406

349407
sampling_state = self._compiled_sample_fn(initial_sampling_state)

examples/gemma/sampler_test.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import sampler as sampler_lib
2323
import sow_lib
2424
import transformer as transformer_lib
25+
import jax
2526
import jax.numpy as jnp
2627
import numpy as np
2728

@@ -108,6 +109,25 @@ def test_samples(self):
108109
result = sampler(['input string', 'hello world'], total_generation_steps=10)
109110
self.assertIsNotNone(result)
110111

112+
top_p_result = sampler(
113+
['input string', 'hello world'],
114+
total_generation_steps=10,
115+
temperature=9,
116+
top_p=0.95,
117+
)
118+
self.assertIsNotNone(top_p_result)
119+
self.assertNotEqual(result.text, top_p_result.text)
120+
121+
top_p_result_2 = sampler(
122+
['input string', 'hello world'],
123+
total_generation_steps=10,
124+
temperature=9,
125+
top_p=0.95,
126+
seed=jax.random.PRNGKey(42),
127+
)
128+
self.assertIsNotNone(top_p_result_2)
129+
self.assertNotEqual(top_p_result.text, top_p_result_2.text)
130+
111131
def test_forbidden_tokens(self):
112132
vocab = MockVocab()
113133
transformer_config = transformer_lib.TransformerConfig( # pytype: disable=wrong-arg-types
@@ -246,6 +266,11 @@ def test_sampler_init_sample_state(self):
246266
sample_state = sampler.init_sample_state(
247267
all_input_ids,
248268
total_sampling_steps=total_sampling_steps,
269+
include_logits=True,
270+
forbidden_token_ids=None,
271+
temperature=0.0,
272+
top_p=0.95,
273+
seed=jax.random.PRNGKey(0),
249274
)
250275

251276
# Check that the position indices correctly ignore padding
@@ -282,6 +307,11 @@ def test_sampler_mask_tokens_after_eos_ids(self):
282307
sample_state = sampler.init_sample_state(
283308
all_input_ids,
284309
total_sampling_steps=total_sampling_steps,
310+
include_logits=True,
311+
forbidden_token_ids=None,
312+
temperature=0.0,
313+
top_p=0.95,
314+
seed=jax.random.PRNGKey(0),
285315
)
286316

287317
masked_token_buffer = sampler.mask_tokens_after_eos_ids(
@@ -389,8 +419,8 @@ def test_compute_attention_mask(self):
389419
time_step, seq_len, input_mask
390420
)
391421
expected_attn_mask = jnp.array(
392-
[[0, 0, 1, 1, 1, 0, 0, 0],
393-
[0, 0, 1, 0, 1, 0, 0, 0]], dtype=jnp.bool_)
422+
[[0, 0, 1, 1, 1, 0, 0, 0], [0, 0, 1, 0, 1, 0, 0, 0]], dtype=jnp.bool_
423+
)
394424

395425
self.assertTrue((attn_mask.squeeze(1) == expected_attn_mask).all())
396426

@@ -403,8 +433,8 @@ def test_compute_attention_mask(self):
403433
)
404434
print(attn_mask)
405435
expected_attn_mask = jnp.array(
406-
[[0, 1, 1, 1],
407-
[0, 1, 0, 1]], dtype=jnp.bool_)
436+
[[0, 1, 1, 1], [0, 1, 0, 1]], dtype=jnp.bool_
437+
)
408438

409439
self.assertTrue((attn_mask.squeeze(1) == expected_attn_mask).all())
410440

0 commit comments

Comments
 (0)