Skip to content

Commit

Permalink
Dramatically speed up sampling compilation time
Browse files Browse the repository at this point in the history
On DD:2x2 hardware, this reduces compilation time from roughly 120s to 6s

We separate out the parameters from the model graph so that the parameters are passed as a parameter to the jitted function, rather than being kept static.

PiperOrigin-RevId: 730875380
  • Loading branch information
Flax Team committed Feb 25, 2025
1 parent 41bef07 commit 7e62fcd
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions examples/gemma/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import modules
import sow_lib
import transformer as transformer_lib
from flax.nnx import graph
from flax.nnx import statelib
import jax
import jax.numpy as jnp

Expand Down Expand Up @@ -128,17 +130,28 @@ def __init__(
vocab: vocabulary of the given model.
cache_size: size of the cache for the transformer.
"""
self.transformer = transformer
self.vocab = vocab
self.cache_size = cache_size
graphdef, state = nnx.split(transformer)
self._transformer_graphdef: graph.NodeDef = graphdef
self._transformer_state: statelib.State = state
# we separate out state and graph def so that the state can be passed as an
# argument to _sample_fn, resulting in it not being treated as a static
# arg. This greatly reduces the size of the HLO and reduces compile time
self._compiled_sample_fn = jax.jit(self._sample_fn)

@property
def transformer(self) -> transformer_lib.Transformer:
return nnx.merge(self._transformer_graphdef, self._transformer_state)

@property
def dtype(self) -> jnp.dtype:
params_state = nnx.state(self.transformer, nnx.Param)
return jax.tree_util.tree_leaves(nnx.to_flat_state(params_state))[0].dtype

def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
def _sample_step(
self, params: statelib.State, sampler_state: _SamplingState
) -> _SamplingState:
"""Performs a single sampling step."""
batch_size = sampler_state.token_buffer.shape[0]
decoding_step = jnp.asarray(sampler_state.decoding_step, dtype=jnp.int32)
Expand All @@ -152,7 +165,8 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
)
last_token = last_token.reshape((batch_size, 1))

logits, cache = self.transformer(
transformer = nnx.merge(self._transformer_graphdef, params)
logits, cache = transformer(
last_token,
step_positions,
sampler_state.cache,
Expand Down Expand Up @@ -183,7 +197,7 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
logits_buffer = sampler_state.logits_buffer

if sampler_state.intermediates is not None:
sampler_state.intermediates.merge(decoding_step, self.transformer)
sampler_state.intermediates.merge(decoding_step, transformer)

done = sampler_state.done | jnp.equal(
token_buffer[:, decoding_step + 1], self.vocab.eos_id()
Expand Down Expand Up @@ -287,12 +301,13 @@ def mask_tokens_after_eos_ids(self, token_buffer):

def _sample_fn(
self,
params: statelib.State,
initial_sampling_state: _SamplingState,
) -> _SamplingState:
"""Internal sampling function (to be jitted)."""

def sample_with_params(sampler_state: _SamplingState):
return self._sample_step(sampler_state)
return self._sample_step(params, sampler_state)

def cond_fn(sampler_state: _SamplingState):
return (
Expand Down Expand Up @@ -346,7 +361,9 @@ def __call__(
forbidden_token_ids=forbidden_token_ids,
)

sampling_state = self._compiled_sample_fn(initial_sampling_state)
sampling_state = self._compiled_sample_fn(
self._transformer_state, initial_sampling_state
)

masked_token_buffer = self.mask_tokens_after_eos_ids(
sampling_state.token_buffer
Expand Down

0 comments on commit 7e62fcd

Please sign in to comment.