From 541385015d1925f50e47530ed838a485c315d58d Mon Sep 17 00:00:00 2001 From: Ivy Zheng Date: Tue, 25 Feb 2025 16:33:53 -0800 Subject: [PATCH] [bridge] Set _initializing correctly and avoid return RNG states PiperOrigin-RevId: 731073396 --- flax/nnx/bridge/module.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py index 377d9f29..13d0ac34 100644 --- a/flax/nnx/bridge/module.py +++ b/flax/nnx/bridge/module.py @@ -190,6 +190,7 @@ def _setattr(self, name: str, value: tp.Any) -> None: graph.update(value, state) for leaf in jax.tree.leaves(value): if isinstance(leaf, Module): + leaf._object__state._initializing = self.is_initializing() _bind_module(self, leaf) super()._setattr(name, value) @@ -308,6 +309,11 @@ def _get_variables(self) -> tp.Mapping: variable_state: variablelib.VariableState for path, variable_state in statelib.to_flat_state(state): + + if issubclass(variable_state.type, rnglib.RngState): + # Don't return RNG states, since Linen doesn't have them. + continue + try: collection = variablelib.variable_name_from_type(variable_state.type) except ValueError: