From 646354575bc7c08f16f95803f225adb17b670172 Mon Sep 17 00:00:00 2001 From: Sascha Rothe Date: Thu, 6 Mar 2025 12:37:11 -0800 Subject: [PATCH] Add Sow Config to from_params constructor. PiperOrigin-RevId: 734247678 --- examples/gemma/sow_lib.py | 2 +- examples/gemma/transformer.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/gemma/sow_lib.py b/examples/gemma/sow_lib.py index 2a9c13018..9cb408c10 100644 --- a/examples/gemma/sow_lib.py +++ b/examples/gemma/sow_lib.py @@ -41,7 +41,7 @@ def merge(self, decoding_step, layer: nnx.Module): value = getattr(self, field.name) if value is None: continue - # We but mlp and attn intermediates into this class without any further + # We put mlp and attn intermediates into this class without any further # nesting. So we have to retrieve the intermediates from the correct # sub-module. try: diff --git a/examples/gemma/transformer.py b/examples/gemma/transformer.py index ce9eed319..667280a49 100644 --- a/examples/gemma/transformer.py +++ b/examples/gemma/transformer.py @@ -249,7 +249,10 @@ class Transformer(nnx.Module): @classmethod def from_params( - cls, params: params_lib.Params, config: None | TransformerConfig = None + cls, + params: params_lib.Params, + config: None | TransformerConfig = None, + sow_config: sow_lib.SowConfig = sow_lib.SowConfig(), ) -> Transformer: if config is None: config = TransformerConfig.from_params(params) @@ -258,7 +261,9 @@ def from_params( transpose_gating_einsum=config.transpose_gating_einsum, ) return helpers.module_from_linen_variables( - module_factory=lambda: cls(config, rngs=nnx.Rngs(params=0)), + module_factory=lambda: cls( + config, rngs=nnx.Rngs(params=0), sow_config=sow_config + ), variables=params['transformer'], map_key_fn=_map_linen_var_names, assign_val_fn=assign_val_fn,