Skip to content

Commit

Permalink
Add configurable Query Pre Attention scalar.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733277426
  • Loading branch information
casaro authored and Flax Authors committed Mar 6, 2025
1 parent 8254dd0 commit dade45c
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 4 deletions.
6 changes: 5 additions & 1 deletion examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
num_kv_heads: int,
features: int,
head_dim: int,
query_pre_attn_scalar: float,
attn_type: AttentionType,
*,
rngs: nnx.Rngs,
Expand All @@ -93,6 +94,7 @@ def __init__(
'`sliding_window_size` must be set if `attn_type` is Local Sliding.'
)

self.query_pre_attn_scalar = query_pre_attn_scalar
self.attn_type = attn_type
self.sliding_window_size = sliding_window_size
self.attn_logits_soft_cap = attn_logits_soft_cap
Expand Down Expand Up @@ -149,7 +151,7 @@ def __call__(
segment_pos,
head_dim=self.head_dim,
)
query_scaled = query_proj * self.head_dim**-0.5
query_scaled = query_proj * self.query_pre_attn_scalar
key_proj = positional_embeddings.apply_rope(
key_proj,
segment_pos,
Expand Down Expand Up @@ -304,6 +306,7 @@ def __init__(
hidden_dim: int,
use_post_attn_norm: bool,
use_post_ffw_norm: bool,
query_pre_attn_scalar: float,
attn_type: AttentionType,
*,
rngs: nnx.Rngs,
Expand All @@ -318,6 +321,7 @@ def __init__(
num_kv_heads=num_kv_heads,
features=embed_dim,
head_dim=head_dim,
query_pre_attn_scalar=query_pre_attn_scalar,
attn_type=attn_type,
attn_logits_soft_cap=attn_logits_soft_cap,
sliding_window_size=sliding_window_size,
Expand Down
16 changes: 13 additions & 3 deletions examples/gemma/modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_head_dim(self, head_dim):
num_kv_heads=4,
features=5,
head_dim=head_dim,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -107,6 +108,7 @@ def test_use_qkv_einsum(
num_kv_heads=num_kv_heads,
features=5,
head_dim=8,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -144,7 +146,8 @@ def test_attention(
num_heads,
features,
head_dim,
modules.AttentionType.GLOBAL,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
cache = attn.init_cache(
Expand Down Expand Up @@ -177,7 +180,8 @@ def test_sliding_window(self, sliding_window_size):
num_heads,
features,
head_dim,
modules.AttentionType.GLOBAL,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
cache = attn.init_cache(
Expand All @@ -191,7 +195,8 @@ def test_sliding_window(self, sliding_window_size):
num_heads,
features,
head_dim,
modules.AttentionType.LOCAL_SLIDING,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.LOCAL_SLIDING,
sliding_window_size=sliding_window_size,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -272,6 +277,7 @@ def test_block(
1,
use_post_attn_norm,
use_post_ffw_norm,
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -315,6 +321,7 @@ def test_post_attention_norm(
1,
True,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand All @@ -326,6 +333,7 @@ def test_post_attention_norm(
1,
False,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -373,6 +381,7 @@ def test_post_ffw_norm(
1,
True,
True, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand All @@ -384,6 +393,7 @@ def test_post_ffw_norm(
1,
False,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down
28 changes: 28 additions & 0 deletions examples/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from collections.abc import Iterable
import dataclasses
import enum
from typing import Any

from flax import nnx
Expand All @@ -32,6 +33,19 @@
Cache = dict[str, modules.LayerCache]


class QueryPreAttentionNormalisation(enum.Enum):
"""Initialization strategy."""

# Whether to scale the query by 1/sqrt(head_dim)
BY_ONE_OVER_SQRT_HEAD_DIM = enum.auto()

# Whether to scale the query by `embed_dim // num_heads`
BY_EMBED_DIM_DIV_NUM_HEADS = enum.auto()

# Whether to scale the query by `1/sqrt(embed_dim // num_heads)`
BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS = enum.auto()


@dataclasses.dataclass(frozen=True)
class TransformerConfig:
"""Configuration for the gemma transformer."""
Expand All @@ -47,10 +61,23 @@ class TransformerConfig:
use_post_attn_norm: bool
use_post_ffw_norm: bool
attention_types: Iterable[modules.AttentionType]
query_pre_attn_norm: QueryPreAttentionNormalisation = (
QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM
)
attn_logits_soft_cap: float | None = None
use_qk_norm: bool = False
sliding_window_size: int | None = None

def query_pre_attn_scalar(self) -> float:
"""Returns the scalar to multiply the query by before attention."""
match self.query_pre_attn_norm:
case QueryPreAttentionNormalisation.BY_EMBED_DIM_DIV_NUM_HEADS:
return self.embed_dim // self.num_heads
case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS: # pylint: disable=line-too-long
return (self.embed_dim // self.num_heads) ** -0.5
case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM | _:
return self.head_dim**-0.5

@classmethod
def from_path(cls, path: str) -> TransformerConfig:
"""Creates a TransformerConfig from loaded parameters."""
Expand Down Expand Up @@ -248,6 +275,7 @@ def __init__(
use_post_ffw_norm=config.use_post_ffw_norm,
attn_logits_soft_cap=config.attn_logits_soft_cap,
attn_type=attn_type,
query_pre_attn_scalar=config.query_pre_attn_scalar(),
rngs=rngs,
use_qk_norm=config.use_qk_norm,
sow_config=sow_config,
Expand Down

0 comments on commit dade45c

Please sign in to comment.