22
22
from collections .abc import Sequence
23
23
import dataclasses
24
24
25
- import chex
25
+ import flax
26
26
from flax import nnx
27
27
import modules
28
28
import sow_lib
33
33
import sentencepiece as spm
34
34
35
35
36
+ def _sample_top_p (probs : jnp .ndarray , p : float , key : jax .Array ) -> jnp .ndarray :
37
+ """Sample a token using top-p sampling."""
38
+ probs_sorted , indices = jax .lax .top_k (probs , k = probs .shape [- 1 ])
39
+ cumsum_probs = jnp .cumsum (probs_sorted , axis = - 1 )
40
+ mask = cumsum_probs - probs_sorted > p
41
+ probs_sorted = jnp .where (mask , 0.0 , probs_sorted )
42
+ probs_sorted /= jnp .sum (probs_sorted , axis = - 1 , keepdims = True )
43
+
44
+ next_token = jax .random .categorical (key , logits = jnp .log (probs_sorted ))
45
+
46
+ next_token = jnp .take_along_axis (indices , next_token [..., None ], axis = - 1 )
47
+ next_token = jnp .squeeze (next_token , axis = - 1 )
48
+ return next_token
49
+
50
+
36
51
def _compute_attention_masks (
37
52
time_step : jax .Array , seq_len : int , input_mask : jax .Array
38
53
) -> jax .Array :
@@ -60,7 +75,7 @@ def _compute_attention_masks(
60
75
return ~ attention_mask
61
76
62
77
63
- @chex .dataclass
78
+ @flax . struct .dataclass
64
79
class _SamplingState :
65
80
"""Internal sampling state."""
66
81
@@ -86,13 +101,22 @@ class _SamplingState:
86
101
total_sampling_steps : int
87
102
88
103
# Fixed-size buffer for accumulating the output logits.
89
- logits_buffer : jnp .ndarray | None = None # [B, L, V]
104
+ logits_buffer : jnp .ndarray | None # [B, L, V]
90
105
91
106
# List of tokens that are forbidden to be generated.
92
- forbidden_token_ids : Sequence [int ] | None = None
107
+ forbidden_token_ids : Sequence [int ] | None
93
108
94
109
# Intermediate activations from the model if requested.
95
- intermediates : sow_lib .TransformerIntermediates | None = None
110
+ intermediates : sow_lib .TransformerIntermediates | None
111
+
112
+ # Random seed for sampling.
113
+ seed : jax .Array
114
+
115
+ # Tempurature for top_p sampling.
116
+ temperature : float = flax .struct .field (pytree_node = False )
117
+
118
+ # Top-p sampling threshold.
119
+ top_p : float = flax .struct .field (pytree_node = False )
96
120
97
121
98
122
@dataclasses .dataclass
@@ -161,8 +185,21 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
161
185
if sampler_state .forbidden_token_ids :
162
186
logits = logits .at [:, :, sampler_state .forbidden_token_ids ].set (- jnp .inf )
163
187
164
- next_token_candidate = jnp .argmax (logits , axis = - 1 ) # [B, 1]
165
- next_token_candidate = next_token_candidate [:, 0 ] # [B,]
188
+ def sample_top_p (logits , key ):
189
+ probs = jax .nn .softmax (logits [:, - 1 ] / sampler_state .temperature , axis = - 1 )
190
+ next_token = _sample_top_p (probs , sampler_state .top_p , key )
191
+ return next_token
192
+
193
+ def sample_best (logits ):
194
+ next_token = jnp .argmax (logits , axis = - 1 )
195
+ next_token = next_token [:, 0 ]
196
+ return next_token
197
+
198
+ if sampler_state .temperature > 0 :
199
+ key = jax .random .fold_in (sampler_state .seed , decoding_step )
200
+ next_token_candidate = sample_top_p (logits , key )
201
+ else :
202
+ next_token_candidate = sample_best (logits )
166
203
167
204
next_token_candidate = jnp .where (
168
205
decoding_step < sampler_state .num_input_tokens - 1 ,
@@ -200,14 +237,20 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
200
237
total_sampling_steps = sampler_state .total_sampling_steps ,
201
238
forbidden_token_ids = sampler_state .forbidden_token_ids ,
202
239
intermediates = sampler_state .intermediates ,
240
+ temperature = sampler_state .temperature ,
241
+ top_p = sampler_state .top_p ,
242
+ seed = sampler_state .seed ,
203
243
)
204
244
205
245
def init_sample_state (
206
246
self ,
207
247
all_input_ids : list [jax .Array ],
208
248
total_sampling_steps : int ,
209
- include_logits : bool = False ,
210
- forbidden_token_ids : Sequence [int ] | None = None ,
249
+ include_logits : bool ,
250
+ forbidden_token_ids : Sequence [int ] | None ,
251
+ temperature : float ,
252
+ top_p : float ,
253
+ seed : jax .Array ,
211
254
) -> _SamplingState :
212
255
"""Initializes the sampling state given input prompts."""
213
256
batch_size = len (all_input_ids )
@@ -259,6 +302,9 @@ def init_sample_state(
259
302
intermediates = self .transformer .init_intermediates (
260
303
batch_size , buffer_size , self .transformer .sow_config
261
304
),
305
+ temperature = temperature ,
306
+ top_p = top_p ,
307
+ seed = seed ,
262
308
)
263
309
264
310
def tokenize (self , input_string : str ) -> jax .Array :
@@ -281,7 +327,7 @@ def mask_tokens_after_eos_ids(self, token_buffer):
281
327
mask = jnp .less_equal (
282
328
jnp .arange (token_buffer .shape [- 1 ]), eos_indices [:, None ]
283
329
)
284
- masked_token_buffer = token_buffer * mask + self .vocab .pad_id ()* (1 - mask )
330
+ masked_token_buffer = token_buffer * mask + self .vocab .pad_id () * (1 - mask )
285
331
286
332
return masked_token_buffer
287
333
@@ -310,17 +356,23 @@ def __call__(
310
356
echo : bool = False ,
311
357
return_logits : bool = True ,
312
358
forbidden_tokens : Sequence [str ] | None = None ,
359
+ temperature : float = 0.0 ,
360
+ top_p : float = 0.95 ,
361
+ seed : jax .Array | None = None ,
313
362
) -> SamplerOutput :
314
363
"""Samples a completion of the input string.
315
364
316
365
Args:
317
366
input_strings: input prompts to feed to the model for sampling.
318
367
total_generation_steps: number of generation steps. will correspond to the
319
368
longest prompt in the batch.
320
- echo: whether to return the prompt as part of the output sample.
369
+ echo: whgether to return the prompt as part of the output sample.
321
370
return_logits: whether to return per-step logits used during generation.
322
371
forbidden_tokens: list of tokens that are forbidden to be generated. Each
323
372
token must map to a single token id in the vocab.
373
+ temperature: temperature for sampling.
374
+ top_p: top-p sampling threshold.
375
+ seed: random seed for sampling.
324
376
325
377
Returns:
326
378
sampler_output: A SamplerOutput object containing the generated samples.
@@ -339,11 +391,17 @@ def __call__(
339
391
all_input_ids = [self .tokenize (x ) for x in input_strings ]
340
392
max_input_length = max (len (input_ids ) for input_ids in all_input_ids )
341
393
total_sampling_steps = max_input_length + total_generation_steps
394
+
395
+ if seed is None :
396
+ seed = jax .random .PRNGKey (0 )
342
397
initial_sampling_state = self .init_sample_state (
343
398
all_input_ids ,
344
399
include_logits = return_logits ,
345
400
total_sampling_steps = total_sampling_steps ,
346
401
forbidden_token_ids = forbidden_token_ids ,
402
+ temperature = temperature ,
403
+ top_p = top_p ,
404
+ seed = seed ,
347
405
)
348
406
349
407
sampling_state = self ._compiled_sample_fn (initial_sampling_state )
0 commit comments