From 28b5b425ee7f62fc3ca9088be045b6285e98d763 Mon Sep 17 00:00:00 2001 From: Flax Team Date: Mon, 10 Mar 2025 03:54:00 -0700 Subject: [PATCH] Dramatically speed up sampling compilation time On DD:2x2 hardware, this reduces compilation time from roughly 120s to 6s We separate out the parameters from the model graph so that the parameters are passed as a parameter to the jitted function, rather than being kept static. PiperOrigin-RevId: 735316560 --- examples/gemma/sampler.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/examples/gemma/sampler.py b/examples/gemma/sampler.py index adbcb126..89dffef5 100644 --- a/examples/gemma/sampler.py +++ b/examples/gemma/sampler.py @@ -27,6 +27,8 @@ import modules import sow_lib import transformer as transformer_lib +from flax.nnx import graph +from flax.nnx import statelib import jax import jax.numpy as jnp @@ -152,17 +154,28 @@ def __init__( vocab: vocabulary of the given model. cache_size: size of the cache for the transformer. """ - self.transformer = transformer self.vocab = vocab self.cache_size = cache_size + graphdef, state = nnx.split(transformer) + self._transformer_graphdef: graph.NodeDef = graphdef + self._transformer_state: statelib.State = state + # we separate out state and graph def so that the state can be passed as an + # argument to _sample_fn, resulting in it not being treated as a static + # arg. This greatly reduces the size of the HLO and reduces compile time self._compiled_sample_fn = jax.jit(self._sample_fn) + @property + def transformer(self) -> transformer_lib.Transformer: + return nnx.merge(self._transformer_graphdef, self._transformer_state) + @property def dtype(self) -> jnp.dtype: params_state = nnx.state(self.transformer, nnx.Param) return jax.tree_util.tree_leaves(nnx.to_flat_state(params_state))[0].dtype - def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState: + def _sample_step( + self, params: statelib.State, sampler_state: _SamplingState + ) -> _SamplingState: """Performs a single sampling step.""" batch_size = sampler_state.token_buffer.shape[0] decoding_step = jnp.asarray(sampler_state.decoding_step, dtype=jnp.int32) @@ -176,7 +189,8 @@ def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState: ) last_token = last_token.reshape((batch_size, 1)) - logits, cache = self.transformer( + transformer = nnx.merge(self._transformer_graphdef, params) + logits, cache = transformer( last_token, step_positions, sampler_state.cache, @@ -220,7 +234,7 @@ def sample_best(logits): logits_buffer = sampler_state.logits_buffer if sampler_state.intermediates is not None: - sampler_state.intermediates.merge(decoding_step, self.transformer) + sampler_state.intermediates.merge(decoding_step, transformer) done = sampler_state.done | jnp.equal( token_buffer[:, decoding_step + 1], self.vocab.eos_id() @@ -333,12 +347,13 @@ def mask_tokens_after_eos_ids(self, token_buffer): def _sample_fn( self, + params: statelib.State, initial_sampling_state: _SamplingState, ) -> _SamplingState: """Internal sampling function (to be jitted).""" def sample_with_params(sampler_state: _SamplingState): - return self._sample_step(sampler_state) + return self._sample_step(params, sampler_state) def cond_fn(sampler_state: _SamplingState): return ( @@ -404,7 +419,9 @@ def __call__( seed=seed, ) - sampling_state = self._compiled_sample_fn(initial_sampling_state) + sampling_state = self._compiled_sample_fn( + self._transformer_state, initial_sampling_state + ) masked_token_buffer = self.mask_tokens_after_eos_ids( sampling_state.token_buffer