Skip to content

Commit

Permalink
bridge module with nnx submodules
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX committed Mar 5, 2025
1 parent 45a8f84 commit 95601b6
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 5 deletions.
5 changes: 4 additions & 1 deletion flax/nnx/bridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@
from .module import Module as Module
from .module import Scope as Scope
from .module import compact as compact
from flax.nnx.nn import initializers as initializers
from .module import current_context as current_context
from .module import current_module as current_module
from .module import wrap_nnx_mdl as wrap_nnx_mdl
from flax.nnx.nn import initializers as initializers
57 changes: 53 additions & 4 deletions flax/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from flax.nnx.object import Object
from flax.nnx import variablelib
from flax.nnx.bridge import variables as bridge_variables
from flax.nnx.transforms.transforms import eval_shape as nnx_eval_shape
from flax.nnx.transforms.compilation import jit as nnx_jit
import numpy as np

A = tp.TypeVar('A')
Expand Down Expand Up @@ -106,6 +108,25 @@ def _bind_module(parent: Module, module: Module) -> Module:
return module


def current_context() -> ModuleStackEntry | None:
return MODULE_CONTEXT.module_stack[-1]


def current_module() -> Module | None:
"""A quick util to get the current bridge module."""
ctx = current_context()
if ctx is None:
return None
return ctx.module


def _auto_submodule_name(parent_ctx, cls):
"""Increment type count and generate a new submodule name."""
type_index = parent_ctx.type_counter[cls]
parent_ctx.type_counter[cls] += 1
return f'{cls.__name__}_{type_index}'


class ModuleMeta(nnx_module.ModuleMeta):

def _object_meta_construct(cls, self, *args, **kwargs):
Expand Down Expand Up @@ -134,15 +155,12 @@ def _module_meta_call(cls: type[M], *args, **kwargs) -> M:
f"'parent' can only be set to None, got {type(parent).__name__}"
)
else:
type_index = parent_ctx.type_counter[cls]
parent_ctx.type_counter[cls] += 1

if 'name' in kwargs:
name = kwargs.pop('name')
if not isinstance(name, str):
raise ValueError(f"'name' must be a 'str', got {type(name).__name__}")
else:
name = f'{cls.__name__}_{type_index}'
name = _auto_submodule_name(parent_ctx, cls)
parent = parent_ctx.module

module = nnx_module.ModuleMeta.__call__(cls, *args, **kwargs)
Expand Down Expand Up @@ -501,3 +519,34 @@ def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable:
raise errors.ApplyModuleInvalidMethodError(method_or_fn)

return method_or_fn


def wrap_nnx_mdl(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module],
name: str | None = None):
"""Create module at init time, or make abstract module and let parent bind it with its state. Use current bridge module scope for RNG generation."""
parent_ctx, parent = current_context(), current_module()
assert parent_ctx is not None and parent is not None, 'wrap_nnx_mdl only needed inside bridge Module'
parent = parent_ctx.module
assert parent.scope is not None

if parent.is_initializing():
module = factory(parent.scope.rngs)
else:
rngs = parent.scope.rngs if parent.scope.rngs else rnglib.Rngs(7) # dummy
module = nnx_eval_shape(factory, rngs)

@nnx_jit
def rng_state(rngs):
return graph.state(factory(rngs), rnglib.RngState)

# Make sure the internal rng state is not abstract - other vars shall be
if parent.scope.rngs:
graph.update(module, rng_state(parent.scope.rngs))

# Automatically set the attribute if compact. If setup, user is responsible
# for setting the attribute of the superlayer.
if parent_ctx.in_compact:
if name is None:
name = _auto_submodule_name(parent_ctx, type(module))
setattr(parent, name, module)
return module
75 changes: 75 additions & 0 deletions tests/nnx/bridge/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,81 @@ def __call__(self, x):
y: jax.Array = foo.apply(variables, x)
self.assertEqual(y.shape, (3, 5))

def test_pure_nnx_submodule(self):
class NNXLayer(nnx.Module):
def __init__(self, dim, dropout, rngs):
self.linear = nnx.Linear(dim, dim, use_bias=False, rngs=rngs)
self.dropout = nnx.Dropout(dropout, rngs=rngs)
self.count = nnx.Intermediate(jnp.array([0.]))
def __call__(self, x):
# Required check to avoid state update in `init()`. Can this be avoided?
if not bridge.current_module().is_initializing():
self.count.value = self.count.value + 1
x = self.linear(x)
x = self.dropout(x)
return x

class BridgeMLP(bridge.Module):
@bridge.compact
def __call__(self, x):
x = nnx.bridge.wrap_nnx_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))(x)
x = nnx.bridge.wrap_nnx_mdl(
lambda r: NNXLayer(8, 0.3, rngs=r), name='another')(x)
return x

model = BridgeMLP()
x = jax.random.normal(jax.random.key(0), (4, 8))
variables = model.init(jax.random.key(1), x)
self.assertSameElements(variables['params'].keys(),
['NNXLayer_0', 'another'])
self.assertFalse(jnp.array_equal(
variables['params']['NNXLayer_0']['linear']['kernel'],
variables['params']['another']['linear']['kernel'], ))
self.assertEqual(variables['intermediates']['NNXLayer_0']['count'], 0)

k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2})
y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3})
assert not jnp.array_equal(y1, y2)

_, updates = model.apply(variables, x, rngs={'params': k1, 'dropout': k3},
mutable=True)
self.assertEqual(updates['intermediates']['NNXLayer_0']['count'], 1)

class BridgeMLPSetup(bridge.Module):
def setup(self):
self.layer = nnx.bridge.wrap_nnx_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))
def __call__(self, x):
return self.layer(x)

model = BridgeMLPSetup()
variables = model.init(jax.random.key(1), x)
self.assertSameElements(variables['params'].keys(), ['layer'])
y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2})
y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3})
assert not jnp.array_equal(y1, y2)

def test_pure_nnx_submodule_modified_rng(self):
class FooStack(nnx.Module):
def __init__(self, in_dim, key):
keys = jax.random.split(key, in_dim)
self.rngs = nnx.Rngs(keys)
def __call__(self, x):
@nnx.vmap
def generate_weights(r):
return jax.random.normal(r.default(), (2,))
w = generate_weights(self.rngs)
return x @ w

class BridgeFoo(bridge.Module):
@bridge.compact
def __call__(self, x):
x = nnx.bridge.wrap_nnx_mdl(lambda r: FooStack(4, r.default()))(x)
return x

model = BridgeFoo()
v = model.init(jax.random.key(1), jnp.ones((1, 4)))
y = model.apply(v, jnp.ones((1, 4)), rngs=jax.random.key(1))

if __name__ == '__main__':
absltest.main()
Expand Down

0 comments on commit 95601b6

Please sign in to comment.