Skip to content

Commit f948154

Browse files
author
Flax Authors
committed
Merge pull request #4161 from IvyZX:bdg-logic
PiperOrigin-RevId: 673097754
2 parents 1b72435 + a9cb80b commit f948154

File tree

7 files changed

+114
-53
lines changed

7 files changed

+114
-53
lines changed

flax/core/meta.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""
2323

2424
import abc
25+
import dataclasses
2526
import functools
2627
from typing import Any, Generic, TypeVar
2728
from collections.abc import Callable
@@ -287,6 +288,19 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding:
287288
"""Returns the ``NamedSharding`` for this partitioned value."""
288289
return jax.sharding.NamedSharding(mesh, self.get_partition_spec())
289290

291+
def to_nnx_metadata(self) -> dict[str, Any]:
292+
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
293+
metadata = vars(self)
294+
metadata['sharding'] = metadata.pop('names')
295+
return metadata
296+
297+
@classmethod
298+
def from_nnx_metadata(cls, metadata: dict[str, Any]):
299+
"""Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`."""
300+
metadata['names'] = metadata.pop('sharding')
301+
fields = {x.name for x in dataclasses.fields(cls)}
302+
return cls(**{k: v for k, v in metadata.items() if k in fields})
303+
290304

291305
def with_partitioning(
292306
fn: Callable[..., Any],

flax/linen/spmd.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,21 @@ def unbox(self, apply_constraint=True) -> Any:
328328
else:
329329
return self.value
330330

331+
def to_nnx_metadata(self) -> dict[str, Any]:
332+
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
333+
metadata = vars(self)
334+
metadata['sharding'] = metadata.pop('names')
335+
metadata['sharding_rules'] = metadata.pop('rules')
336+
return metadata
337+
338+
@classmethod
339+
def from_nnx_metadata(cls, metadata: dict[str, Any]):
340+
"""Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`."""
341+
metadata['names'] = metadata.pop('sharding')
342+
metadata['rules'] = metadata.pop('sharding_rules')
343+
fields = {x.name for x in dataclasses.fields(cls)}
344+
return cls(**{k: v for k, v in metadata.items() if k in fields})
345+
331346

332347
def with_logical_partitioning(
333348
fn: Callable[..., Any],

flax/nnx/bridge/variables.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str:
5858

5959

6060
def register_variable_name_type_pair(name, typ, overwrite = False):
61-
"""Register a pair of variable type name (like Linen collections) and its NNX type."""
61+
"""Register a pair of Linen collection name and its NNX type."""
6262
if not overwrite and name in VariableTypeCache:
6363
raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. '
64-
'To overwrite, call with `overwrite=True`.')
64+
'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.')
6565
VariableTypeCache[name] = typ
6666

6767

@@ -85,8 +85,7 @@ def _variable_parents_count(t: type):
8585

8686

8787
class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]):
88-
"""Default Flax metadata class for `nnx.VariableState`.
89-
"""
88+
"""Default Flax metadata class for `nnx.VariableState`."""
9089

9190
var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False)
9291
value: Any = struct.field(pytree_node=True)
@@ -110,10 +109,11 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
110109
def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata:
111110
metadata = vs.get_metadata()
112111
if 'linen_meta_type' in metadata:
113-
if metadata['linen_meta_type'] is not meta.Partitioned:
114-
raise ValueError('Not supporting Linen metadata types other than nn.Partitioned')
115-
return meta.Partitioned(vs.value, names=metadata['sharding'], mesh=metadata['mesh'])
116-
return NNXMeta(vs.type, vs.value, vs.get_metadata())
112+
linen_type = metadata['linen_meta_type']
113+
if hasattr(linen_type, 'from_nnx_metadata'):
114+
return linen_type.from_nnx_metadata({'value': vs.value, **metadata})
115+
return linen_type(vs.value, **metadata)
116+
return NNXMeta(vs.type, vs.value, metadata)
117117

118118

119119
def get_col_name(keypath: tp.Sequence[Any]) -> str:
@@ -124,15 +124,15 @@ def get_col_name(keypath: tp.Sequence[Any]) -> str:
124124

125125

126126
def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable:
127-
"""Convert a Linen variable to an NNX variable.
128-
This process needs the collection name,
129-
"""
127+
"""Convert a Linen variable to an NNX variable."""
130128
vtype = variable_type(col)
131129
if isinstance(x, NNXMeta):
132130
assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}'
133131
return x.var_type(x.value, **x.metadata)
134132
if isinstance(x, meta.AxisMetadata):
135-
if isinstance(x, meta.Partitioned):
136-
return vtype(x.value, sharding=x.names, mesh=x.mesh, linen_meta_type=meta.Partitioned)
137-
raise ValueError('Not yet supporting metadata types other than nn.Partitioned and NNXMeta')
138-
return vtype(x)
133+
x_metadata = vars(x)
134+
if hasattr(x, 'to_nnx_metadata'):
135+
x_metadata = x.to_nnx_metadata()
136+
assert hasattr(x, 'value')
137+
return vtype(**x_metadata, linen_meta_type=type(x))
138+
return vtype(x)

flax/nnx/bridge/wrappers.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs):
7474
module = fn
7575
assert callable(fn)
7676
else:
77-
if not (hasattr(fn, '__self__') and isinstance(fn.__self__, Module)):
77+
if not hasattr(fn, '__self__') and isinstance(fn.__self__, Module):
7878
raise ValueError(f'{fn = } needs to be a method of an NNX Module.')
7979
module = fn.__self__
8080
_set_initializing(module, True)
@@ -124,6 +124,7 @@ def __init__(
124124
self.linen_collections: tuple[str, ...] = ()
125125

126126
def lazy_init(self, *args, **kwargs):
127+
"""A shortcut of calling `nnx.bridge.lazy_init()` upon this module."""
127128
return lazy_init(self, *args, **kwargs)
128129

129130
def __call__(
@@ -224,28 +225,6 @@ class ToLinen(linen.Module):
224225
skip_rng: bool = False
225226
metadata_type: tp.Type = bv.NNXMeta
226227

227-
def update_variables(self, module):
228-
"""Store the NNX module's graph def and state inside Linen module variables."""
229-
gdef, state = nnx.split(module)
230-
# Save the graph def.
231-
if self.is_mutable_collection('nnx'):
232-
self.put_variable('nnx', 'graphdef', gdef)
233-
# Sort all the variable types.
234-
types = set(jax.tree.leaves(
235-
jax.tree.map(lambda x: x.type, state,
236-
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
237-
types = bv.sort_variable_types(types)
238-
_, *state_by_types = nnx.split(module, *types)
239-
# Each variable type goes to its own linen collection, and
240-
# each attribute goes to its own linen variable
241-
for typ, state in zip(types, state_by_types):
242-
collection = bv.variable_type_name(typ)
243-
if self.is_mutable_collection(collection):
244-
for k, v in state.raw_mapping.items():
245-
v = jax.tree.map(bv.to_linen_var, v,
246-
is_leaf=lambda x: isinstance(x, nnx.VariableState))
247-
self.put_variable(collection, k, v)
248-
249228
@linen.compact
250229
def __call__(self, *args, **kwargs):
251230
# init codepath
@@ -255,7 +234,7 @@ def __call__(self, *args, **kwargs):
255234
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self)))
256235
module = self.nnx_class(*self.args, **module_kwargs)
257236
# TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`.
258-
self.update_variables(module)
237+
self._update_variables(module)
259238
return module(*args, **kwargs)
260239

261240
# apply codepath
@@ -270,11 +249,33 @@ def __call__(self, *args, **kwargs):
270249
module = nnx.merge(gdef, nnx_state)
271250
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
272251
out = module(*args, **kwargs)
273-
self.update_variables(module)
252+
self._update_variables(module)
274253
return out
275254

255+
def _update_variables(self, module):
256+
"""Store the NNX module's graph def and state inside Linen module variables."""
257+
gdef, state = nnx.split(module)
258+
# Save the graph def.
259+
if self.is_mutable_collection('nnx'):
260+
self.put_variable('nnx', 'graphdef', gdef)
261+
# Sort all the variable types.
262+
types = set(jax.tree.leaves(
263+
jax.tree.map(lambda x: x.type, state,
264+
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
265+
types = bv.sort_variable_types(types)
266+
_, *state_by_types = nnx.split(module, *types)
267+
# Each variable type goes to its own linen collection, and
268+
# each attribute goes to its own linen variable
269+
for typ, state in zip(types, state_by_types):
270+
collection = bv.variable_type_name(typ)
271+
if self.is_mutable_collection(collection):
272+
for k, v in state.raw_mapping.items():
273+
v = jax.tree.map(bv.to_linen_var, v,
274+
is_leaf=lambda x: isinstance(x, nnx.VariableState))
275+
self.put_variable(collection, k, v)
276+
276277

277278
def to_linen(nnx_class: tp.Callable[..., Module], *args,
278279
name: str | None = None, **kwargs):
279-
"""Shortcut of `ToLinen` if user is not changing any of `ToLinen` default fields."""
280+
"""Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields."""
280281
return ToLinen(nnx_class, args=args, kwargs=kwargs, name=name)

flax/nnx/spmd.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,15 @@ def _maybe_replicate(x):
8989
else:
9090
return None
9191

92+
def from_rules(sharding, sharding_rules):
93+
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
94+
return (rules[s] if s in rules else s for s in sharding)
95+
9296
def f(x):
9397
if isinstance(x, (variables.VariableState, variables.Variable)):
9498
if hasattr(x, 'sharding') and x.sharding:
99+
if hasattr(x, 'sharding_rules') and x.sharding_rules:
100+
return x.replace(PartitionSpec(*from_rules(x.sharding, x.sharding_rules)))
95101
return x.replace(PartitionSpec(*x.sharding))
96102
else:
97103
return x.replace(_maybe_replicate(x.value))

tests/nnx/bridge/wrappers_test.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'
1517

1618
from absl.testing import absltest
1719
import flax
@@ -24,6 +26,12 @@
2426

2527

2628
class TestCompatibility(absltest.TestCase):
29+
def setUp(self):
30+
super().setUp()
31+
dim1 = max(jax.device_count() // 2, 1)
32+
device_mesh = np.array(jax.devices()).reshape(dim1, jax.device_count() // dim1)
33+
self.mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=('in', 'out'))
34+
2735
def test_functional(self):
2836
# Functional API for NNX Modules
2937
functional = bridge.functional(nnx.Linear)(32, 64)
@@ -135,21 +143,35 @@ def vmap_fn(inner, x):
135143
def test_linen_to_nnx_metadata(self):
136144
linen_module = nn.Dense(
137145
features=64,
138-
kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')))
146+
kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')),
147+
bias_init=nn.with_logical_partitioning(nn.initializers.zeros_init(), ('out-alias',),
148+
rules=(('out-alias', 'out'),)),
149+
)
139150
x = jax.numpy.ones((1, 32))
140151
linen_vars = linen_module.init(jax.random.key(0), x)
141-
nnx_model = bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x)
142-
# nn.Partitioned metadata box is translated into a valid nnx.Variable / VariableState box.
152+
153+
@nnx.jit
154+
def create_sharded_nnx_module(x):
155+
model = bridge.lazy_init(bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)), x)
156+
state = nnx.state(model)
157+
sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state))
158+
nnx.update(model, sharded_state)
159+
return model
160+
with self.mesh:
161+
nnx_model = create_sharded_nnx_module(x)
162+
163+
# nn.Partitioned metadata boxes translated into valid nnx.Variable boxes.
143164
self.assertIsInstance(linen_vars['params']['kernel'], nn.Partitioned)
165+
self.assertIsInstance(linen_vars['params']['bias'], nn.LogicallyPartitioned)
144166
self.assertIsInstance(nnx_model.params['kernel'], nnx.Variable)
145-
np.testing.assert_array_equal(linen_vars['params']['kernel'].value,
146-
nnx_model.params['kernel'].value)
147167
assert nnx_model.params['kernel'].sharding == ('in', 'out')
148-
_, nnx_state = nnx.split(nnx_model)
149-
self.assertIsInstance(nnx_state['params']['kernel'], nnx.VariableState)
150-
np.testing.assert_array_equal(linen_vars['params']['kernel'].value,
151-
nnx_state['params']['kernel'].value)
152-
assert nnx_state['params']['kernel'].sharding == ('in', 'out')
168+
assert nnx_model.params['kernel'].value.sharding.is_equivalent_to(
169+
jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('in', 'out')), ndim=2)
170+
171+
assert nnx_model.params['bias'].sharding == ('out-alias',)
172+
assert nnx_model.params['bias'].sharding_rules == (('out-alias', 'out'),)
173+
assert nnx_model.params['bias'].value.sharding.is_equivalent_to(
174+
jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('out',)), ndim=1)
153175

