diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py index 377d9f29..13d0ac34 100644 --- a/flax/nnx/bridge/module.py +++ b/flax/nnx/bridge/module.py @@ -190,6 +190,7 @@ def _setattr(self, name: str, value: tp.Any) -> None: graph.update(value, state) for leaf in jax.tree.leaves(value): if isinstance(leaf, Module): + leaf._object__state._initializing = self.is_initializing() _bind_module(self, leaf) super()._setattr(name, value) @@ -308,6 +309,11 @@ def _get_variables(self) -> tp.Mapping: variable_state: variablelib.VariableState for path, variable_state in statelib.to_flat_state(state): + + if issubclass(variable_state.type, rnglib.RngState): + # Don't return RNG states, since Linen doesn't have them. + continue + try: collection = variablelib.variable_name_from_type(variable_state.type) except ValueError: