Skip to content

Commit 3059d70

Browse files
casaroFlax Authors
authored andcommitted
Make RoPE Base Frequency configurable.
PiperOrigin-RevId: 733277425
1 parent 0769411 commit 3059d70

File tree

3 files changed

+58
-4
lines changed

3 files changed

+58
-4
lines changed

examples/gemma/modules.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
Shape = Sequence[Union[int, Any]]
3434

3535
K_MASK = -2.3819763e38 # Set to a large negative number.
36+
DEFAULT_ROPE_BASE_FREQUENCY = 10_000
3637

3738

3839
class AttentionType(enum.Enum):
@@ -80,9 +81,11 @@ def __init__(
8081
num_kv_heads: int,
8182
features: int,
8283
head_dim: int,
84+
query_pre_attn_scalar: float,
8385
attn_type: AttentionType,
8486
*,
8587
rngs: nnx.Rngs,
88+
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
8689
attn_logits_soft_cap: float | None = None,
8790
sliding_window_size: int | None = None,
8891
use_qk_norm: bool = False,
@@ -93,6 +96,7 @@ def __init__(
9396
'`sliding_window_size` must be set if `attn_type` is Local Sliding.'
9497
)
9598

99+
self.query_pre_attn_scalar = query_pre_attn_scalar
96100
self.attn_type = attn_type
97101
self.sliding_window_size = sliding_window_size
98102
self.attn_logits_soft_cap = attn_logits_soft_cap
@@ -101,6 +105,7 @@ def __init__(
101105
shape=(num_heads, head_dim, features),
102106
rngs=rngs,
103107
)
108+
self.rope_base_frequency = rope_base_frequency
104109
self.use_qk_norm = use_qk_norm
105110
self.sow_config = sow_config
106111

@@ -148,12 +153,14 @@ def __call__(
148153
query_proj,
149154
segment_pos,
150155
head_dim=self.head_dim,
156+
max_wavelength=self.rope_base_frequency,
151157
)
152-
query_scaled = query_proj * self.head_dim**-0.5
158+
query_scaled = query_proj * self.query_pre_attn_scalar
153159
key_proj = positional_embeddings.apply_rope(
154160
key_proj,
155161
segment_pos,
156162
head_dim=self.head_dim,
163+
max_wavelength=self.rope_base_frequency,
157164
)
158165

159166
# Cache is left aligned.
@@ -304,9 +311,11 @@ def __init__(
304311
hidden_dim: int,
305312
use_post_attn_norm: bool,
306313
use_post_ffw_norm: bool,
314+
query_pre_attn_scalar: float,
307315
attn_type: AttentionType,
308316
*,
309317
rngs: nnx.Rngs,
318+
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
310319
attn_logits_soft_cap: float | None = None,
311320
sliding_window_size: int | None = None,
312321
use_qk_norm: bool = False,
@@ -318,7 +327,9 @@ def __init__(
318327
num_kv_heads=num_kv_heads,
319328
features=embed_dim,
320329
head_dim=head_dim,
330+
query_pre_attn_scalar=query_pre_attn_scalar,
321331
attn_type=attn_type,
332+
rope_base_frequency=rope_base_frequency,
322333
attn_logits_soft_cap=attn_logits_soft_cap,
323334
sliding_window_size=sliding_window_size,
324335
rngs=rngs,

examples/gemma/modules_test.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def test_head_dim(self, head_dim):
7878
num_kv_heads=4,
7979
features=5,
8080
head_dim=head_dim,
81+
query_pre_attn_scalar=1.0,
8182
attn_type=modules.AttentionType.GLOBAL,
8283
rngs=nnx.Rngs(params=0),
8384
)
@@ -107,6 +108,7 @@ def test_use_qkv_einsum(
107108
num_kv_heads=num_kv_heads,
108109
features=5,
109110
head_dim=8,
111+
query_pre_attn_scalar=1.0,
110112
attn_type=modules.AttentionType.GLOBAL,
111113
rngs=nnx.Rngs(params=0),
112114
)
@@ -144,7 +146,8 @@ def test_attention(
144146
num_heads,
145147
features,
146148
head_dim,
147-
modules.AttentionType.GLOBAL,
149+
query_pre_attn_scalar=1.0,
150+
attn_type=modules.AttentionType.GLOBAL,
148151
rngs=nnx.Rngs(params=0),
149152
)
150153
cache = attn.init_cache(
@@ -177,7 +180,8 @@ def test_sliding_window(self, sliding_window_size):
177180
num_heads,
178181
features,
179182
head_dim,
180-
modules.AttentionType.GLOBAL,
183+
query_pre_attn_scalar=1.0,
184+
attn_type=modules.AttentionType.GLOBAL,
181185
rngs=nnx.Rngs(params=0),
182186
)
183187
cache = attn.init_cache(
@@ -191,7 +195,8 @@ def test_sliding_window(self, sliding_window_size):
191195
num_heads,
192196
features,
193197
head_dim,
194-
modules.AttentionType.LOCAL_SLIDING,
198+
query_pre_attn_scalar=1.0,
199+
attn_type=modules.AttentionType.LOCAL_SLIDING,
195200
sliding_window_size=sliding_window_size,
196201
rngs=nnx.Rngs(params=0),
197202
)
@@ -272,6 +277,7 @@ def test_block(
272277
1,
273278
use_post_attn_norm,
274279
use_post_ffw_norm,
280+
1.0,
275281
modules.AttentionType.GLOBAL,
276282
rngs=nnx.Rngs(params=0),
277283
)
@@ -315,6 +321,7 @@ def test_post_attention_norm(
315321
1,
316322
True,
317323
False, # use_post_ffw_norm
324+
1.0,
318325
modules.AttentionType.GLOBAL,
319326
rngs=nnx.Rngs(params=0),
320327
)
@@ -326,6 +333,7 @@ def test_post_attention_norm(
326333
1,
327334
False,
328335
False, # use_post_ffw_norm
336+
1.0,
329337
modules.AttentionType.GLOBAL,
330338
rngs=nnx.Rngs(params=0),
331339
)
@@ -373,6 +381,7 @@ def test_post_ffw_norm(
373381
1,
374382
True,
375383
True, # use_post_ffw_norm
384+
1.0,
376385
modules.AttentionType.GLOBAL,
377386
rngs=nnx.Rngs(params=0),
378387
)
@@ -384,6 +393,7 @@ def test_post_ffw_norm(
384393
1,
385394
False,
386395
False, # use_post_ffw_norm
396+
1.0,
387397
modules.AttentionType.GLOBAL,
388398
rngs=nnx.Rngs(params=0),
389399
)

examples/gemma/transformer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from collections.abc import Iterable
2020
import dataclasses
21+
import enum
2122
from typing import Any
2223

2324
from flax import nnx
@@ -32,6 +33,19 @@
3233
Cache = dict[str, modules.LayerCache]
3334

3435

36+
class QueryPreAttentionNormalisation(enum.Enum):
37+
"""Initialization strategy."""
38+
39+
# Whether to scale the query by 1/sqrt(head_dim)
40+
BY_ONE_OVER_SQRT_HEAD_DIM = enum.auto()
41+
42+
# Whether to scale the query by `embed_dim // num_heads`
43+
BY_EMBED_DIM_DIV_NUM_HEADS = enum.auto()
44+
45+
# Whether to scale the query by `1/sqrt(embed_dim // num_heads)`
46+
BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS = enum.auto()
47+
48+
3549
@dataclasses.dataclass(frozen=True)
3650
class TransformerConfig:
3751
"""Configuration for the gemma transformer."""
@@ -47,10 +61,25 @@ class TransformerConfig:
4761
use_post_attn_norm: bool
4862
use_post_ffw_norm: bool
4963
attention_types: Iterable[modules.AttentionType]
64+
query_pre_attn_norm: QueryPreAttentionNormalisation = (
65+
QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM
66+
)
5067
attn_logits_soft_cap: float | None = None
68+
local_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
69+
global_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY
5170
use_qk_norm: bool = False
5271
sliding_window_size: int | None = None
5372

73+
def query_pre_attn_scalar(self) -> float:
74+
"""Returns the scalar to multiply the query by before attention."""
75+
match self.query_pre_attn_norm:
76+
case QueryPreAttentionNormalisation.BY_EMBED_DIM_DIV_NUM_HEADS:
77+
return self.embed_dim // self.num_heads
78+
case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS: # pylint: disable=line-too-long
79+
return (self.embed_dim // self.num_heads) ** -0.5
80+
case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM | _:
81+
return self.head_dim**-0.5
82+
5483
@classmethod
5584
def from_path(cls, path: str) -> TransformerConfig:
5685
"""Creates a TransformerConfig from loaded parameters."""
@@ -248,7 +277,11 @@ def __init__(
248277
use_post_ffw_norm=config.use_post_ffw_norm,
249278
attn_logits_soft_cap=config.attn_logits_soft_cap,
250279
attn_type=attn_type,
280+
query_pre_attn_scalar=config.query_pre_attn_scalar(),
251281
rngs=rngs,
282+
rope_base_frequency=config.local_base_frequency
283+
if attn_type == modules.AttentionType.LOCAL_SLIDING
284+
else config.global_base_frequency,
252285
use_qk_norm=config.use_qk_norm,
253286
sow_config=sow_config,
254287
)

0 commit comments

Comments
 (0)