Skip to content

Commit

Permalink
[bridge] Set _initializing correctly and avoid return RNG states
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 731073396
  • Loading branch information
IvyZX authored and Flax Authors committed Feb 26, 2025
1 parent d96be6c commit 5413850
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions flax/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def _setattr(self, name: str, value: tp.Any) -> None:
graph.update(value, state)
for leaf in jax.tree.leaves(value):
if isinstance(leaf, Module):
leaf._object__state._initializing = self.is_initializing()
_bind_module(self, leaf)
super()._setattr(name, value)

Expand Down Expand Up @@ -308,6 +309,11 @@ def _get_variables(self) -> tp.Mapping:

variable_state: variablelib.VariableState
for path, variable_state in statelib.to_flat_state(state):

if issubclass(variable_state.type, rnglib.RngState):
# Don't return RNG states, since Linen doesn't have them.
continue

try:
collection = variablelib.variable_name_from_type(variable_state.type)
except ValueError:
Expand Down

0 comments on commit 5413850

Please sign in to comment.