diff --git a/examples/gemma/sow_lib.py b/examples/gemma/sow_lib.py index 2a9c1301..9cb408c1 100644 --- a/examples/gemma/sow_lib.py +++ b/examples/gemma/sow_lib.py @@ -41,7 +41,7 @@ def merge(self, decoding_step, layer: nnx.Module): value = getattr(self, field.name) if value is None: continue - # We but mlp and attn intermediates into this class without any further + # We put mlp and attn intermediates into this class without any further # nesting. So we have to retrieve the intermediates from the correct # sub-module. try: diff --git a/examples/gemma/transformer.py b/examples/gemma/transformer.py index ce9eed31..667280a4 100644 --- a/examples/gemma/transformer.py +++ b/examples/gemma/transformer.py @@ -249,7 +249,10 @@ class Transformer(nnx.Module): @classmethod def from_params( - cls, params: params_lib.Params, config: None | TransformerConfig = None + cls, + params: params_lib.Params, + config: None | TransformerConfig = None, + sow_config: sow_lib.SowConfig = sow_lib.SowConfig(), ) -> Transformer: if config is None: config = TransformerConfig.from_params(params) @@ -258,7 +261,9 @@ def from_params( transpose_gating_einsum=config.transpose_gating_einsum, ) return helpers.module_from_linen_variables( - module_factory=lambda: cls(config, rngs=nnx.Rngs(params=0)), + module_factory=lambda: cls( + config, rngs=nnx.Rngs(params=0), sow_config=sow_config + ), variables=params['transformer'], map_key_fn=_map_linen_var_names, assign_val_fn=assign_val_fn,