From 50f6c9b387d50c3bce84983e505afd940235de5c Mon Sep 17 00:00:00 2001 From: Flax Team Date: Tue, 4 Mar 2025 04:37:20 -0800 Subject: [PATCH] Add QK Norm. PiperOrigin-RevId: 733277424 --- examples/gemma/modules.py | 11 +++++++++++ examples/gemma/transformer.py | 2 ++ 2 files changed, 13 insertions(+) diff --git a/examples/gemma/modules.py b/examples/gemma/modules.py index 5eb6fa57..c7572e05 100644 --- a/examples/gemma/modules.py +++ b/examples/gemma/modules.py @@ -85,6 +85,7 @@ def __init__( rngs: nnx.Rngs, attn_logits_soft_cap: float | None = None, sliding_window_size: int | None = None, + use_qk_norm: bool = False, sow_config: sow_lib.SowConfig = sow_lib.SowConfig() ): if attn_type == AttentionType.LOCAL_SLIDING and sliding_window_size is None: @@ -100,6 +101,7 @@ def __init__( shape=(num_heads, head_dim, features), rngs=rngs, ) + self.use_qk_norm = use_qk_norm self.sow_config = sow_config if num_heads == num_kv_heads: @@ -119,6 +121,9 @@ def __init__( shape=(2, num_kv_heads, features, head_dim), rngs=rngs, ) + if self.use_qk_norm: + self._query_norm = layers.RMSNorm(head_dim, rngs=rngs) + self._key_norm = layers.RMSNorm(head_dim, rngs=rngs) def __call__( self, @@ -135,6 +140,10 @@ def __call__( query_proj = self.q_einsum(x) key_proj, value_proj = self.kv_einsum(x) + if self.use_qk_norm: + query_proj = self._query_norm(query_proj) + key_proj = self._key_norm(key_proj) + query_proj = positional_embeddings.apply_rope( query_proj, segment_pos, @@ -300,6 +309,7 @@ def __init__( rngs: nnx.Rngs, attn_logits_soft_cap: float | None = None, sliding_window_size: int | None = None, + use_qk_norm: bool = False, sow_config: sow_lib.SowConfig = sow_lib.SowConfig() ): self.pre_attention_norm = layers.RMSNorm(embed_dim, rngs=rngs) @@ -312,6 +322,7 @@ def __init__( attn_logits_soft_cap=attn_logits_soft_cap, sliding_window_size=sliding_window_size, rngs=rngs, + use_qk_norm=use_qk_norm, sow_config=sow_config, ) if use_post_attn_norm: diff --git a/examples/gemma/transformer.py b/examples/gemma/transformer.py index cdf607c1..e253c408 100644 --- a/examples/gemma/transformer.py +++ b/examples/gemma/transformer.py @@ -48,6 +48,7 @@ class TransformerConfig: use_post_ffw_norm: bool attention_types: Iterable[modules.AttentionType] attn_logits_soft_cap: float | None = None + use_qk_norm: bool = False sliding_window_size: int | None = None @classmethod @@ -248,6 +249,7 @@ def __init__( attn_logits_soft_cap=config.attn_logits_soft_cap, attn_type=attn_type, rngs=rngs, + use_qk_norm=config.use_qk_norm, sow_config=sow_config, ) for _, attn_type in zip(