Skip to content

Commit

Permalink
bridge module with linen module
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX committed Mar 6, 2025
1 parent 8254dd0 commit 300494b
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 29 deletions.
3 changes: 2 additions & 1 deletion flax/nnx/bridge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
from .module import compact as compact
from .module import current_context as current_context
from .module import current_module as current_module
from .interop import wrap_nnx_mdl as wrap_nnx_mdl
from .interop import nnx_in_bridge_mdl as nnx_in_bridge_mdl
from .interop import linen_in_bridge_mdl as linen_in_bridge_mdl
from flax.nnx.nn import initializers as initializers
45 changes: 41 additions & 4 deletions flax/nnx/bridge/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,32 @@

import typing as tp

from flax.linen import module as nn_module
from flax.nnx import graph, rnglib
from flax.nnx.bridge import wrappers
from flax.nnx.bridge import module as bdg_module
import flax.nnx.module as nnx_module
from flax.nnx.transforms.transforms import eval_shape as nnx_eval_shape
from flax.nnx.transforms.compilation import jit as nnx_jit


def wrap_nnx_mdl(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module],
name: str | None = None):
"""Create module at init time, or make abstract module and let parent bind it with its state. Use current bridge module scope for RNG generation."""
def nnx_in_bridge_mdl(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module],
name: str | None = None) -> nnx_module.Module:
"""Make pure NNX modules a submodule of a bridge module.
Create module at init time, or make abstract module and let parent bind
it with its state.
Use current bridge module scope for RNG generation.
Args:
factory: a function that takes an `nnx.Rngs` arg and returns an NNX module.
name: the name of the module. Only used during `bridge.compact` functions;
in setup() function the user will set it to an attribute explicitly.
Returns:
A submodule (`nnx.Module`) of the bridge module.
"""
parent_ctx, parent = bdg_module.current_context(), bdg_module.current_module()
assert parent_ctx is not None and parent is not None, 'wrap_nnx_mdl only needed inside bridge Module'
assert parent_ctx is not None and parent is not None, 'nnx_in_bridge_mdl() only needed inside bridge Module'
parent = parent_ctx.module
assert parent.scope is not None

Expand All @@ -50,3 +64,26 @@ def rng_state(rngs):
name = bdg_module._auto_submodule_name(parent_ctx, type(module))
setattr(parent, name, module)
return module


def linen_in_bridge_mdl(linen_module: nn_module.Module,
name: str | None = None) -> nnx_module.Module:
"""Make Linen modules a submodule of a bridge module using wrappers.ToNNX().
Args:
linen_module: the underlying Linen module instance.
name: the name of the module. Only used during `bridge.compact` functions;
in setup() function the user will set it to an attribute explicitly.
Returns:
A submodule (`nnx.Module`) of the bridge module.
"""
parent_ctx, parent = bdg_module.current_context(), bdg_module.current_module()
assert parent_ctx is not None and parent is not None, 'linen_in_bridge_mdl() only needed inside bridge Module'
assert parent.scope is not None
module = wrappers.ToNNX(linen_module, parent.scope.rngs)
wrappers._set_initializing(module, parent.is_initializing())
if parent_ctx.in_compact:
if name is None:
name = bdg_module._auto_submodule_name(parent_ctx, type(linen_module))
setattr(parent, name, module)
return module
11 changes: 7 additions & 4 deletions flax/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from flax import errors
from flax.core import meta
from flax.core.scope import CollectionFilter
from flax.core.frozen_dict import FrozenDict
from flax.nnx import graph, rnglib, statelib, traversals
from flax.nnx import variablelib
Expand Down Expand Up @@ -63,11 +64,12 @@ class ModuleState(statelib.State):


class Scope(Object):
def __init__(self, rngs: rnglib.Rngs):
def __init__(self, rngs: rnglib.Rngs, mutable: CollectionFilter):
self.rngs = rngs
self.mutable = mutable

def copy(self):
return Scope(self.rngs)
return Scope(self.rngs, self.mutable)


class _HasSetup(tp.Protocol):
Expand Down Expand Up @@ -365,7 +367,7 @@ def apply(
*args,
rngs: int | jax.Array | dict[str, jax.Array] | rnglib.Rngs | None = None,
method: tp.Callable[..., tp.Any] | str = '__call__',
mutable: tp.Any = False,
mutable: CollectionFilter = False,
_initialize: bool = False,
**kwargs,
) -> tp.Any:
Expand Down Expand Up @@ -422,7 +424,7 @@ def to_variable(value):
if isinstance(value, Object):
value._object__state._initializing = _initialize
if isinstance(value, Module):
value.scope = Scope(rngs)
value.scope = Scope(rngs, mutable)
_maybe_call_setup(value)

MODULE_CONTEXT.module_stack.append(
Expand Down Expand Up @@ -517,3 +519,4 @@ def _get_unbound_fn(method_or_fn: tp.Callable) -> tp.Callable:
raise errors.ApplyModuleInvalidMethodError(method_or_fn)

return method_or_fn

27 changes: 15 additions & 12 deletions flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flax.nnx import graph
from flax.nnx import variablelib
from flax.nnx.bridge import variables as bv
from flax.nnx.bridge import module as bdg_module
from flax.nnx.module import Module
from flax.nnx.object import Object
from flax.nnx.rnglib import Rngs
Expand Down Expand Up @@ -124,7 +125,6 @@ def __init__(
):
self.module = module
self.rngs = rngs
self.linen_attributes: tuple[str, ...] = ()

def lazy_init(self, *args, **kwargs):
"""A shortcut of calling `nnx.bridge.lazy_init()` upon this module."""
Expand All @@ -140,46 +140,49 @@ def __call__(
rngs = self.rngs
if self._object__state.initializing:
_rngs = (
{name: stream.key.raw_value for name, stream in rngs.items()}
if rngs
else {}
{name: stream() for name, stream in rngs.items()} if rngs else {}
)
# rename default to params
if 'params' not in _rngs and 'default' in _rngs:
_rngs['params'] = _rngs.pop('default')
out, variables = self.module.init_with_output(_rngs, *args, method=method, **kwargs)

nnx_attrs = bv.linen_vars_to_nnx_attrs(variables)
linen_attributes = set(self.linen_attributes)
for attr_name, value in nnx_attrs.items():
setattr(self, attr_name, value)
linen_attributes.add(attr_name)
self.linen_attributes = tuple(linen_attributes) # make it hashable

else:
nnx_attrs = {name: getattr(self, name) for name in self.linen_attributes}
nnx_attrs = {k: v for k, v in vars(self).items()
if k not in ['module', 'rngs', '_object__state']}
variables = bv.nnx_attrs_to_linen_vars(nnx_attrs)

_rngs = (
{name: stream() for name, stream in rngs.items()} if rngs else {}
)

# Get `mutable` from top level bridge.Module context if any
if (m := bdg_module.current_module()) is not None:
assert m.scope is not None
mutable = m.scope.mutable
if 'mutable' in kwargs and kwargs['mutable'] != mutable:
raise ValueError(
f"Multiple `mutable` arguments detected: {mutable} at top level vs "
f"{kwargs['mutable']} in ToNNX() call")
kwargs['mutable'] = mutable

out = self.module.apply(variables, *args, rngs=_rngs, method=method, **kwargs)

# Split out the updates if `mutable` is passed into the Flax module
if kwargs.get('mutable', False) != False:
out, updates = out
nnx_attrs = bv.linen_vars_to_nnx_attrs(updates)
linen_attributes = set(self.linen_attributes)
for attr_name, value in nnx_attrs.items():
linen_attributes.add(attr_name)
if hasattr(self, attr_name) and isinstance(value, dict):
original_tree = getattr(self, attr_name)
setattr(self, attr_name, original_tree | value)
else:
setattr(self, attr_name, value)

self.linen_attributes = tuple(linen_attributes) # make it hashable

return out


Expand Down
63 changes: 59 additions & 4 deletions tests/nnx/bridge/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,8 @@ def __call__(self, x):
class BridgeMLP(bridge.Module):
@bridge.compact
def __call__(self, x):
x = nnx.bridge.wrap_nnx_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))(x)
x = nnx.bridge.wrap_nnx_mdl(
x = bridge.nnx_in_bridge_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))(x)
x = bridge.nnx_in_bridge_mdl(
lambda r: NNXLayer(8, 0.3, rngs=r), name='another')(x)
return x

Expand All @@ -345,7 +345,8 @@ def __call__(self, x):

