Skip to content

Commit

Permalink
bridge module with nnx submodules
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX committed Feb 27, 2025
1 parent 6af8fcb commit 23a2c18
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 1 deletion.
4 changes: 3 additions & 1 deletion flax/nnx/bridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@
from .variables import with_partitioning as with_partitioning
from .module import Module as Module
from .module import Scope as Scope
from .module import compact as compact
from .module import compact as compact
from .module import current_module as current_module
from .module import wrap_nnx_module as wrap_nnx_module
26 changes: 26 additions & 0 deletions flax/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from flax.nnx.object import Object
from flax.nnx import variablelib
from flax.nnx.bridge import variables as bridge_variables
from flax.nnx.transforms import transforms
import jax.numpy as jnp

A = tp.TypeVar('A')
Expand Down Expand Up @@ -117,6 +118,14 @@ def _object_meta_construct(cls, self, *args, **kwargs):
super()._object_meta_construct(self, *args, **kwargs)


def current_module() -> tp.Optional[Module]:
"""A quick util to get the current bridge module."""
ctx = MODULE_CONTEXT.module_stack[-1]
if ctx is None:
return None
return ctx.module


def _module_meta_call(cls: type[M], *args, **kwargs) -> M:
# compact behavior
parent_ctx = MODULE_CONTEXT.module_stack[-1]
Expand Down Expand Up @@ -502,3 +511,20 @@ def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable:
raise errors.ApplyModuleInvalidMethodError(method_or_fn)

return method_or_fn


def wrap_nnx_module(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module]):
"""Create module at init time, or make abstract module and let parent bind it with its state. Use current bridge module scope for RNG generation."""
parent = current_module()
assert parent is not None, 'wrap_nnx_module only needed inside bridge Module'
if parent.is_initializing():
return factory(parent.scope.rngs)
rngs = parent.scope.rngs if parent.scope.rngs else rnglib.Rngs(7) # dummy
abs_module = transforms.eval_shape(lambda: factory(rngs))
# Make sure the internal rng state is not abstract - everything else shall be
for _, stream in graph.iter_graph(abs_module):
if isinstance(stream, rnglib.RngStream):
parent_stream = rngs[stream.key.tag]
stream.key.value = parent_stream.key.value
stream.count.value = parent_stream.count.value
return abs_module
42 changes: 42 additions & 0 deletions tests/nnx/bridge/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,48 @@ def __call__(self, x):
y: jax.Array = foo.apply(variables, x)
self.assertEqual(y.shape, (3, 5))

def test_with_pure_nnx(self):
class NNXLayer(nnx.Module):
def __init__(self, dim, dropout, rngs):
self.linear = nnx.Linear(dim, dim, use_bias=False, rngs=rngs)
self.dropout = nnx.Dropout(dropout, rngs=rngs)
self.count = nnx.Intermediate(jnp.array([0.]))
def __call__(self, x):
# Required check to avoid state update in `init()`. Can this be avoided?
if not bridge.current_module().is_initializing():
self.count.value = self.count.value + 1
x = self.linear(x)
x = self.dropout(x)
return x

class BridgeMLP(bridge.Module):
num_layers: int
@bridge.compact
def __call__(self, x):
for i in range(self.num_layers):
layer = nnx.bridge.wrap_nnx_module(lambda r: NNXLayer(8, 0.3, rngs=r))
setattr(self, f'layer{i}', layer)
x = layer(x)
return x

model = BridgeMLP(2)
x = jax.random.normal(jax.random.key(0), (4, 8))
variables = model.init(jax.random.key(1), x)
self.assertSameElements(variables['params'].keys(), ['layer0', 'layer1'])
self.assertFalse(jnp.array_equal(
variables['params']['layer0']['linear']['kernel'],
variables['params']['layer1']['linear']['kernel'], ))
self.assertEqual(variables['intermediates']['layer1']['count'], 0)

k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2})
y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3})
assert not jnp.array_equal(y1, y2)

_, updates = model.apply(variables, x, rngs={'params': k1, 'dropout': k3},
mutable=True)
self.assertEqual(updates['intermediates']['layer1']['count'], 1)


if __name__ == '__main__':
absltest.main()
Expand Down

0 comments on commit 23a2c18

Please sign in to comment.