Skip to content

Commit

Permalink
Add Sow Config to from_params constructor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733277422
  • Loading branch information
Flax Team committed Mar 4, 2025
1 parent a24d790 commit 787a820
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 49 deletions.
61 changes: 34 additions & 27 deletions examples/gemma/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Shape = Sequence[Union[int, Any]]

K_MASK = -2.3819763e38 # Set to a large negative number.
DEFAULT_ROPE_BASE_FREQUENCY = 10_000


class AttentionType(enum.Enum):
Expand Down Expand Up @@ -80,9 +81,11 @@ def __init__(
num_kv_heads: int,
features: int,
head_dim: int,
query_pre_attn_scalar: float,
attn_type: AttentionType,
*,
rngs: nnx.Rngs,
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
sow_config: sow_lib.SowConfig = sow_lib.SowConfig()
Expand All @@ -92,6 +95,7 @@ def __init__(
'`sliding_window_size` must be set if `attn_type` is Local Sliding.'
)

self.query_pre_attn_scalar = query_pre_attn_scalar
self.attn_type = attn_type
self.sliding_window_size = sliding_window_size
self.attn_logits_soft_cap = attn_logits_soft_cap
Expand All @@ -100,6 +104,7 @@ def __init__(
shape=(num_heads, head_dim, features),
rngs=rngs,
)
self.rope_base_frequency = rope_base_frequency
self.sow_config = sow_config

if num_heads == num_kv_heads:
Expand Down Expand Up @@ -139,12 +144,14 @@ def __call__(
query_proj,
segment_pos,
head_dim=self.head_dim,
max_wavelength=self.rope_base_frequency,
)
query_scaled = query_proj * self.head_dim**-0.5
query_scaled = query_proj * self.query_pre_attn_scalar
key_proj = positional_embeddings.apply_rope(
key_proj,
segment_pos,
head_dim=self.head_dim,
max_wavelength=self.rope_base_frequency,
)

# Cache is left aligned.
Expand Down Expand Up @@ -253,21 +260,21 @@ def __init__(
out_features=hidden_dim,
use_bias=False,
rngs=rngs,
kernel_init=nn.initializers.zeros_init(),
kernel_init=nn.initializers.normal(),
)
self.up_proj = nnx.Linear(
in_features=features,
out_features=hidden_dim,
use_bias=False,
rngs=rngs,
kernel_init=nn.initializers.zeros_init(),
kernel_init=nn.initializers.normal(),
)
self.down_proj = nnx.Linear(
in_features=hidden_dim,
out_features=features,
use_bias=False,
rngs=rngs,
kernel_init=nn.initializers.zeros_init(),
kernel_init=nn.initializers.normal(),
)
self.sow_config = sow_config

Expand Down Expand Up @@ -295,9 +302,11 @@ def __init__(
hidden_dim: int,
use_post_attn_norm: bool,
use_post_ffw_norm: bool,
query_pre_attn_scalar: float,
attn_type: AttentionType,
*,
rngs: nnx.Rngs,
rope_base_frequency: int = DEFAULT_ROPE_BASE_FREQUENCY,
attn_logits_soft_cap: float | None = None,
sliding_window_size: int | None = None,
sow_config: sow_lib.SowConfig = sow_lib.SowConfig()
Expand All @@ -308,14 +317,17 @@ def __init__(
num_kv_heads=num_kv_heads,
features=embed_dim,
head_dim=head_dim,
query_pre_attn_scalar=query_pre_attn_scalar,
attn_type=attn_type,
rope_base_frequency=rope_base_frequency,
attn_logits_soft_cap=attn_logits_soft_cap,
sliding_window_size=sliding_window_size,
rngs=rngs,
sow_config=sow_config,
)
if use_post_attn_norm:
self.post_attn_norm = layers.RMSNorm(embed_dim, rngs=rngs)
self.use_post_attn_norm = use_post_attn_norm
if self.use_post_attn_norm:
self.post_attention_norm = layers.RMSNorm(embed_dim, rngs=rngs)

self.pre_ffw_norm = layers.RMSNorm(embed_dim, rngs=rngs)
self.mlp = FeedForward(
Expand All @@ -324,7 +336,8 @@ def __init__(
rngs=rngs,
sow_config=sow_config,
)
if use_post_ffw_norm:
self.use_post_ffw_norm = use_post_ffw_norm
if self.use_post_ffw_norm:
self.post_ffw_norm = layers.RMSNorm(embed_dim, rngs=rngs)
self.sow_config = sow_config

Expand All @@ -335,35 +348,29 @@ def __call__(
cache: LayerCache | None,
attn_mask: jax.Array,
) -> tuple[LayerCache | None, jax.Array]:
inputs_normalized = self.pre_attention_norm(x)

# Attention.
attn_inputs = self.pre_attention_norm(x)
cache, attn_output = self.attn(
inputs_normalized,
attn_inputs,
segment_pos,
cache,
attn_mask,
)
attn_output += x
residual = attn_output
attn_output = self.pre_ffw_norm(attn_output)

if self.use_post_attn_norm:
attn_output = self.post_attn_norm(attn_output)
self.sow_config.maybe_sow_rs_after_attention(attn_output, self)
attn_output = self.post_attention_norm(attn_output)
x += attn_output
self.sow_config.maybe_sow_rs_after_attention(x, self)

outputs = self.mlp(attn_output)
# Feed forward.
ffw_inputs = self.pre_ffw_norm(x)
ffw_outputs = self.mlp(ffw_inputs)
if self.use_post_ffw_norm:
outputs = self.post_ffw_norm(outputs)
outputs = residual + outputs
self.sow_config.maybe_sow_rs_after_ffw(outputs, self)
return cache, outputs
ffw_outputs = self.post_ffw_norm(ffw_outputs)
x += ffw_outputs
self.sow_config.maybe_sow_rs_after_ffw(x, self)

@property
def use_post_attn_norm(self):
return hasattr(self, 'post_attn_norm') and self.post_attn_norm is not None

@property
def use_post_ffw_norm(self):
return hasattr(self, 'post_ffw_norm') and self.post_ffw_norm is not None
return cache, x

def init_cache(
self,
Expand Down
44 changes: 27 additions & 17 deletions examples/gemma/modules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def test_head_dim(self, head_dim):
num_kv_heads=4,
features=5,
head_dim=head_dim,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -107,6 +108,7 @@ def test_use_qkv_einsum(
num_kv_heads=num_kv_heads,
features=5,
head_dim=8,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -144,7 +146,8 @@ def test_attention(
num_heads,
features,
head_dim,
modules.AttentionType.GLOBAL,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
cache = attn.init_cache(
Expand Down Expand Up @@ -177,7 +180,8 @@ def test_sliding_window(self, sliding_window_size):
num_heads,
features,
head_dim,
modules.AttentionType.GLOBAL,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
cache = attn.init_cache(
Expand All @@ -191,7 +195,8 @@ def test_sliding_window(self, sliding_window_size):
num_heads,
features,
head_dim,
modules.AttentionType.LOCAL_SLIDING,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.LOCAL_SLIDING,
sliding_window_size=sliding_window_size,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -272,6 +277,7 @@ def test_block(
1,
use_post_attn_norm,
use_post_ffw_norm,
1.0,
modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
Expand Down Expand Up @@ -313,9 +319,10 @@ def test_post_attention_norm(
embed_dim,
head_dim,
1,
True,
False, # use_post_ffw_norm
modules.AttentionType.GLOBAL,
use_post_attn_norm=True,
use_post_ffw_norm=False,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
unnormed_block = modules.Block(
Expand All @@ -324,9 +331,10 @@ def test_post_attention_norm(
embed_dim,
head_dim,
1,
False,
False, # use_post_ffw_norm
modules.AttentionType.GLOBAL,
use_post_attn_norm=False,
use_post_ffw_norm=False,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)

Expand All @@ -343,7 +351,7 @@ def test_post_attention_norm(
all_outputs.append(outputs)

normed_output, unnormed_output = all_outputs # pylint: disable=unbalanced-tuple-unpacking
self.assertFalse(jnp.not_equal(normed_output, unnormed_output).all())
self.assertTrue(jnp.not_equal(normed_output, unnormed_output).all())

@parameterized.parameters(
dict(
Expand Down Expand Up @@ -371,9 +379,10 @@ def test_post_ffw_norm(
embed_dim,
head_dim,
1,
True,
True, # use_post_ffw_norm
modules.AttentionType.GLOBAL,
use_post_attn_norm=False,
use_post_ffw_norm=True,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)
unnormed_block = modules.Block(
Expand All @@ -382,9 +391,10 @@ def test_post_ffw_norm(
embed_dim,
head_dim,
1,
False,
False, # use_post_ffw_norm
modules.AttentionType.GLOBAL,
use_post_attn_norm=False,
use_post_ffw_norm=False,
query_pre_attn_scalar=1.0,
attn_type=modules.AttentionType.GLOBAL,
rngs=nnx.Rngs(params=0),
)

Expand All @@ -401,7 +411,7 @@ def test_post_ffw_norm(
all_outputs.append(outputs)

normed_output, unnormed_output = all_outputs # pylint: disable=unbalanced-tuple-unpacking
self.assertFalse(jnp.not_equal(normed_output, unnormed_output).all())
self.assertTrue(jnp.not_equal(normed_output, unnormed_output).all())


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion examples/gemma/sow_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def merge(self, decoding_step, layer: nnx.Module):
value = getattr(self, field.name)
if value is None:
continue
# We but mlp and attn intermediates into this class without any further
# We put mlp and attn intermediates into this class without any further
# nesting. So we have to retrieve the intermediates from the correct
# sub-module.
try:
Expand Down
Loading

0 comments on commit 787a820

Please sign in to comment.