154176

155177
##################
@@ -306,7 +328,9 @@ class LinenMiddle(nn.Module):
306328
@nn.compact
307329
def __call__(self, x):
308330
dot = bridge.to_linen(NNXInner, x.shape[-1], self.dout, self.dropout_rate, name='dot')
309-
b = self.param('b', nn.initializers.lecun_normal(), (1, self.dout))
331+
logical_init = nn.with_logical_partitioning(
332+
nn.initializers.lecun_normal(), ('out-alias',), rules=(('out-alias', 'out')))
333+
b = self.param('b', logical_init, (1, self.dout))
310334
return dot(x) + b
311335

312336
class NNXOuter(nnx.Module):
@@ -335,6 +359,7 @@ def __call__(self, x):
335359
self.assertIsInstance(w, nnx.Param)
336360
np.testing.assert_allclose(model(x), x @ w + b)
337361
assert hasattr(w, 'sharding') and w.sharding == ('in', 'out')
362+
assert hasattr(b, 'sharding') and b.sharding == ('out-alias', )
338363

339364
def test_linen_nnx_linen(self):
340365
# TODO: add when we can safely `lazy_init` the NNX module inside `ToLinen` without

tests/nnx/transforms_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def f(m: Foo):
323323

324324
def test_apply_shardings(self):
325325
n_devices = max(jax.local_device_count() // 2, 1)
326-
devices = mesh_utils.create_device_mesh((n_devices, n_devices))
326+
devices = mesh_utils.create_device_mesh((n_devices, jax.local_device_count() // n_devices))
327327
mesh = jax.sharding.Mesh(devices, ('a', 'b'))
328328

329329
def sharding(*args):

0 commit comments

Comments
 (0)