From 23a2c187e12ef1de06db0dc840ced1b2a0f4461d Mon Sep 17 00:00:00 2001 From: IvyZX Date: Thu, 27 Feb 2025 14:25:26 -0800 Subject: [PATCH] bridge module with nnx submodules --- flax/nnx/bridge/__init__.py | 4 +++- flax/nnx/bridge/module.py | 26 ++++++++++++++++++++ tests/nnx/bridge/module_test.py | 42 +++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/flax/nnx/bridge/__init__.py b/flax/nnx/bridge/__init__.py index 867ffae6..fcf05cc8 100644 --- a/flax/nnx/bridge/__init__.py +++ b/flax/nnx/bridge/__init__.py @@ -23,4 +23,6 @@ from .variables import with_partitioning as with_partitioning from .module import Module as Module from .module import Scope as Scope -from .module import compact as compact \ No newline at end of file +from .module import compact as compact +from .module import current_module as current_module +from .module import wrap_nnx_module as wrap_nnx_module \ No newline at end of file diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py index 13d0ac34..2a4c89ee 100644 --- a/flax/nnx/bridge/module.py +++ b/flax/nnx/bridge/module.py @@ -32,6 +32,7 @@ from flax.nnx.object import Object from flax.nnx import variablelib from flax.nnx.bridge import variables as bridge_variables +from flax.nnx.transforms import transforms import jax.numpy as jnp A = tp.TypeVar('A') @@ -117,6 +118,14 @@ def _object_meta_construct(cls, self, *args, **kwargs): super()._object_meta_construct(self, *args, **kwargs) +def current_module() -> tp.Optional[Module]: + """A quick util to get the current bridge module.""" + ctx = MODULE_CONTEXT.module_stack[-1] + if ctx is None: + return None + return ctx.module + + def _module_meta_call(cls: type[M], *args, **kwargs) -> M: # compact behavior parent_ctx = MODULE_CONTEXT.module_stack[-1] @@ -502,3 +511,20 @@ def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable: raise errors.ApplyModuleInvalidMethodError(method_or_fn) return method_or_fn + + +def wrap_nnx_module(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module]): + """Create module at init time, or make abstract module and let parent bind it with its state. Use current bridge module scope for RNG generation.""" + parent = current_module() + assert parent is not None, 'wrap_nnx_module only needed inside bridge Module' + if parent.is_initializing(): + return factory(parent.scope.rngs) + rngs = parent.scope.rngs if parent.scope.rngs else rnglib.Rngs(7) # dummy + abs_module = transforms.eval_shape(lambda: factory(rngs)) + # Make sure the internal rng state is not abstract - everything else shall be + for _, stream in graph.iter_graph(abs_module): + if isinstance(stream, rnglib.RngStream): + parent_stream = rngs[stream.key.tag] + stream.key.value = parent_stream.key.value + stream.count.value = parent_stream.count.value + return abs_module \ No newline at end of file diff --git a/tests/nnx/bridge/module_test.py b/tests/nnx/bridge/module_test.py index 3c6670c6..a96c6046 100644 --- a/tests/nnx/bridge/module_test.py +++ b/tests/nnx/bridge/module_test.py @@ -302,6 +302,48 @@ def __call__(self, x): y: jax.Array = foo.apply(variables, x) self.assertEqual(y.shape, (3, 5)) + def test_with_pure_nnx(self): + class NNXLayer(nnx.Module): + def __init__(self, dim, dropout, rngs): + self.linear = nnx.Linear(dim, dim, use_bias=False, rngs=rngs) + self.dropout = nnx.Dropout(dropout, rngs=rngs) + self.count = nnx.Intermediate(jnp.array([0.])) + def __call__(self, x): + # Required check to avoid state update in `init()`. Can this be avoided? + if not bridge.current_module().is_initializing(): + self.count.value = self.count.value + 1 + x = self.linear(x) + x = self.dropout(x) + return x + + class BridgeMLP(bridge.Module): + num_layers: int + @bridge.compact + def __call__(self, x): + for i in range(self.num_layers): + layer = nnx.bridge.wrap_nnx_module(lambda r: NNXLayer(8, 0.3, rngs=r)) + setattr(self, f'layer{i}', layer) + x = layer(x) + return x + + model = BridgeMLP(2) + x = jax.random.normal(jax.random.key(0), (4, 8)) + variables = model.init(jax.random.key(1), x) + self.assertSameElements(variables['params'].keys(), ['layer0', 'layer1']) + self.assertFalse(jnp.array_equal( + variables['params']['layer0']['linear']['kernel'], + variables['params']['layer1']['linear']['kernel'], )) + self.assertEqual(variables['intermediates']['layer1']['count'], 0) + + k1, k2, k3 = jax.random.split(jax.random.key(0), 3) + y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2}) + y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}) + assert not jnp.array_equal(y1, y2) + + _, updates = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}, + mutable=True) + self.assertEqual(updates['intermediates']['layer1']['count'], 1) + if __name__ == '__main__': absltest.main()