class BridgeMLPSetup(bridge.Module):
def setup(self):
self.layer = nnx.bridge.wrap_nnx_mdl(lambda r: NNXLayer(8, 0.3, rngs=r))
self.layer = bridge.nnx_in_bridge_mdl(
lambda r: NNXLayer(8, 0.3, rngs=r))
def __call__(self, x):
return self.layer(x)

Expand All @@ -371,13 +372,67 @@ def generate_weights(r):
class BridgeFoo(bridge.Module):
@bridge.compact
def __call__(self, x):
x = nnx.bridge.wrap_nnx_mdl(lambda r: FooStack(4, r.default()))(x)
x = bridge.nnx_in_bridge_mdl(lambda r: FooStack(4, r.default()))(x)
return x

model = BridgeFoo()
v = model.init(jax.random.key(1), jnp.ones((1, 4)))
y = model.apply(v, jnp.ones((1, 4)), rngs=jax.random.key(1))

def test_linen_submodule(self):
class LinenLayer(nn.Module):
dim: int
dropout_rate: float
def setup(self):
self.linear = nn.Dense(self.dim, use_bias=False)
self.dropout = nn.Dropout(self.dropout_rate, deterministic=False)

def __call__(self, x):
if not self.is_initializing():
self.sow('intermediates', 'count', 1,
init_fn=lambda: 0, reduce_fn=lambda a, b: a + b)
x = self.linear(x)
x = self.dropout(x)
return x

class BridgeMLP(bridge.Module):
@bridge.compact
def __call__(self, x):
x = bridge.linen_in_bridge_mdl(LinenLayer(8, 0.3))(x)
x = bridge.linen_in_bridge_mdl(LinenLayer(8, 0.3), name='another')(x)
return x

model = BridgeMLP()
x = jax.random.normal(jax.random.key(0), (4, 8))
variables = model.init(jax.random.key(1), x)
self.assertFalse(jnp.array_equal(
variables['params']['LinenLayer_0']['linear']['kernel'],
variables['params']['another']['linear']['kernel'], ))

k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2})
y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3})
assert not jnp.array_equal(y1, y2)

_, updates = model.apply(variables, x, rngs={'params': k1, 'dropout': k3},
mutable=True)
self.assertEqual(updates['intermediates']['LinenLayer_0']['count'], 1)

class BridgeMLPSetup(bridge.Module):
def setup(self):
self.layer = bridge.linen_in_bridge_mdl(LinenLayer(8, 0.3))
def __call__(self, x):
return self.layer(x)

model = BridgeMLPSetup()
variables = model.init(jax.random.key(1), x)
self.assertSameElements(variables['params'].keys(), ['layer'])
y1 = model.apply(variables, x, rngs={'params': k1, 'dropout': k2})
y2 = model.apply(variables, x, rngs={'params': k1, 'dropout': k3})
assert not jnp.array_equal(y1, y2)



if __name__ == '__main__':
absltest.main()

10 changes: 6 additions & 4 deletions tests/nnx/bridge/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ def test_linen_to_nnx(self):
assert y.shape == (1, 64)
self.assertIsInstance(model.kernel, nnx.Variable)
# NNX automatically adds metadata box regardless of original Linen module.
linen_vars = linen_module.init(jax.random.key(0), x)
np.testing.assert_array_equal(linen_vars['params']['kernel'],
model.kernel.value)
linen_vars = {'params': {'kernel': model.kernel.value,
'bias': model.bias.value}}
linen_y = linen_module.apply(linen_vars, x)
np.testing.assert_array_equal(y, linen_y)

def test_linen_to_nnx_submodule(self):
class NNXOuter(nnx.Module):
Expand Down Expand Up @@ -468,10 +469,11 @@ def __call__(self, x):
# Test the RNG
model = bridge.lazy_init(NNXOuter(dout=3, dropout_rate=0.5,
rngs=nnx.Rngs(default=1, dropout=2)), x)
nnx.reseed(model, dropout=2)
y1, y2 = model(x), model(x)
# The dropout key of lowest NNX level still changes over stateful calls
assert not jnp.allclose(y1, y2)
# Reseed resets the RNG key back
# Another reseed resets the RNG key back
nnx.reseed(model, dropout=2)
np.testing.assert_array_equal(y1, model(x))

Expand Down

0 comments on commit 300494b

Please sign in to comment.