From 787a820f838c75809eaf63fa7189b0957c28340e Mon Sep 17 00:00:00 2001 From: Flax Team Date: Tue, 4 Mar 2025 04:37:20 -0800 Subject: [PATCH] Add Sow Config to from_params constructor. PiperOrigin-RevId: 733277422 --- examples/gemma/modules.py | 61 +++++++++++++++++------------- examples/gemma/modules_test.py | 44 ++++++++++++--------- examples/gemma/sow_lib.py | 2 +- examples/gemma/transformer.py | 55 +++++++++++++++++++++++++-- examples/gemma/transformer_test.py | 2 +- 5 files changed, 115 insertions(+), 49 deletions(-) diff --git a/examples/gemma/modules.py b/examples/gemma/modules.py index 5eb6fa578..1cdf44401 100644 --- a/examples/gemma/modules.py +++ b/examples/gemma/modules.py @@ -33,6 +33,7 @@ Shape = Sequence[Union[int, Any]] K_MASK = -2.3819763e38 # Set to a large negative number. +DEFAULT_ROPE_BASE_FREQUENCY = 10_000 class AttentionType(enum.Enum): @@ -80,9 +81,11 @@ def __init__( num_kv_heads: int, features: int, head_dim: int, + query_pre_attn_scalar: float, attn_type: AttentionType, *, rngs: nnx.Rngs, + rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY, attn_logits_soft_cap: float | None = None, sliding_window_size: int | None = None, sow_config: sow_lib.SowConfig = sow_lib.SowConfig() @@ -92,6 +95,7 @@ def __init__( '`sliding_window_size` must be set if `attn_type` is Local Sliding.' ) + self.query_pre_attn_scalar = query_pre_attn_scalar self.attn_type = attn_type self.sliding_window_size = sliding_window_size self.attn_logits_soft_cap = attn_logits_soft_cap @@ -100,6 +104,7 @@ def __init__( shape=(num_heads, head_dim, features), rngs=rngs, ) + self.rope_base_frequency = rope_base_frequency self.sow_config = sow_config if num_heads == num_kv_heads: @@ -139,12 +144,14 @@ def __call__( query_proj, segment_pos, head_dim=self.head_dim, + max_wavelength=self.rope_base_frequency, ) - query_scaled = query_proj * self.head_dim**-0.5 + query_scaled = query_proj * self.query_pre_attn_scalar key_proj = positional_embeddings.apply_rope( key_proj, segment_pos, head_dim=self.head_dim, + max_wavelength=self.rope_base_frequency, ) # Cache is left aligned. @@ -253,21 +260,21 @@ def __init__( out_features=hidden_dim, use_bias=False, rngs=rngs, - kernel_init=nn.initializers.zeros_init(), + kernel_init=nn.initializers.normal(), ) self.up_proj = nnx.Linear( in_features=features, out_features=hidden_dim, use_bias=False, rngs=rngs, - kernel_init=nn.initializers.zeros_init(), + kernel_init=nn.initializers.normal(), ) self.down_proj = nnx.Linear( in_features=hidden_dim, out_features=features, use_bias=False, rngs=rngs, - kernel_init=nn.initializers.zeros_init(), + kernel_init=nn.initializers.normal(), ) self.sow_config = sow_config @@ -295,9 +302,11 @@ def __init__( hidden_dim: int, use_post_attn_norm: bool, use_post_ffw_norm: bool, + query_pre_attn_scalar: float, attn_type: AttentionType, *, rngs: nnx.Rngs, + rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY, attn_logits_soft_cap: float | None = None, sliding_window_size: int | None = None, sow_config: sow_lib.SowConfig = sow_lib.SowConfig() @@ -308,14 +317,17 @@ def __init__( num_kv_heads=num_kv_heads, features=embed_dim, head_dim=head_dim, + query_pre_attn_scalar=query_pre_attn_scalar, attn_type=attn_type, + rope_base_frequency=rope_base_frequency, attn_logits_soft_cap=attn_logits_soft_cap, sliding_window_size=sliding_window_size, rngs=rngs, sow_config=sow_config, ) - if use_post_attn_norm: - self.post_attn_norm = layers.RMSNorm(embed_dim, rngs=rngs) + self.use_post_attn_norm = use_post_attn_norm + if self.use_post_attn_norm: + self.post_attention_norm = layers.RMSNorm(embed_dim, rngs=rngs) self.pre_ffw_norm = layers.RMSNorm(embed_dim, rngs=rngs) self.mlp = FeedForward( @@ -324,7 +336,8 @@ def __init__( rngs=rngs, sow_config=sow_config, ) - if use_post_ffw_norm: + self.use_post_ffw_norm = use_post_ffw_norm + if self.use_post_ffw_norm: self.post_ffw_norm = layers.RMSNorm(embed_dim, rngs=rngs) self.sow_config = sow_config @@ -335,35 +348,29 @@ def __call__( cache: LayerCache | None, attn_mask: jax.Array, ) -> tuple[LayerCache | None, jax.Array]: - inputs_normalized = self.pre_attention_norm(x) + + # Attention. + attn_inputs = self.pre_attention_norm(x) cache, attn_output = self.attn( - inputs_normalized, + attn_inputs, segment_pos, cache, attn_mask, ) - attn_output += x - residual = attn_output - attn_output = self.pre_ffw_norm(attn_output) - if self.use_post_attn_norm: - attn_output = self.post_attn_norm(attn_output) - self.sow_config.maybe_sow_rs_after_attention(attn_output, self) + attn_output = self.post_attention_norm(attn_output) + x += attn_output + self.sow_config.maybe_sow_rs_after_attention(x, self) - outputs = self.mlp(attn_output) + # Feed forward. + ffw_inputs = self.pre_ffw_norm(x) + ffw_outputs = self.mlp(ffw_inputs) if self.use_post_ffw_norm: - outputs = self.post_ffw_norm(outputs) - outputs = residual + outputs - self.sow_config.maybe_sow_rs_after_ffw(outputs, self) - return cache, outputs + ffw_outputs = self.post_ffw_norm(ffw_outputs) + x += ffw_outputs + self.sow_config.maybe_sow_rs_after_ffw(x, self) - @property - def use_post_attn_norm(self): - return hasattr(self, 'post_attn_norm') and self.post_attn_norm is not None - - @property - def use_post_ffw_norm(self): - return hasattr(self, 'post_ffw_norm') and self.post_ffw_norm is not None + return cache, x def init_cache( self, diff --git a/examples/gemma/modules_test.py b/examples/gemma/modules_test.py index 7439585cc..8e5f77f22 100644 --- a/examples/gemma/modules_test.py +++ b/examples/gemma/modules_test.py @@ -78,6 +78,7 @@ def test_head_dim(self, head_dim): num_kv_heads=4, features=5, head_dim=head_dim, + query_pre_attn_scalar=1.0, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -107,6 +108,7 @@ def test_use_qkv_einsum( num_kv_heads=num_kv_heads, features=5, head_dim=8, + query_pre_attn_scalar=1.0, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -144,7 +146,8 @@ def test_attention( num_heads, features, head_dim, - modules.AttentionType.GLOBAL, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) cache = attn.init_cache( @@ -177,7 +180,8 @@ def test_sliding_window(self, sliding_window_size): num_heads, features, head_dim, - modules.AttentionType.GLOBAL, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) cache = attn.init_cache( @@ -191,7 +195,8 @@ def test_sliding_window(self, sliding_window_size): num_heads, features, head_dim, - modules.AttentionType.LOCAL_SLIDING, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.LOCAL_SLIDING, sliding_window_size=sliding_window_size, rngs=nnx.Rngs(params=0), ) @@ -272,6 +277,7 @@ def test_block( 1, use_post_attn_norm, use_post_ffw_norm, + 1.0, modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -313,9 +319,10 @@ def test_post_attention_norm( embed_dim, head_dim, 1, - True, - False, # use_post_ffw_norm - modules.AttentionType.GLOBAL, + use_post_attn_norm=True, + use_post_ffw_norm=False, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) unnormed_block = modules.Block( @@ -324,9 +331,10 @@ def test_post_attention_norm( embed_dim, head_dim, 1, - False, - False, # use_post_ffw_norm - modules.AttentionType.GLOBAL, + use_post_attn_norm=False, + use_post_ffw_norm=False, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -343,7 +351,7 @@ def test_post_attention_norm( all_outputs.append(outputs) normed_output, unnormed_output = all_outputs # pylint: disable=unbalanced-tuple-unpacking - self.assertFalse(jnp.not_equal(normed_output, unnormed_output).all()) + self.assertTrue(jnp.not_equal(normed_output, unnormed_output).all()) @parameterized.parameters( dict( @@ -371,9 +379,10 @@ def test_post_ffw_norm( embed_dim, head_dim, 1, - True, - True, # use_post_ffw_norm - modules.AttentionType.GLOBAL, + use_post_attn_norm=False, + use_post_ffw_norm=True, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) unnormed_block = modules.Block( @@ -382,9 +391,10 @@ def test_post_ffw_norm( embed_dim, head_dim, 1, - False, - False, # use_post_ffw_norm - modules.AttentionType.GLOBAL, + use_post_attn_norm=False, + use_post_ffw_norm=False, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -401,7 +411,7 @@ def test_post_ffw_norm( all_outputs.append(outputs) normed_output, unnormed_output = all_outputs # pylint: disable=unbalanced-tuple-unpacking - self.assertFalse(jnp.not_equal(normed_output, unnormed_output).all()) + self.assertTrue(jnp.not_equal(normed_output, unnormed_output).all()) if __name__ == '__main__': diff --git a/examples/gemma/sow_lib.py b/examples/gemma/sow_lib.py index 2a9c13018..9cb408c10 100644 --- a/examples/gemma/sow_lib.py +++ b/examples/gemma/sow_lib.py @@ -41,7 +41,7 @@ def merge(self, decoding_step, layer: nnx.Module): value = getattr(self, field.name) if value is None: continue - # We but mlp and attn intermediates into this class without any further + # We put mlp and attn intermediates into this class without any further # nesting. So we have to retrieve the intermediates from the correct # sub-module. try: diff --git a/examples/gemma/transformer.py b/examples/gemma/transformer.py index cdf607c1a..f52ef9478 100644 --- a/examples/gemma/transformer.py +++ b/examples/gemma/transformer.py @@ -18,6 +18,8 @@ from collections.abc import Iterable import dataclasses +import enum +import functools from typing import Any from flax import nnx @@ -32,6 +34,19 @@ Cache = dict[str, modules.LayerCache] +class QueryPreAttentionNormalisation(enum.Enum): + """Initialization strategy.""" + + # Whether to scale the query by 1/sqrt(head_dim) + BY_ONE_OVER_SQRT_HEAD_DIM = enum.auto() + + # Whether to scale the query by `embed_dim // num_heads` + BY_EMBED_DIM_DIV_NUM_HEADS = enum.auto() + + # Whether to scale the query by `1/sqrt(embed_dim // num_heads)` + BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS = enum.auto() + + @dataclasses.dataclass(frozen=True) class TransformerConfig: """Configuration for the gemma transformer.""" @@ -47,9 +62,25 @@ class TransformerConfig: use_post_attn_norm: bool use_post_ffw_norm: bool attention_types: Iterable[modules.AttentionType] + query_pre_attn_norm: QueryPreAttentionNormalisation = ( + QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM + ) attn_logits_soft_cap: float | None = None + transpose_gating_einsum: bool = False + local_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY + global_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY sliding_window_size: int | None = None + def query_pre_attn_scalar(self) -> float: + """Returns the scalar to multiply the query by before attention.""" + match self.query_pre_attn_norm: + case QueryPreAttentionNormalisation.BY_EMBED_DIM_DIV_NUM_HEADS: + return self.embed_dim // self.num_heads + case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS: # pylint: disable=line-too-long + return (self.embed_dim // self.num_heads) ** -0.5 + case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM | _: + return self.head_dim**-0.5 + @classmethod def from_path(cls, path: str) -> TransformerConfig: """Creates a TransformerConfig from loaded parameters.""" @@ -175,6 +206,7 @@ def gemma_9b(cls): def _map_linen_var_names(key: tuple[str, ...]) -> tuple[str | int, ...]: + """Maps linen variable names to nnx variable names.""" new_key = [] for k in key: if k.startswith('layer_'): @@ -198,8 +230,12 @@ def _assign_linen_params_to_nnx_state( state: dict[tuple[str, ...], Any], mapped_path: tuple[str | int, ...], val: Any, + transpose_gating_einsum: bool, ) -> dict[tuple[str, ...], Any]: + """Splits and maybe transposes gate_proj.""" if 'gate_proj' in mapped_path: + if transpose_gating_einsum: + val = jnp.swapaxes(val, 1, 2) state[mapped_path].value = val[0] state[mapped_path[:-2] + ('up_proj', 'kernel')].value = val[1] else: @@ -212,15 +248,24 @@ class Transformer(nnx.Module): @classmethod def from_params( - cls, params: params_lib.Params, config: None | TransformerConfig = None + cls, + params: params_lib.Params, + config: None | TransformerConfig = None, + sow_config: sow_lib.SowConfig = sow_lib.SowConfig(), ) -> Transformer: if config is None: config = TransformerConfig.from_params(params) + assign_val_fn = functools.partial( + _assign_linen_params_to_nnx_state, + transpose_gating_einsum=config.transpose_gating_einsum, + ) return helpers.module_from_linen_variables( - module_factory=lambda: cls(config, rngs=nnx.Rngs(params=0)), + module_factory=lambda: cls( + config, rngs=nnx.Rngs(params=0), sow_config=sow_config + ), variables=params['transformer'], map_key_fn=_map_linen_var_names, - assign_val_fn=_assign_linen_params_to_nnx_state, + assign_val_fn=assign_val_fn, ) def __init__( @@ -247,7 +292,11 @@ def __init__( use_post_ffw_norm=config.use_post_ffw_norm, attn_logits_soft_cap=config.attn_logits_soft_cap, attn_type=attn_type, + query_pre_attn_scalar=config.query_pre_attn_scalar(), rngs=rngs, + rope_base_frequency=config.local_base_frequency + if attn_type == modules.AttentionType.LOCAL_SLIDING + else config.global_base_frequency, sow_config=sow_config, ) for _, attn_type in zip( diff --git a/examples/gemma/transformer_test.py b/examples/gemma/transformer_test.py index c637a33e5..420d23169 100644 --- a/examples/gemma/transformer_test.py +++ b/examples/gemma/transformer_test.py @@ -73,7 +73,7 @@ def nested_defaultdict(): )) if config.use_post_attn_norm: - params[f'layer_{layer_idx}']['post_attn_norm']['scale'] = jnp.ones(( + params[f'layer_{layer_idx}']['post_attention_norm']['scale'] = jnp.ones(( config.embed_dim, )) if config.use_post_ffw_norm: