Skip to content

Commit 23a2c18

Browse files
committed
bridge module with nnx submodules
1 parent 6af8fcb commit 23a2c18

File tree

3 files changed

+71
-1
lines changed

3 files changed

+71
-1
lines changed

flax/nnx/bridge/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,6 @@
2323
from .variables import with_partitioning as with_partitioning
2424
from .module import Module as Module
2525
from .module import Scope as Scope
26-
from .module import compact as compact
26+
from .module import compact as compact
27+
from .module import current_module as current_module
28+
from .module import wrap_nnx_module as wrap_nnx_module

flax/nnx/bridge/module.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from flax.nnx.object import Object
3333
from flax.nnx import variablelib
3434
from flax.nnx.bridge import variables as bridge_variables
35+
from flax.nnx.transforms import transforms
3536
import jax.numpy as jnp
3637

3738
A = tp.TypeVar('A')
@@ -117,6 +118,14 @@ def _object_meta_construct(cls, self, *args, **kwargs):
117118
super()._object_meta_construct(self, *args, **kwargs)
118119

119120

121+
def current_module() -> tp.Optional[Module]:
122+
"""A quick util to get the current bridge module."""
123+
ctx = MODULE_CONTEXT.module_stack[-1]
124+
if ctx is None:
125+
return None
126+
return ctx.module
127+
128+
120129
def _module_meta_call(cls: type[M], *args, **kwargs) -> M:
121130
# compact behavior
122131
parent_ctx = MODULE_CONTEXT.module_stack[-1]
@@ -502,3 +511,20 @@ def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable:
502511
raise errors.ApplyModuleInvalidMethodError(method_or_fn)
503512

504513
return method_or_fn
514+
515+
516+
def wrap_nnx_module(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module]):
517+
"""Create module at init time, or make abstract module and let parent bind it with its state. Use current bridge module scope for RNG generation."""
518+
parent = current_module()
519+
assert parent is not None, 'wrap_nnx_module only needed inside bridge Module'
520+
if parent.is_initializing():
521+
return factory(parent.scope.rngs)
522+
rngs = parent.scope.rngs if parent.scope.rngs else rnglib.Rngs(7) # dummy
523+
abs_module = transforms.eval_shape(lambda: factory(rngs))
524+
# Make sure the internal rng state is not abstract - everything else shall be
525+
for _, stream in graph.iter_graph(abs_module):
526+
if isinstance(stream, rnglib.RngStream):
527+
parent_stream = rngs[stream.key.tag]
528+
stream.key.value = parent_stream.key.value
529+
stream.count.value = parent_stream.count.value
530+
return abs_module

tests/nnx/bridge/module_test.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,48 @@ def __call__(self, x):
302302
y: jax.Array = foo.apply(variables, x)
303303
self.assertEqual(y.shape, (3, 5))
304304

305+
def test_with_pure_nnx(self):
306+
class NNXLayer(nnx.Module):
307+
def __init__(self, dim, dropout, rngs):
308+
self.linear = nnx.Linear(dim, dim, use_bias=False, rngs=rngs)
309+
self.dropout = nnx.Dropout(dropout, rngs=rngs)
310+
self.count = nnx.Intermediate(jnp.array([0.]))
311+
def __call__(self, x):
312+
# Required check to avoid state update in `init()`. Can this be avoided?
313+
if not bridge.current_module().is_initializing():
314+
self.count.value = self.count.value + 1
315+
x = self.linear(x)
316+
x = self.dropout(x)
317+
return x
318+
319+
class BridgeMLP(bridge.Module):
320+
num_layers: int
321+
@bridge.compact
322+
def __call__(self, x):
323+
for i in range(self.num_layers):
324+
layer = nnx.bridge.wrap_nnx_module(lambda r: NNXLayer(8, 0.3, rngs=r))
325+
setattr(self, f'layer{i}', layer)
326+
x = layer(x)
327+
return x
328+
329+
model = BridgeMLP(2)
330+
x = jax.random.normal(jax.random.key(0), (4, 8))
331+
variables = model.init(jax.random.key(1), x)
332+
self.assertSameElements(variables['params'].keys(), ['layer0', 'layer1'])
333+
self.assertFalse(jnp.array_equal(
334+
variables['params']['layer0']['linear']['kernel'],
335+
variables['params']['layer1']['linear']['kernel'], ))
336+
self.assertEqual(variables['intermediates']['layer1']['count'], 0)
337+
338+
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
339+
y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2})
340+
y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3})
341+
assert not jnp.array_equal(y1, y2)
342+
343+
_, updates = model.apply(variables, x, rngs={'params': k1, 'dropout': k3},
344+
mutable=True)
345+
self.assertEqual(updates['intermediates']['layer1']['count'], 1)
346+
305347

306348
if __name__ == '__main__':
307349
absltest.main()

0 commit comments

Comments
 (0)