From b17442728c42ce2f7ae65340abae9cd77c00eecd Mon Sep 17 00:00:00 2001 From: IvyZX Date: Tue, 4 Mar 2025 18:05:28 -0800 Subject: [PATCH] bridge module with nnx submodules --- flax/nnx/bridge/__init__.py | 5 ++- flax/nnx/bridge/module.py | 54 +++++++++++++++++++++++++++++--- tests/nnx/bridge/module_test.py | 55 +++++++++++++++++++++++++++++++++ 3 files changed, 109 insertions(+), 5 deletions(-) diff --git a/flax/nnx/bridge/__init__.py b/flax/nnx/bridge/__init__.py index 4f722b89..60db4e70 100644 --- a/flax/nnx/bridge/__init__.py +++ b/flax/nnx/bridge/__init__.py @@ -24,4 +24,7 @@ from .module import Module as Module from .module import Scope as Scope from .module import compact as compact -from flax.nnx.nn import initializers as initializers \ No newline at end of file +from .module import current_context as current_context +from .module import current_module as current_module +from .module import wrap_nnx_mdl as wrap_nnx_mdl +from flax.nnx.nn import initializers as initializers diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py index 6be25b07..72000e11 100644 --- a/flax/nnx/bridge/module.py +++ b/flax/nnx/bridge/module.py @@ -32,6 +32,7 @@ from flax.nnx.object import Object from flax.nnx import variablelib from flax.nnx.bridge import variables as bridge_variables +from flax.nnx.transforms import transforms import numpy as np A = tp.TypeVar('A') @@ -106,6 +107,25 @@ def _bind_module(parent: Module, module: Module) -> Module: return module +def current_context() -> tp.Optional[ModuleStackEntry]: + return MODULE_CONTEXT.module_stack[-1] + + +def current_module() -> tp.Optional[Module]: + """A quick util to get the current bridge module.""" + ctx = current_context() + if ctx is None: + return None + return ctx.module + + +def _auto_submodule_name(parent_ctx, cls): + """Increment type count and generate a new submodule name.""" + type_index = parent_ctx.type_counter[cls] + parent_ctx.type_counter[cls] += 1 + return f'{cls.__name__}_{type_index}' + + class ModuleMeta(nnx_module.ModuleMeta): def _object_meta_construct(cls, self, *args, **kwargs): @@ -134,15 +154,12 @@ def _module_meta_call(cls: type[M], *args, **kwargs) -> M: f"'parent' can only be set to None, got {type(parent).__name__}" ) else: - type_index = parent_ctx.type_counter[cls] - parent_ctx.type_counter[cls] += 1 - if 'name' in kwargs: name = kwargs.pop('name') if not isinstance(name, str): raise ValueError(f"'name' must be a 'str', got {type(name).__name__}") else: - name = f'{cls.__name__}_{type_index}' + name = _auto_submodule_name(parent_ctx, cls) parent = parent_ctx.module module = nnx_module.ModuleMeta.__call__(cls, *args, **kwargs) @@ -501,3 +518,32 @@ def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable: raise errors.ApplyModuleInvalidMethodError(method_or_fn) return method_or_fn + + +def wrap_nnx_mdl(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module], + name: str | None = None): + """Create module at init time, or make abstract module and let parent bind it with its state. Use current bridge module scope for RNG generation.""" + parent_ctx, parent = current_context(), current_module() + assert parent_ctx is not None and parent_ctx.module is not None, 'wrap_nnx_mdl only needed inside bridge Module' + parent = parent_ctx.module + assert parent.scope is not None + + if parent.is_initializing(): + module = factory(parent.scope.rngs) + else: + rngs = parent.scope.rngs if parent.scope.rngs else rnglib.Rngs(7) # dummy + module = transforms.eval_shape(factory, rngs) + # Make sure the internal rng state is not abstract - other vars shall be + if parent.scope.rngs: + for _, stream in graph.iter_graph(module): + if isinstance(stream, rnglib.RngStream): + stream.key.value = rngs[stream.key.tag].key.value + stream.count.value = rngs[stream.key.tag].count.value + + # Automatically set the attribute if compact. If setup, user is responsible + # for setting the attribute of the superlayer. + if parent_ctx.in_compact: + if name is None: + name = _auto_submodule_name(parent_ctx, type(module)) + setattr(parent, name, module) + return module \ No newline at end of file diff --git a/tests/nnx/bridge/module_test.py b/tests/nnx/bridge/module_test.py index 3c6670c6..c8c19262 100644 --- a/tests/nnx/bridge/module_test.py +++ b/tests/nnx/bridge/module_test.py @@ -302,6 +302,61 @@ def __call__(self, x): y: jax.Array = foo.apply(variables, x) self.assertEqual(y.shape, (3, 5)) + def test_with_pure_nnx(self): + class NNXLayer(nnx.Module): + def __init__(self, dim, dropout, rngs): + self.linear = nnx.Linear(dim, dim, use_bias=False, rngs=rngs) + self.dropout = nnx.Dropout(dropout, rngs=rngs) + self.count = nnx.Intermediate(jnp.array([0.])) + def __call__(self, x): + # Required check to avoid state update in `init()`. Can this be avoided? + if not bridge.current_module().is_initializing(): + self.count.value = self.count.value + 1 + x = self.linear(x) + x = self.dropout(x) + return x + + class BridgeMLP(bridge.Module): + @bridge.compact + def __call__(self, x): + x = nnx.bridge.wrap_nnx_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))(x) + x = nnx.bridge.wrap_nnx_mdl( + lambda r: NNXLayer(8, 0.3, rngs=r), name='another')(x) + return x + + model = BridgeMLP() + x = jax.random.normal(jax.random.key(0), (4, 8)) + variables = model.init(jax.random.key(1), x) + self.assertSameElements(variables['params'].keys(), + ['NNXLayer_0', 'another']) + self.assertFalse(jnp.array_equal( + variables['params']['NNXLayer_0']['linear']['kernel'], + variables['params']['another']['linear']['kernel'], )) + self.assertEqual(variables['intermediates']['NNXLayer_0']['count'], 0) + + k1, k2, k3 = jax.random.split(jax.random.key(0), 3) + y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2}) + y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}) + assert not jnp.array_equal(y1, y2) + + _, updates = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}, + mutable=True) + self.assertEqual(updates['intermediates']['NNXLayer_0']['count'], 1) + + class BridgeMLPSetup(bridge.Module): + def setup(self): + self.layer = nnx.bridge.wrap_nnx_mdl(lambda r: NNXLayer(8, 0.3, rngs=r)) + def __call__(self, x): + return self.layer(x) + + model = BridgeMLPSetup() + variables = model.init(jax.random.key(1), x) + self.assertSameElements(variables['params'].keys(), ['layer']) + y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2}) + y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3}) + assert not jnp.array_equal(y1, y2) + + if __name__ == '__main__': absltest.main()