Skip to content

Commit

Permalink
Add QK Norm.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733277424
  • Loading branch information
Flax Team committed Mar 4, 2025
1 parent 45a8f84 commit ebd5365
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 0 deletions.
11 changes: 11 additions & 0 deletions examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
rngs: nnx.Rngs,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_qk_norm: bool = False,
sow_config: sow_lib.SowConfig = sow_lib.SowConfig()
):
if attn_type == AttentionType.LOCAL_SLIDING and sliding_window_size is None:
Expand All @@ -100,6 +101,7 @@ def __init__(
shape=(num_heads, head_dim, features),
rngs=rngs,
)
self.use_qk_norm = use_qk_norm
self.sow_config = sow_config

if num_heads == num_kv_heads:
Expand All @@ -119,6 +121,9 @@ def __init__(
shape=(2, num_kv_heads, features, head_dim),
rngs=rngs,
)
if self.use_qk_norm:
self._query_norm = layers.RMSNorm(head_dim, rngs=rngs)
self._key_norm = layers.RMSNorm(head_dim, rngs=rngs)

def __call__(
self,
Expand All @@ -135,6 +140,10 @@ def __call__(
query_proj = self.q_einsum(x)
key_proj, value_proj = self.kv_einsum(x)

if self.use_qk_norm:
query_proj = self._query_norm(query_proj)
key_proj = self._key_norm(key_proj)

query_proj = positional_embeddings.apply_rope(
query_proj,
segment_pos,
Expand Down Expand Up @@ -300,6 +309,7 @@ def __init__(
rngs: nnx.Rngs,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_qk_norm: bool = False,
sow_config: sow_lib.SowConfig = sow_lib.SowConfig()
):
self.pre_attention_norm = layers.RMSNorm(embed_dim, rngs=rngs)
Expand All @@ -312,6 +322,7 @@ def __init__(
attn_logits_soft_cap=attn_logits_soft_cap,
sliding_window_size=sliding_window_size,
rngs=rngs,
use_qk_norm=use_qk_norm,
sow_config=sow_config,
)
if use_post_attn_norm:
Expand Down
2 changes: 2 additions & 0 deletions examples/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TransformerConfig:
use_post_ffw_norm: bool
attention_types: Iterable[modules.AttentionType]
attn_logits_soft_cap: float | None = None
use_qk_norm: bool = False
sliding_window_size: int | None = None

@classmethod
Expand Down Expand Up @@ -248,6 +249,7 @@ def __init__(
attn_logits_soft_cap=config.attn_logits_soft_cap,
attn_type=attn_type,
rngs=rngs,
use_qk_norm=config.use_qk_norm,
sow_config=sow_config,
)
for _, attn_type in zip(
Expand Down

0 comments on commit ebd5365

Please sign in to comment.