Skip to content

Commit

Permalink
Make RoPE Base Frequency configurable.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 734217570
  • Loading branch information
casaro authored and Flax Authors committed Mar 6, 2025
1 parent fd32d45 commit 324bbe0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
7 changes: 7 additions & 0 deletions examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Shape = Sequence[Union[int, Any]]

K_MASK = -2.3819763e38 # Set to a large negative number.
DEFAULT_ROPE_BASE_FREQUENCY = 10_000


class AttentionType(enum.Enum):
Expand Down Expand Up @@ -84,6 +85,7 @@ def __init__(
attn_type: AttentionType,
*,
rngs: nnx.Rngs,
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_qk_norm: bool = False,
Expand All @@ -103,6 +105,7 @@ def __init__(
shape=(num_heads, head_dim, features),
rngs=rngs,
)
self.rope_base_frequency = rope_base_frequency
self.use_qk_norm = use_qk_norm
self.sow_config = sow_config

Expand Down Expand Up @@ -150,12 +153,14 @@ def __call__(
query_proj,
segment_pos,
head_dim=self.head_dim,
max_wavelength=self.rope_base_frequency,
)
query_scaled = query_proj * self.query_pre_attn_scalar
key_proj = positional_embeddings.apply_rope(
key_proj,
segment_pos,
head_dim=self.head_dim,
max_wavelength=self.rope_base_frequency,
)

# Cache is left aligned.
Expand Down Expand Up @@ -310,6 +315,7 @@ def __init__(
attn_type: AttentionType,
*,
rngs: nnx.Rngs,
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
use_qk_norm: bool = False,
Expand All @@ -323,6 +329,7 @@ def __init__(
head_dim=head_dim,
query_pre_attn_scalar=query_pre_attn_scalar,
attn_type=attn_type,
rope_base_frequency=rope_base_frequency,
attn_logits_soft_cap=attn_logits_soft_cap,
sliding_window_size=sliding_window_size,
rngs=rngs,
Expand Down
5 changes: 5 additions & 0 deletions examples/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ class TransformerConfig:
QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM
)
attn_logits_soft_cap: float | None = None
local_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
global_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
use_qk_norm: bool = False
sliding_window_size: int | None = None

Expand Down Expand Up @@ -277,6 +279,9 @@ def __init__(
attn_type=attn_type,
query_pre_attn_scalar=config.query_pre_attn_scalar(),
rngs=rngs,
rope_base_frequency=config.local_base_frequency
if attn_type == modules.AttentionType.LOCAL_SLIDING
else config.global_base_frequency,
use_qk_norm=config.use_qk_norm,
sow_config=sow_config,
)
Expand Down

0 comments on commit 324bbe0

Please sign in to comment.