diff --git a/flax/nnx/bridge/__init__.py b/flax/nnx/bridge/__init__.py index 880dc1bf..cb5d8fb6 100644 --- a/flax/nnx/bridge/__init__.py +++ b/flax/nnx/bridge/__init__.py @@ -26,5 +26,6 @@ from .module import compact as compact from .module import current_context as current_context from .module import current_module as current_module -from .interop import wrap_nnx_mdl as wrap_nnx_mdl +from .interop import nnx_in_bridge_mdl as nnx_in_bridge_mdl +from .interop import linen_in_bridge_mdl as linen_in_bridge_mdl from flax.nnx.nn import initializers as initializers diff --git a/flax/nnx/bridge/interop.py b/flax/nnx/bridge/interop.py index 7b1695a2..4029b0bd 100644 --- a/flax/nnx/bridge/interop.py +++ b/flax/nnx/bridge/interop.py @@ -14,18 +14,32 @@ import typing as tp +from flax.linen import module as nn_module from flax.nnx import graph, rnglib +from flax.nnx.bridge import wrappers from flax.nnx.bridge import module as bdg_module import flax.nnx.module as nnx_module from flax.nnx.transforms.transforms import eval_shape as nnx_eval_shape from flax.nnx.transforms.compilation import jit as nnx_jit -def wrap_nnx_mdl(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module], - name: str | None = None): - """Create module at init time, or make abstract module and let parent bind it with its state. Use current bridge module scope for RNG generation.""" +def nnx_in_bridge_mdl(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module], + name: str | None = None) -> nnx_module.Module: + """Make pure NNX modules a submodule of a bridge module. + + Create module at init time, or make abstract module and let parent bind + it with its state. + Use current bridge module scope for RNG generation. + + Args: + factory: a function that takes an `nnx.Rngs` arg and returns an NNX module. + name: the name of the module. Only used during `bridge.compact` functions; + in setup() function the user will set it to an attribute explicitly. + Returns: + A submodule (`nnx.Module`) of the bridge module. + """ parent_ctx, parent = bdg_module.current_context(), bdg_module.current_module() - assert parent_ctx is not None and parent is not None, 'wrap_nnx_mdl only needed inside bridge Module' + assert parent_ctx is not None and parent is not None, 'nnx_in_bridge_mdl() only needed inside bridge Module' parent = parent_ctx.module assert parent.scope is not None @@ -50,3 +64,26 @@ def rng_state(rngs): name = bdg_module._auto_submodule_name(parent_ctx, type(module)) setattr(parent, name, module) return module + + +def linen_in_bridge_mdl(linen_module: nn_module.Module, + name: str | None = None) -> nnx_module.Module: + """Make Linen modules a submodule of a bridge module using wrappers.ToNNX(). + + Args: + linen_module: the underlying Linen module instance. + name: the name of the module. Only used during `bridge.compact` functions; + in setup() function the user will set it to an attribute explicitly. + Returns: + A submodule (`nnx.Module`) of the bridge module. + """ + parent_ctx, parent = bdg_module.current_context(), bdg_module.current_module() + assert parent_ctx is not None and parent is not None, 'linen_in_bridge_mdl() only needed inside bridge Module' + assert parent.scope is not None + module = wrappers.ToNNX(linen_module, parent.scope.rngs) + wrappers._set_initializing(module, parent.is_initializing()) + if parent_ctx.in_compact: + if name is None: + name = bdg_module._auto_submodule_name(parent_ctx, type(linen_module)) + setattr(parent, name, module) + return module diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py index 3f9b5d32..772fd5db 100644 --- a/flax/nnx/bridge/module.py +++ b/flax/nnx/bridge/module.py @@ -25,6 +25,7 @@ from flax import errors from flax.core import meta +from flax.core.scope import CollectionFilter from flax.core.frozen_dict import FrozenDict from flax.nnx import graph, rnglib, statelib, traversals from flax.nnx import variablelib @@ -63,11 +64,12 @@ class ModuleState(statelib.State): class Scope(Object): - def __init__(self, rngs: rnglib.Rngs): + def __init__(self, rngs: rnglib.Rngs, mutable: CollectionFilter): self.rngs = rngs + self.mutable = mutable def copy(self): - return Scope(self.rngs) + return Scope(self.rngs, self.mutable) class _HasSetup(tp.Protocol): @@ -365,7 +367,7 @@ def apply( *args, rngs: int | jax.Array | dict[str, jax.Array] | rnglib.Rngs | None = None, method: tp.Callable[..., tp.Any] | str = '__call__', - mutable: tp.Any = False, + mutable: CollectionFilter = False, _initialize: bool = False, **kwargs, ) -> tp.Any: @@ -422,7 +424,7 @@ def to_variable(value): if isinstance(value, Object): value._object__state._initializing = _initialize if isinstance(value, Module): - value.scope = Scope(rngs) + value.scope = Scope(rngs, mutable) _maybe_call_setup(value) MODULE_CONTEXT.module_stack.append( @@ -517,3 +519,4 @@ def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable: raise errors.ApplyModuleInvalidMethodError(method_or_fn) return method_or_fn + diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index 7fd65106..597a112b 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -23,6 +23,7 @@ from flax.nnx import graph from flax.nnx import variablelib from flax.nnx.bridge import variables as bv +from flax.nnx.bridge import module as bdg_module from flax.nnx.module import Module from flax.nnx.object import Object from flax.nnx.rnglib import Rngs @@ -124,7 +125,6 @@ def __init__( ): self.module = module self.rngs = rngs - self.linen_attributes: tuple[str, ...] = () def lazy_init(self, *args, **kwargs): """A shortcut of calling `nnx.bridge.lazy_init()` upon this module.""" @@ -140,9 +140,7 @@ def __call__( rngs = self.rngs if self._object__state.initializing: _rngs = ( - {name: stream.key.raw_value for name, stream in rngs.items()} - if rngs - else {} + {name: stream() for name, stream in rngs.items()} if rngs else {} ) # rename default to params if 'params' not in _rngs and 'default' in _rngs: @@ -150,36 +148,41 @@ def __call__( out, variables = self.module.init_with_output(_rngs, *args, method=method, **kwargs) nnx_attrs = bv.linen_vars_to_nnx_attrs(variables) - linen_attributes = set(self.linen_attributes) for attr_name, value in nnx_attrs.items(): setattr(self, attr_name, value) - linen_attributes.add(attr_name) - self.linen_attributes = tuple(linen_attributes) # make it hashable else: - nnx_attrs = {name: getattr(self, name) for name in self.linen_attributes} + nnx_attrs = {k: v for k, v in vars(self).items() + if k not in ['module', 'rngs', '_object__state']} variables = bv.nnx_attrs_to_linen_vars(nnx_attrs) _rngs = ( {name: stream() for name, stream in rngs.items()} if rngs else {} ) + + # Get `mutable` from top level bridge.Module context if any + if (m := bdg_module.current_module()) is not None: + assert m.scope is not None + mutable = m.scope.mutable + if 'mutable' in kwargs and kwargs['mutable'] != mutable: + raise ValueError( + f"Multiple `mutable` arguments detected: {mutable} at top level vs " + f"{kwargs['mutable']} in ToNNX() call") + kwargs['mutable'] = mutable + out = self.module.apply(variables, *args, rngs=_rngs, method=method, **kwargs) # Split out the updates if `mutable` is passed into the Flax module if kwargs.get('mutable', False) != False: out, updates = out nnx_attrs = bv.linen_vars_to_nnx_attrs(updates) - linen_attributes = set(self.linen_attributes) for attr_name, value in nnx_attrs.items(): - linen_attributes.add(attr_name) if hasattr(self, attr_name) and isinstance(value, dict): original_tree = getattr(self, attr_name) setattr(self, attr_name, original_tree | value) else: setattr(self, attr_name, value) - self.linen_attributes = tuple(linen_attributes) # make it hashable - return out diff --git a/tests/nnx/bridge/module_test.py b/tests/nnx/bridge/module_test.py index 01170a09..a52b8cbb 100644 --- a/tests/nnx/bridge/module_test.py +++ b/tests/nnx/bridge/module_test.py @@ -319,8 +319,8 @@ def __call__(self, x): class BridgeMLP(bridge.Module): @bridge.compact def __call__(self, x): - x = nnx.bridge.wrap_nnx_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))(x) - x = nnx.bridge.wrap_nnx_mdl( + x = bridge.nnx_in_bridge_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))(x) + x = bridge.nnx_in_bridge_mdl( lambda r: NNXLayer(8, 0.3, rngs=r), name='another')(x) return x @@ -345,7 +345,8 @@ def __call__(self, x): class BridgeMLPSetup(bridge.Module): def setup(self): - self.layer = nnx.bridge.wrap_nnx_mdl(lambda r: NNXLayer(8, 0.3, rngs=r)) + self.layer = bridge.nnx_in_bridge_mdl( + lambda r: NNXLayer(8, 0.3, rngs=r)) def __call__(self, x): return self.layer(x) @@ -371,13 +372,67 @@ def generate_weights(r): class BridgeFoo(bridge.Module): @bridge.compact def __call__(self, x): - x = nnx.bridge.wrap_nnx_mdl(lambda r: FooStack(4, r.default()))(x) + x = bridge.nnx_in_bridge_mdl(lambda r: FooStack(4, r.default()))(x) return x model = BridgeFoo() v = model.init(jax.random.key(1), jnp.ones((1, 4))) y = model.apply(v, jnp.ones((1, 4)), rngs=jax.random.key(1)) + def test_linen_submodule(self): + class LinenLayer(nn.Module): + dim: int + dropout_rate: float + def setup(self): + self.linear = nn.Dense(self.dim, use_bias=False) + self.dropout = nn.Dropout(self.dropout_rate, deterministic=False) + + def __call__(self, x): + if not self.is_initializing(): + self.sow('intermediates', 'count', 1, + init_fn=lambda: 0, reduce_fn=lambda a, b: a + b) + x = self.linear(x) + x = self.dropout(x) + return x + + class BridgeMLP(bridge.Module): + @bridge.compact + def __call__(self, x): + x = bridge.linen_in_bridge_mdl(LinenLayer(8, 0.3))(x) + x = bridge.linen_in_bridge_mdl(LinenLayer(8, 0.3), name='another')(x) + return x + + model = BridgeMLP() + x = jax.random.normal(jax.random.key(0), (4, 8)) + variables = model.init(jax.random.key(1), x) + self.assertFalse(jnp.array_equal( + variables['params']['LinenLayer_0']['linear']['kernel'], + variables['params']['another']['linear']['kernel'], )) + + k1, k2, k3 = jax.random.split(jax.random.key(0), 3) + y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2}) + y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}) + assert not jnp.array_equal(y1, y2) + + _, updates = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}, + mutable=True) + self.assertEqual(updates['intermediates']['LinenLayer_0']['count'], 1) + + class BridgeMLPSetup(bridge.Module): + def setup(self): + self.layer = bridge.linen_in_bridge_mdl(LinenLayer(8, 0.3)) + def __call__(self, x): + return self.layer(x) + + model = BridgeMLPSetup() + variables = model.init(jax.random.key(1), x) + self.assertSameElements(variables['params'].keys(), ['layer']) + y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2}) + y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}) + assert not jnp.array_equal(y1, y2) + + + if __name__ == '__main__': absltest.main() diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py index d142d839..cbe4b619 100644 --- a/tests/nnx/bridge/wrappers_test.py +++ b/tests/nnx/bridge/wrappers_test.py @@ -54,9 +54,10 @@ def test_linen_to_nnx(self): assert y.shape == (1, 64) self.assertIsInstance(model.kernel, nnx.Variable) # NNX automatically adds metadata box regardless of original Linen module. - linen_vars = linen_module.init(jax.random.key(0), x) - np.testing.assert_array_equal(linen_vars['params']['kernel'], - model.kernel.value) + linen_vars = {'params': {'kernel': model.kernel.value, + 'bias': model.bias.value}} + linen_y = linen_module.apply(linen_vars, x) + np.testing.assert_array_equal(y, linen_y) def test_linen_to_nnx_submodule(self): class NNXOuter(nnx.Module): @@ -468,10 +469,11 @@ def __call__(self, x): # Test the RNG model = bridge.lazy_init(NNXOuter(dout=3, dropout_rate=0.5, rngs=nnx.Rngs(default=1, dropout=2)), x) + nnx.reseed(model, dropout=2) y1, y2 = model(x), model(x) # The dropout key of lowest NNX level still changes over stateful calls assert not jnp.allclose(y1, y2) - # Reseed resets the RNG key back + # Another reseed resets the RNG key back nnx.reseed(model, dropout=2) np.testing.assert_array_equal(y1, model(x))