Skip to content

Commit

Permalink
Fix position and name of Post Attention Norm.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 734232594
  • Loading branch information
casaro authored and Flax Authors committed Mar 6, 2025
1 parent a149b6d commit b28822f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 48 deletions.
54 changes: 26 additions & 28 deletions examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,21 +269,21 @@ def __init__(
out_features=hidden_dim,
use_bias=False,
rngs=rngs,
kernel_init=nn.initializers.zeros_init(),
kernel_init=nn.initializers.normal(),
)
self.up_proj = nnx.Linear(
in_features=features,
out_features=hidden_dim,
use_bias=False,
rngs=rngs,
kernel_init=nn.initializers.zeros_init(),
kernel_init=nn.initializers.normal(),
)
self.down_proj = nnx.Linear(
in_features=hidden_dim,
out_features=features,
use_bias=False,
rngs=rngs,
kernel_init=nn.initializers.zeros_init(),
kernel_init=nn.initializers.normal(),
)
self.sow_config = sow_config

Expand Down Expand Up @@ -337,7 +337,9 @@ def __init__(
sow_config=sow_config,
)
if use_post_attn_norm:
self.post_attn_norm = layers.RMSNorm(embed_dim, rngs=rngs)
self.post_attention_norm = layers.RMSNorm(embed_dim, rngs=rngs)
else:
self.post_attention_norm = None

self.pre_ffw_norm = layers.RMSNorm(embed_dim, rngs=rngs)
self.mlp = FeedForward(
Expand All @@ -348,6 +350,8 @@ def __init__(
)
if use_post_ffw_norm:
self.post_ffw_norm = layers.RMSNorm(embed_dim, rngs=rngs)
else:
self.post_ffw_norm = None
self.sow_config = sow_config

def __call__(
Expand All @@ -357,35 +361,29 @@ def __call__(
cache: LayerCache | None,
attn_mask: jax.Array,
) -> tuple[LayerCache | None, jax.Array]:
inputs_normalized = self.pre_attention_norm(x)

# Attention.
attn_inputs = self.pre_attention_norm(x)
cache, attn_output = self.attn(
inputs_normalized,
attn_inputs,
segment_pos,
cache,
attn_mask,
)
attn_output += x
residual = attn_output
attn_output = self.pre_ffw_norm(attn_output)

if self.use_post_attn_norm:
attn_output = self.post_attn_norm(attn_output)
self.sow_config.maybe_sow_rs_after_attention(attn_output, self)

outputs = self.mlp(attn_output)
if self.use_post_ffw_norm:
outputs = self.post_ffw_norm(outputs)
outputs = residual + outputs
self.sow_config.maybe_sow_rs_after_ffw(outputs, self)
return cache, outputs

@property
def use_post_attn_norm(self):
return hasattr(self, 'post_attn_norm') and self.post_attn_norm is not None

@property
def use_post_ffw_norm(self):
return hasattr(self, 'post_ffw_norm') and self.post_ffw_norm is not None
if self.post_attention_norm is not None:
attn_output = self.post_attention_norm(attn_output)
x += attn_output
self.sow_config.maybe_sow_rs_after_attention(x, self)

# Feed forward.
ffw_inputs = self.pre_ffw_norm(x)
ffw_outputs = self.mlp(ffw_inputs)
if self.post_ffw_norm is not None:
ffw_outputs = self.post_ffw_norm(ffw_outputs)
x += ffw_outputs
self.sow_config.maybe_sow_rs_after_ffw(x, self)

return cache, x

def init_cache(
self,
Expand Down
38 changes: 19 additions & 19 deletions examples/gemma/modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def test_block(

new_cache, outputs = block(inputs, jnp.array([[0]]), cache, attn_mask)

self.assertEqual(block.use_post_attn_norm, use_post_attn_norm)
self.assertEqual(block.post_attention_norm is not None, use_post_attn_norm)
self.assertEqual(new_cache['k'].shape, expected_cache_shape)
self.assertEqual(outputs.shape, expected_output_shape)

Expand Down Expand Up @@ -319,10 +319,10 @@ def test_post_attention_norm(
embed_dim,
head_dim,
1,
True,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
use_post_attn_norm=True,
use_post_ffw_norm=False,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
unnormed_block = modules.Block(
Expand All @@ -331,10 +331,10 @@ def test_post_attention_norm(
embed_dim,
head_dim,
1,
False,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
use_post_attn_norm=False,
use_post_ffw_norm=False,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)

Expand All @@ -351,7 +351,7 @@ def test_post_attention_norm(
all_outputs.append(outputs)

normed_output, unnormed_output = all_outputs # pylint: disable=unbalanced-tuple-unpacking
self.assertFalse(jnp.not_equal(normed_output, unnormed_output).all())
self.assertTrue(jnp.not_equal(normed_output, unnormed_output).all())

@parameterized.parameters(
dict(
Expand Down Expand Up @@ -379,10 +379,10 @@ def test_post_ffw_norm(
embed_dim,
head_dim,
1,
True,
True, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
use_post_attn_norm=False,
use_post_ffw_norm=True,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
unnormed_block = modules.Block(
Expand All @@ -391,10 +391,10 @@ def test_post_ffw_norm(
embed_dim,
head_dim,
1,
False,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
use_post_attn_norm=False,
use_post_ffw_norm=False,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)

Expand All @@ -411,7 +411,7 @@ def test_post_ffw_norm(
all_outputs.append(outputs)

normed_output, unnormed_output = all_outputs # pylint: disable=unbalanced-tuple-unpacking
self.assertFalse(jnp.not_equal(normed_output, unnormed_output).all())
self.assertTrue(jnp.not_equal(normed_output, unnormed_output).all())


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion examples/gemma/transformer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def nested_defaultdict():
))

if config.use_post_attn_norm:
params[f'layer_{layer_idx}']['post_attn_norm']['scale'] = jnp.ones((
params[f'layer_{layer_idx}']['post_attention_norm']['scale'] = jnp.ones((
config.embed_dim,
))
if config.use_post_ffw_norm:
Expand Down

0 comments on commit b28822f

Please sign in to comment.