From b28822f34322167a59efc7ca8040a931e7bd6390 Mon Sep 17 00:00:00 2001 From: Sascha Rothe Date: Thu, 6 Mar 2025 11:56:04 -0800 Subject: [PATCH] Fix position and name of Post Attention Norm. PiperOrigin-RevId: 734232594 --- examples/gemma/modules.py | 54 ++++++++++++++---------------- examples/gemma/modules_test.py | 38 ++++++++++----------- examples/gemma/transformer_test.py | 2 +- 3 files changed, 46 insertions(+), 48 deletions(-) diff --git a/examples/gemma/modules.py b/examples/gemma/modules.py index be152eda..5dc5d0fe 100644 --- a/examples/gemma/modules.py +++ b/examples/gemma/modules.py @@ -269,21 +269,21 @@ def __init__( out_features=hidden_dim, use_bias=False, rngs=rngs, - kernel_init=nn.initializers.zeros_init(), + kernel_init=nn.initializers.normal(), ) self.up_proj = nnx.Linear( in_features=features, out_features=hidden_dim, use_bias=False, rngs=rngs, - kernel_init=nn.initializers.zeros_init(), + kernel_init=nn.initializers.normal(), ) self.down_proj = nnx.Linear( in_features=hidden_dim, out_features=features, use_bias=False, rngs=rngs, - kernel_init=nn.initializers.zeros_init(), + kernel_init=nn.initializers.normal(), ) self.sow_config = sow_config @@ -337,7 +337,9 @@ def __init__( sow_config=sow_config, ) if use_post_attn_norm: - self.post_attn_norm = layers.RMSNorm(embed_dim, rngs=rngs) + self.post_attention_norm = layers.RMSNorm(embed_dim, rngs=rngs) + else: + self.post_attention_norm = None self.pre_ffw_norm = layers.RMSNorm(embed_dim, rngs=rngs) self.mlp = FeedForward( @@ -348,6 +350,8 @@ def __init__( ) if use_post_ffw_norm: self.post_ffw_norm = layers.RMSNorm(embed_dim, rngs=rngs) + else: + self.post_ffw_norm = None self.sow_config = sow_config def __call__( @@ -357,35 +361,29 @@ def __call__( cache: LayerCache | None, attn_mask: jax.Array, ) -> tuple[LayerCache | None, jax.Array]: - inputs_normalized = self.pre_attention_norm(x) + + # Attention. + attn_inputs = self.pre_attention_norm(x) cache, attn_output = self.attn( - inputs_normalized, + attn_inputs, segment_pos, cache, attn_mask, ) - attn_output += x - residual = attn_output - attn_output = self.pre_ffw_norm(attn_output) - - if self.use_post_attn_norm: - attn_output = self.post_attn_norm(attn_output) - self.sow_config.maybe_sow_rs_after_attention(attn_output, self) - - outputs = self.mlp(attn_output) - if self.use_post_ffw_norm: - outputs = self.post_ffw_norm(outputs) - outputs = residual + outputs - self.sow_config.maybe_sow_rs_after_ffw(outputs, self) - return cache, outputs - - @property - def use_post_attn_norm(self): - return hasattr(self, 'post_attn_norm') and self.post_attn_norm is not None - - @property - def use_post_ffw_norm(self): - return hasattr(self, 'post_ffw_norm') and self.post_ffw_norm is not None + if self.post_attention_norm is not None: + attn_output = self.post_attention_norm(attn_output) + x += attn_output + self.sow_config.maybe_sow_rs_after_attention(x, self) + + # Feed forward. + ffw_inputs = self.pre_ffw_norm(x) + ffw_outputs = self.mlp(ffw_inputs) + if self.post_ffw_norm is not None: + ffw_outputs = self.post_ffw_norm(ffw_outputs) + x += ffw_outputs + self.sow_config.maybe_sow_rs_after_ffw(x, self) + + return cache, x def init_cache( self, diff --git a/examples/gemma/modules_test.py b/examples/gemma/modules_test.py index 0f94a5e4..6f3140a5 100644 --- a/examples/gemma/modules_test.py +++ b/examples/gemma/modules_test.py @@ -289,7 +289,7 @@ def test_block( new_cache, outputs = block(inputs, jnp.array([[0]]), cache, attn_mask) - self.assertEqual(block.use_post_attn_norm, use_post_attn_norm) + self.assertEqual(block.post_attention_norm is not None, use_post_attn_norm) self.assertEqual(new_cache['k'].shape, expected_cache_shape) self.assertEqual(outputs.shape, expected_output_shape) @@ -319,10 +319,10 @@ def test_post_attention_norm( embed_dim, head_dim, 1, - True, - False, # use_post_ffw_norm - 1.0, - modules.AttentionType.GLOBAL, + use_post_attn_norm=True, + use_post_ffw_norm=False, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) unnormed_block = modules.Block( @@ -331,10 +331,10 @@ def test_post_attention_norm( embed_dim, head_dim, 1, - False, - False, # use_post_ffw_norm - 1.0, - modules.AttentionType.GLOBAL, + use_post_attn_norm=False, + use_post_ffw_norm=False, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -351,7 +351,7 @@ def test_post_attention_norm( all_outputs.append(outputs) normed_output, unnormed_output = all_outputs # pylint: disable=unbalanced-tuple-unpacking - self.assertFalse(jnp.not_equal(normed_output, unnormed_output).all()) + self.assertTrue(jnp.not_equal(normed_output, unnormed_output).all()) @parameterized.parameters( dict( @@ -379,10 +379,10 @@ def test_post_ffw_norm( embed_dim, head_dim, 1, - True, - True, # use_post_ffw_norm - 1.0, - modules.AttentionType.GLOBAL, + use_post_attn_norm=False, + use_post_ffw_norm=True, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) unnormed_block = modules.Block( @@ -391,10 +391,10 @@ def test_post_ffw_norm( embed_dim, head_dim, 1, - False, - False, # use_post_ffw_norm - 1.0, - modules.AttentionType.GLOBAL, + use_post_attn_norm=False, + use_post_ffw_norm=False, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -411,7 +411,7 @@ def test_post_ffw_norm( all_outputs.append(outputs) normed_output, unnormed_output = all_outputs # pylint: disable=unbalanced-tuple-unpacking - self.assertFalse(jnp.not_equal(normed_output, unnormed_output).all()) + self.assertTrue(jnp.not_equal(normed_output, unnormed_output).all()) if __name__ == '__main__': diff --git a/examples/gemma/transformer_test.py b/examples/gemma/transformer_test.py index c637a33e..420d2316 100644 --- a/examples/gemma/transformer_test.py +++ b/examples/gemma/transformer_test.py @@ -73,7 +73,7 @@ def nested_defaultdict(): )) if config.use_post_attn_norm: - params[f'layer_{layer_idx}']['post_attn_norm']['scale'] = jnp.ones(( + params[f'layer_{layer_idx}']['post_attention_norm']['scale'] = jnp.ones(( config.embed_dim, )) if config.use_post_ffw_norm: