diff --git a/examples/gemma/modules.py b/examples/gemma/modules.py index eb7ff216..be152eda 100644 --- a/examples/gemma/modules.py +++ b/examples/gemma/modules.py @@ -33,6 +33,7 @@ Shape = Sequence[Union[int, Any]] K_MASK = -2.3819763e38 # Set to a large negative number. +DEFAULT_ROPE_BASE_FREQUENCY = 10_000 class AttentionType(enum.Enum): @@ -84,6 +85,7 @@ def __init__( attn_type: AttentionType, *, rngs: nnx.Rngs, + rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY, attn_logits_soft_cap: float | None = None, sliding_window_size: int | None = None, use_qk_norm: bool = False, @@ -103,6 +105,7 @@ def __init__( shape=(num_heads, head_dim, features), rngs=rngs, ) + self.rope_base_frequency = rope_base_frequency self.use_qk_norm = use_qk_norm self.sow_config = sow_config @@ -150,12 +153,14 @@ def __call__( query_proj, segment_pos, head_dim=self.head_dim, + max_wavelength=self.rope_base_frequency, ) query_scaled = query_proj * self.query_pre_attn_scalar key_proj = positional_embeddings.apply_rope( key_proj, segment_pos, head_dim=self.head_dim, + max_wavelength=self.rope_base_frequency, ) # Cache is left aligned. @@ -310,6 +315,7 @@ def __init__( attn_type: AttentionType, *, rngs: nnx.Rngs, + rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY, attn_logits_soft_cap: float | None = None, sliding_window_size: int | None = None, use_qk_norm: bool = False, @@ -323,6 +329,7 @@ def __init__( head_dim=head_dim, query_pre_attn_scalar=query_pre_attn_scalar, attn_type=attn_type, + rope_base_frequency=rope_base_frequency, attn_logits_soft_cap=attn_logits_soft_cap, sliding_window_size=sliding_window_size, rngs=rngs, diff --git a/examples/gemma/transformer.py b/examples/gemma/transformer.py index 8f999980..98b1c8e9 100644 --- a/examples/gemma/transformer.py +++ b/examples/gemma/transformer.py @@ -65,6 +65,8 @@ class TransformerConfig: QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM ) attn_logits_soft_cap: float | None = None + local_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY + global_base_frequency: int = modules.DEFAULT_ROPE_BASE_FREQUENCY use_qk_norm: bool = False sliding_window_size: int | None = None @@ -277,6 +279,9 @@ def __init__( attn_type=attn_type, query_pre_attn_scalar=config.query_pre_attn_scalar(), rngs=rngs, + rope_base_frequency=config.local_base_frequency + if attn_type == modules.AttentionType.LOCAL_SLIDING + else config.global_base_frequency, use_qk_norm=config.use_qk_norm, sow_config=sow_config, )