From 2e3d283df38fccd815540457511e3fea6194bdda Mon Sep 17 00:00:00 2001 From: Sascha Rothe Date: Tue, 4 Mar 2025 04:37:20 -0800 Subject: [PATCH] Add configurable Query Pre Attention scalar. PiperOrigin-RevId: 733277426 --- examples/gemma/modules.py | 6 +++++- examples/gemma/modules_test.py | 16 +++++++++++++--- examples/gemma/transformer.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/examples/gemma/modules.py b/examples/gemma/modules.py index c7572e055..eb7ff2166 100644 --- a/examples/gemma/modules.py +++ b/examples/gemma/modules.py @@ -80,6 +80,7 @@ def __init__( num_kv_heads: int, features: int, head_dim: int, + query_pre_attn_scalar: float, attn_type: AttentionType, *, rngs: nnx.Rngs, @@ -93,6 +94,7 @@ def __init__( '`sliding_window_size` must be set if `attn_type` is Local Sliding.' ) + self.query_pre_attn_scalar = query_pre_attn_scalar self.attn_type = attn_type self.sliding_window_size = sliding_window_size self.attn_logits_soft_cap = attn_logits_soft_cap @@ -149,7 +151,7 @@ def __call__( segment_pos, head_dim=self.head_dim, ) - query_scaled = query_proj * self.head_dim**-0.5 + query_scaled = query_proj * self.query_pre_attn_scalar key_proj = positional_embeddings.apply_rope( key_proj, segment_pos, @@ -304,6 +306,7 @@ def __init__( hidden_dim: int, use_post_attn_norm: bool, use_post_ffw_norm: bool, + query_pre_attn_scalar: float, attn_type: AttentionType, *, rngs: nnx.Rngs, @@ -318,6 +321,7 @@ def __init__( num_kv_heads=num_kv_heads, features=embed_dim, head_dim=head_dim, + query_pre_attn_scalar=query_pre_attn_scalar, attn_type=attn_type, attn_logits_soft_cap=attn_logits_soft_cap, sliding_window_size=sliding_window_size, diff --git a/examples/gemma/modules_test.py b/examples/gemma/modules_test.py index 7439585cc..0f94a5e4f 100644 --- a/examples/gemma/modules_test.py +++ b/examples/gemma/modules_test.py @@ -78,6 +78,7 @@ def test_head_dim(self, head_dim): num_kv_heads=4, features=5, head_dim=head_dim, + query_pre_attn_scalar=1.0, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -107,6 +108,7 @@ def test_use_qkv_einsum( num_kv_heads=num_kv_heads, features=5, head_dim=8, + query_pre_attn_scalar=1.0, attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -144,7 +146,8 @@ def test_attention( num_heads, features, head_dim, - modules.AttentionType.GLOBAL, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) cache = attn.init_cache( @@ -177,7 +180,8 @@ def test_sliding_window(self, sliding_window_size): num_heads, features, head_dim, - modules.AttentionType.GLOBAL, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) cache = attn.init_cache( @@ -191,7 +195,8 @@ def test_sliding_window(self, sliding_window_size): num_heads, features, head_dim, - modules.AttentionType.LOCAL_SLIDING, + query_pre_attn_scalar=1.0, + attn_type=modules.AttentionType.LOCAL_SLIDING, sliding_window_size=sliding_window_size, rngs=nnx.Rngs(params=0), ) @@ -272,6 +277,7 @@ def test_block( 1, use_post_attn_norm, use_post_ffw_norm, + 1.0, modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -315,6 +321,7 @@ def test_post_attention_norm( 1, True, False, # use_post_ffw_norm + 1.0, modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -326,6 +333,7 @@ def test_post_attention_norm( 1, False, False, # use_post_ffw_norm + 1.0, modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -373,6 +381,7 @@ def test_post_ffw_norm( 1, True, True, # use_post_ffw_norm + 1.0, modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) @@ -384,6 +393,7 @@ def test_post_ffw_norm( 1, False, False, # use_post_ffw_norm + 1.0, modules.AttentionType.GLOBAL, rngs=nnx.Rngs(params=0), ) diff --git a/examples/gemma/transformer.py b/examples/gemma/transformer.py index e253c4082..8f9999807 100644 --- a/examples/gemma/transformer.py +++ b/examples/gemma/transformer.py @@ -18,6 +18,7 @@ from collections.abc import Iterable import dataclasses +import enum from typing import Any from flax import nnx @@ -32,6 +33,19 @@ Cache = dict[str, modules.LayerCache] +class QueryPreAttentionNormalisation(enum.Enum): + """Initialization strategy.""" + + # Whether to scale the query by 1/sqrt(head_dim) + BY_ONE_OVER_SQRT_HEAD_DIM = enum.auto() + + # Whether to scale the query by `embed_dim // num_heads` + BY_EMBED_DIM_DIV_NUM_HEADS = enum.auto() + + # Whether to scale the query by `1/sqrt(embed_dim // num_heads)` + BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS = enum.auto() + + @dataclasses.dataclass(frozen=True) class TransformerConfig: """Configuration for the gemma transformer.""" @@ -47,10 +61,23 @@ class TransformerConfig: use_post_attn_norm: bool use_post_ffw_norm: bool attention_types: Iterable[modules.AttentionType] + query_pre_attn_norm: QueryPreAttentionNormalisation = ( + QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM + ) attn_logits_soft_cap: float | None = None use_qk_norm: bool = False sliding_window_size: int | None = None + def query_pre_attn_scalar(self) -> float: + """Returns the scalar to multiply the query by before attention.""" + match self.query_pre_attn_norm: + case QueryPreAttentionNormalisation.BY_EMBED_DIM_DIV_NUM_HEADS: + return self.embed_dim // self.num_heads + case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_EMBED_DIM_DIV_NUM_HEADS: # pylint: disable=line-too-long + return (self.embed_dim // self.num_heads) ** -0.5 + case QueryPreAttentionNormalisation.BY_ONE_OVER_SQRT_HEAD_DIM | _: + return self.head_dim**-0.5 + @classmethod def from_path(cls, path: str) -> TransformerConfig: """Creates a TransformerConfig from loaded parameters.""" @@ -248,6 +275,7 @@ def __init__( use_post_ffw_norm=config.use_post_ffw_norm, attn_logits_soft_cap=config.attn_logits_soft_cap, attn_type=attn_type, + query_pre_attn_scalar=config.query_pre_attn_scalar(), rngs=rngs, use_qk_norm=config.use_qk_norm, sow_config=sow_config,