Skip to content

Commit

Permalink
Add Sow Config to from_params constructor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 733277422
  • Loading branch information
casaro authored and Flax Authors committed Mar 6, 2025
1 parent b28822f commit 9d7f7e4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion examples/gemma/sow_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def merge(self, decoding_step, layer: nnx.Module):
value = getattr(self, field.name)
if value is None:
continue
# We but mlp and attn intermediates into this class without any further
# We put mlp and attn intermediates into this class without any further
# nesting. So we have to retrieve the intermediates from the correct
# sub-module.
try:
Expand Down
9 changes: 7 additions & 2 deletions examples/gemma/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ class Transformer(nnx.Module):

@classmethod
def from_params(
cls, params: params_lib.Params, config: None | TransformerConfig = None
cls,
params: params_lib.Params,
config: None | TransformerConfig = None,
sow_config: sow_lib.SowConfig = sow_lib.SowConfig(),
) -> Transformer:
if config is None:
config = TransformerConfig.from_params(params)
Expand All @@ -258,7 +261,9 @@ def from_params(
transpose_gating_einsum=config.transpose_gating_einsum,
)
return helpers.module_from_linen_variables(
module_factory=lambda: cls(config, rngs=nnx.Rngs(params=0)),
module_factory=lambda: cls(
config, rngs=nnx.Rngs(params=0), sow_config=sow_config
),
variables=params['transformer'],
map_key_fn=_map_linen_var_names,
assign_val_fn=assign_val_fn,
Expand Down

0 comments on commit 9d7f7e4

Please sign in to comment.