Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QK Norm. #4594

Merged
merged 1 commit into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
rngs: nnx.Rngs,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_qk_norm: bool = False,
sow_config: sow_lib.SowConfig = sow_lib.SowConfig()
):
if attn_type == AttentionType.LOCAL_SLIDING and sliding_window_size is None:
Expand All @@ -100,6 +101,7 @@ def __init__(
shape=(num_heads, head_dim, features),
rngs=rngs,
)
self.use_qk_norm = use_qk_norm
self.sow_config = sow_config

if num_heads == num_kv_heads:
Expand All @@ -119,6 +121,9 @@ def __init__(
shape=(2, num_kv_heads, features, head_dim),
rngs=rngs,
)
if self.use_qk_norm:
self._query_norm = layers.RMSNorm(head_dim, rngs=rngs)
self._key_norm = layers.RMSNorm(head_dim, rngs=rngs)

def __call__(
self,
Expand All @@ -135,6 +140,10 @@ def __call__(
query_proj = self.q_einsum(x)
key_proj, value_proj = self.kv_einsum(x)

if self.use_qk_norm:
query_proj = self._query_norm(query_proj)
key_proj = self._key_norm(key_proj)

query_proj = positional_embeddings.apply_rope(
query_proj,
segment_pos,
Expand Down Expand Up @@ -300,6 +309,7 @@ def __init__(
rngs: nnx.Rngs,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_qk_norm: bool = False,
sow_config: sow_lib.SowConfig = sow_lib.SowConfig()
):
self.pre_attention_norm = layers.RMSNorm(embed_dim, rngs=rngs)
Expand All @@ -312,6 +322,7 @@ def __init__(
attn_logits_soft_cap=attn_logits_soft_cap,
sliding_window_size=sliding_window_size,
rngs=rngs,
use_qk_norm=use_qk_norm,
sow_config=sow_config,
)
if use_post_attn_norm:
Expand Down
2 changes: 2 additions & 0 deletions examples/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class TransformerConfig:
use_post_ffw_norm: bool
attention_types: Iterable[modules.AttentionType]
attn_logits_soft_cap: float | None = None
use_qk_norm: bool = False
sliding_window_size: int | None = None

@classmethod
Expand Down Expand Up @@ -248,6 +249,7 @@ def __init__(
attn_logits_soft_cap=config.attn_logits_soft_cap,
attn_type=attn_type,
rngs=rngs,
use_qk_norm=config.use_qk_norm,
sow_config=sow_config,
)
for _, attn_type in zip(
Expand Down
Loading