33
33
Shape = Sequence [Union [int , Any ]]
34
34
35
35
K_MASK = - 2.3819763e38 # Set to a large negative number.
36
+ DEFAULT_ROPE_BASE_FREQUENCY = 10_000
36
37
37
38
38
39
class AttentionType (enum .Enum ):
@@ -84,6 +85,7 @@ def __init__(
84
85
attn_type : AttentionType ,
85
86
* ,
86
87
rngs : nnx .Rngs ,
88
+ rope_base_frequency : int = DEFAULT_ROPE_BASE_FREQUENCY ,
87
89
attn_logits_soft_cap : float | None = None ,
88
90
sliding_window_size : int | None = None ,
89
91
use_qk_norm : bool = False ,
@@ -103,6 +105,7 @@ def __init__(
103
105
shape = (num_heads , head_dim , features ),
104
106
rngs = rngs ,
105
107
)
108
+ self .rope_base_frequency = rope_base_frequency
106
109
self .use_qk_norm = use_qk_norm
107
110
self .sow_config = sow_config
108
111
@@ -150,12 +153,14 @@ def __call__(
150
153
query_proj ,
151
154
segment_pos ,
152
155
head_dim = self .head_dim ,
156
+ max_wavelength = self .rope_base_frequency ,
153
157
)
154
158
query_scaled = query_proj * self .query_pre_attn_scalar
155
159
key_proj = positional_embeddings .apply_rope (
156
160
key_proj ,
157
161
segment_pos ,
158
162
head_dim = self .head_dim ,
163
+ max_wavelength = self .rope_base_frequency ,
159
164
)
160
165
161
166
# Cache is left aligned.
@@ -310,6 +315,7 @@ def __init__(
310
315
attn_type : AttentionType ,
311
316
* ,
312
317
rngs : nnx .Rngs ,
318
+ rope_base_frequency : int = DEFAULT_ROPE_BASE_FREQUENCY ,
313
319
attn_logits_soft_cap : float | None = None ,
314
320
sliding_window_size : int | None = None ,
315
321
use_qk_norm : bool = False ,
@@ -323,6 +329,7 @@ def __init__(
323
329
head_dim = head_dim ,
324
330
query_pre_attn_scalar = query_pre_attn_scalar ,
325
331
attn_type = attn_type ,
332
+ rope_base_frequency = rope_base_frequency ,
326
333
attn_logits_soft_cap = attn_logits_soft_cap ,
327
334
sliding_window_size = sliding_window_size ,
328
335
rngs = rngs ,
0 commit comments