Skip to content

Commit 324bbe0

Browse files
casaroFlax Authors
authored andcommitted
Make RoPE Base Frequency configurable.
PiperOrigin-RevId: 734217570
1 parent fd32d45 commit 324bbe0

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

examples/gemma/modules.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
Shape = Sequence[Union[int, Any]]
3434

3535
K_MASK = -2.3819763e38 # Set to a large negative number.
36+
DEFAULT_ROPE_BASE_FREQUENCY = 10_000
3637

3738

3839
class AttentionType(enum.Enum):
@@ -84,6 +85,7 @@ def __init__(
8485
attn_type: AttentionType,
8586
*,
8687
rngs: nnx.Rngs,
88+
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
8789
attn_logits_soft_cap: float | None = None,
8890
sliding_window_size: int | None = None,
8991
use_qk_norm: bool = False,
@@ -103,6 +105,7 @@ def __init__(
103105
shape=(num_heads, head_dim, features),
104106
rngs=rngs,
105107
)
108+
self.rope_base_frequency = rope_base_frequency
106109
self.use_qk_norm = use_qk_norm
107110
self.sow_config = sow_config
108111

@@ -150,12 +153,14 @@ def __call__(
150153
query_proj,
151154
segment_pos,
152155
head_dim=self.head_dim,
156+
max_wavelength=self.rope_base_frequency,
153157
)
154158
query_scaled = query_proj * self.query_pre_attn_scalar
155159
key_proj = positional_embeddings.apply_rope(
156160
key_proj,
157161
segment_pos,
158162
head_dim=self.head_dim,
163+
max_wavelength=self.rope_base_frequency,
159164
)
160165

161166
# Cache is left aligned.
@@ -310,6 +315,7 @@ def __init__(
310315
attn_type: AttentionType,
311316
*,
312317
rngs: nnx.Rngs,
318+
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
313319
attn_logits_soft_cap: float | None = None,
314320
sliding_window_size: int | None = None,
315321
use_qk_norm: bool = False,
@@ -323,6 +329,7 @@ def __init__(
323329
head_dim=head_dim,
324330
query_pre_attn_scalar=query_pre_attn_scalar,
325331
attn_type=attn_type,
332+
rope_base_frequency=rope_base_frequency,
326333
attn_logits_soft_cap=attn_logits_soft_cap,
327334
sliding_window_size=sliding_window_size,
328335
rngs=rngs,

examples/gemma/transformer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ class TransformerConfig:
6565
QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM
6666
)
6767
attn_logits_soft_cap: float | None = None
68+
local_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
69+
global_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
6870
use_qk_norm: bool = False
6971
sliding_window_size: int | None = None
7072

@@ -277,6 +279,9 @@ def __init__(
277279
attn_type=attn_type,
278280
query_pre_attn_scalar=config.query_pre_attn_scalar(),
279281
rngs=rngs,
282+
rope_base_frequency=config.local_base_frequency
283+
if attn_type == modules.AttentionType.LOCAL_SLIDING
284+
else config.global_base_frequency,
280285
use_qk_norm=config.use_qk_norm,
281286
sow_config=sow_config,
282287
)

0 commit comments

Comments
 (0)