Skip to content

Commit

Permalink
Make RoPE Base Frequency configurable.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733277425
  • Loading branch information
casaro authored and Flax Authors committed Mar 5, 2025
1 parent 0769411 commit 3059d70
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 4 deletions.
13 changes: 12 additions & 1 deletion examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Shape = Sequence[Union[int, Any]]

K_MASK = -2.3819763e38 # Set to a large negative number.
DEFAULT_ROPE_BASE_FREQUENCY = 10_000


class AttentionType(enum.Enum):
Expand Down Expand Up @@ -80,9 +81,11 @@ def __init__(
num_kv_heads: int,
features: int,
head_dim: int,
query_pre_attn_scalar: float,
attn_type: AttentionType,
*,
rngs: nnx.Rngs,
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_qk_norm: bool = False,
Expand All @@ -93,6 +96,7 @@ def __init__(
'`sliding_window_size` must be set if `attn_type` is Local Sliding.'
)

self.query_pre_attn_scalar = query_pre_attn_scalar
self.attn_type = attn_type
self.sliding_window_size = sliding_window_size
self.attn_logits_soft_cap = attn_logits_soft_cap
Expand All @@ -101,6 +105,7 @@ def __init__(
shape=(num_heads, head_dim, features),
rngs=rngs,
)
self.rope_base_frequency = rope_base_frequency
self.use_qk_norm = use_qk_norm
self.sow_config = sow_config

Expand Down Expand Up @@ -148,12 +153,14 @@ def __call__(
query_proj,
segment_pos,
head_dim=self.head_dim,
max_wavelength=self.rope_base_frequency,
)
query_scaled = query_proj * self.head_dim**-0.5
query_scaled = query_proj * self.query_pre_attn_scalar
key_proj = positional_embeddings.apply_rope(
key_proj,
segment_pos,
head_dim=self.head_dim,
max_wavelength=self.rope_base_frequency,
)

# Cache is left aligned.
Expand Down Expand Up @@ -304,9 +311,11 @@ def __init__(
hidden_dim: int,
use_post_attn_norm: bool,
use_post_ffw_norm: bool,
query_pre_attn_scalar: float,
attn_type: AttentionType,
*,
rngs: nnx.Rngs,
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_qk_norm: bool = False,
Expand All @@ -318,7 +327,9 @@ def __init__(
num_kv_heads=num_kv_heads,
features=embed_dim,
head_dim=head_dim,
query_pre_attn_scalar=query_pre_attn_scalar,
attn_type=attn_type,
rope_base_frequency=rope_base_frequency,
attn_logits_soft_cap=attn_logits_soft_cap,
sliding_window_size=sliding_window_size,
rngs=rngs,
Expand Down
16 changes: 13 additions & 3 deletions examples/gemma/modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_head_dim(self, head_dim):
num_kv_heads=4,
features=5,
head_dim=head_dim,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -107,6 +108,7 @@ def test_use_qkv_einsum(
num_kv_heads=num_kv_heads,
features=5,
head_dim=8,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -144,7 +146,8 @@ def test_attention(
num_heads,
features,
head_dim,
modules.AttentionType.GLOBAL,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
cache = attn.init_cache(
Expand Down Expand Up @@ -177,7 +180,8 @@ def test_sliding_window(self, sliding_window_size):
num_heads,
features,
head_dim,
modules.AttentionType.GLOBAL,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
cache = attn.init_cache(
Expand All @@ -191,7 +195,8 @@ def test_sliding_window(self, sliding_window_size):
num_heads,
features,
head_dim,
modules.AttentionType.LOCAL_SLIDING,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.LOCAL_SLIDING,
sliding_window_size=sliding_window_size,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -272,6 +277,7 @@ def test_block(
1,
use_post_attn_norm,
use_post_ffw_norm,
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -315,6 +321,7 @@ def test_post_attention_norm(
1,
True,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand All @@ -326,6 +333,7 @@ def test_post_attention_norm(
1,
False,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -373,6 +381,7 @@ def test_post_ffw_norm(
1,
True,
True, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand All @@ -384,6 +393,7 @@ def test_post_ffw_norm(
1,
False,
False, # use_post_ffw_norm
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down
33 changes: 33 additions & 0 deletions examples/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from collections.abc import Iterable
import dataclasses
import enum
from typing import Any

from flax import nnx
Expand All @@ -32,6 +33,19 @@
Cache = dict[str, modules.LayerCache]


class QueryPreAttentionNormalisation(enum.Enum):
"""Initialization strategy."""

# Whether to scale the query by 1/sqrt(head_dim)
BY_ONE_OVER_SQRT_HEAD_DIM = enum.auto()

# Whether to scale the query by `embed_dim // num_heads`
BY_EMBED_DIM_DIV_NUM_HEADS = enum.auto()

# Whether to scale the query by `1/sqrt(embed_dim // num_heads)`
BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS = enum.auto()


@dataclasses.dataclass(frozen=True)
class TransformerConfig:
"""Configuration for the gemma transformer."""
Expand All @@ -47,10 +61,25 @@ class TransformerConfig:
use_post_attn_norm: bool
use_post_ffw_norm: bool
attention_types: Iterable[modules.AttentionType]
query_pre_attn_norm: QueryPreAttentionNormalisation = (
QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM
)
attn_logits_soft_cap: float | None = None
local_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
global_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
use_qk_norm: bool = False
sliding_window_size: int | None = None

def query_pre_attn_scalar(self) -> float:
"""Returns the scalar to multiply the query by before attention."""
match self.query_pre_attn_norm:
case QueryPreAttentionNormalisation.BY_EMBED_DIM_DIV_NUM_HEADS:
return self.embed_dim // self.num_heads
case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS: # pylint: disable=line-too-long
return (self.embed_dim // self.num_heads) ** -0.5
case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM | _:
return self.head_dim**-0.5

@classmethod
def from_path(cls, path: str) -> TransformerConfig:
"""Creates a TransformerConfig from loaded parameters."""
Expand Down Expand Up @@ -248,7 +277,11 @@ def __init__(
use_post_ffw_norm=config.use_post_ffw_norm,
attn_logits_soft_cap=config.attn_logits_soft_cap,
attn_type=attn_type,
query_pre_attn_scalar=config.query_pre_attn_scalar(),
rngs=rngs,
rope_base_frequency=config.local_base_frequency
if attn_type == modules.AttentionType.LOCAL_SLIDING
else config.global_base_frequency,
use_qk_norm=config.use_qk_norm,
sow_config=sow_config,
)
Expand Down

0 comments on commit 3059d70

Please sign in to comment.