Skip to content

Commit

Permalink
Merge pull request #4161 from IvyZX:bdg-logic
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673097754
  • Loading branch information
Flax Authors committed Sep 10, 2024
2 parents 1b72435 + a9cb80b commit f948154
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 53 deletions.
14 changes: 14 additions & 0 deletions flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""

import abc
import dataclasses
import functools
from typing import Any, Generic, TypeVar
from collections.abc import Callable
Expand Down Expand Up @@ -287,6 +288,19 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding:
"""Returns the ``NamedSharding`` for this partitioned value."""
return jax.sharding.NamedSharding(mesh, self.get_partition_spec())

def to_nnx_metadata(self) -> dict[str, Any]:
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
metadata = vars(self)
metadata['sharding'] = metadata.pop('names')
return metadata

@classmethod
def from_nnx_metadata(cls, metadata: dict[str, Any]):
"""Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`."""
metadata['names'] = metadata.pop('sharding')
fields = {x.name for x in dataclasses.fields(cls)}
return cls(**{k: v for k, v in metadata.items() if k in fields})


def with_partitioning(
fn: Callable[..., Any],
Expand Down
15 changes: 15 additions & 0 deletions flax/linen/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,21 @@ def unbox(self, apply_constraint=True) -> Any:
else:
return self.value

def to_nnx_metadata(self) -> dict[str, Any]:
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
metadata = vars(self)
metadata['sharding'] = metadata.pop('names')
metadata['sharding_rules'] = metadata.pop('rules')
return metadata

@classmethod
def from_nnx_metadata(cls, metadata: dict[str, Any]):
"""Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`."""
metadata['names'] = metadata.pop('sharding')
metadata['rules'] = metadata.pop('sharding_rules')
fields = {x.name for x in dataclasses.fields(cls)}
return cls(**{k: v for k, v in metadata.items() if k in fields})


def with_logical_partitioning(
fn: Callable[..., Any],
Expand Down
30 changes: 15 additions & 15 deletions flax/nnx/bridge/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str:


def register_variable_name_type_pair(name, typ, overwrite = False):
"""Register a pair of variable type name (like Linen collections) and its NNX type."""
"""Register a pair of Linen collection name and its NNX type."""
if not overwrite and name in VariableTypeCache:
raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. '
'To overwrite, call with `overwrite=True`.')
'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.')
VariableTypeCache[name] = typ


Expand All @@ -85,8 +85,7 @@ def _variable_parents_count(t: type):


class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]):
"""Default Flax metadata class for `nnx.VariableState`.
"""
"""Default Flax metadata class for `nnx.VariableState`."""

var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False)
value: Any = struct.field(pytree_node=True)
Expand All @@ -110,10 +109,11 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata:
metadata = vs.get_metadata()
if 'linen_meta_type' in metadata:
if metadata['linen_meta_type'] is not meta.Partitioned:
raise ValueError('Not supporting Linen metadata types other than nn.Partitioned')
return meta.Partitioned(vs.value, names=metadata['sharding'], mesh=metadata['mesh'])
return NNXMeta(vs.type, vs.value, vs.get_metadata())
linen_type = metadata['linen_meta_type']
if hasattr(linen_type, 'from_nnx_metadata'):
return linen_type.from_nnx_metadata({'value': vs.value, **metadata})
return linen_type(vs.value, **metadata)
return NNXMeta(vs.type, vs.value, metadata)


def get_col_name(keypath: tp.Sequence[Any]) -> str:
Expand All @@ -124,15 +124,15 @@ def get_col_name(keypath: tp.Sequence[Any]) -> str:


def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable:
"""Convert a Linen variable to an NNX variable.
This process needs the collection name,
"""
"""Convert a Linen variable to an NNX variable."""
vtype = variable_type(col)
if isinstance(x, NNXMeta):
assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}'
return x.var_type(x.value, **x.metadata)
if isinstance(x, meta.AxisMetadata):
if isinstance(x, meta.Partitioned):
return vtype(x.value, sharding=x.names, mesh=x.mesh, linen_meta_type=meta.Partitioned)
raise ValueError('Not yet supporting metadata types other than nn.Partitioned and NNXMeta')
return vtype(x)
x_metadata = vars(x)
if hasattr(x, 'to_nnx_metadata'):
x_metadata = x.to_nnx_metadata()
assert hasattr(x, 'value')
return vtype(**x_metadata, linen_meta_type=type(x))
return vtype(x)
53 changes: 27 additions & 26 deletions flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs):
module = fn
assert callable(fn)
else:
if not (hasattr(fn, '__self__') and isinstance(fn.__self__, Module)):
if not hasattr(fn, '__self__') and isinstance(fn.__self__, Module):
raise ValueError(f'{fn = } needs to be a method of an NNX Module.')
module = fn.__self__
_set_initializing(module, True)
Expand Down Expand Up @@ -124,6 +124,7 @@ def __init__(
self.linen_collections: tuple[str, ...] = ()

def lazy_init(self, *args, **kwargs):
"""A shortcut of calling `nnx.bridge.lazy_init()` upon this module."""
return lazy_init(self, *args, **kwargs)

def __call__(
Expand Down Expand Up @@ -224,28 +225,6 @@ class ToLinen(linen.Module):
skip_rng: bool = False
metadata_type: tp.Type = bv.NNXMeta

def update_variables(self, module):
"""Store the NNX module's graph def and state inside Linen module variables."""
gdef, state = nnx.split(module)
# Save the graph def.
if self.is_mutable_collection('nnx'):
self.put_variable('nnx', 'graphdef', gdef)
# Sort all the variable types.
types = set(jax.tree.leaves(
jax.tree.map(lambda x: x.type, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
types = bv.sort_variable_types(types)
_, *state_by_types = nnx.split(module, *types)
# Each variable type goes to its own linen collection, and
# each attribute goes to its own linen variable
for typ, state in zip(types, state_by_types):
collection = bv.variable_type_name(typ)
if self.is_mutable_collection(collection):
for k, v in state.raw_mapping.items():
v = jax.tree.map(bv.to_linen_var, v,
is_leaf=lambda x: isinstance(x, nnx.VariableState))
self.put_variable(collection, k, v)

@linen.compact
def __call__(self, *args, **kwargs):
# init codepath
Expand All @@ -255,7 +234,7 @@ def __call__(self, *args, **kwargs):
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self)))
module = self.nnx_class(*self.args, **module_kwargs)
# TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`.
self.update_variables(module)
self._update_variables(module)
return module(*args, **kwargs)

# apply codepath
Expand All @@ -270,11 +249,33 @@ def __call__(self, *args, **kwargs):
module = nnx.merge(gdef, nnx_state)
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
out = module(*args, **kwargs)
self.update_variables(module)
self._update_variables(module)
return out

def _update_variables(self, module):
"""Store the NNX module's graph def and state inside Linen module variables."""
gdef, state = nnx.split(module)
# Save the graph def.
if self.is_mutable_collection('nnx'):
self.put_variable('nnx', 'graphdef', gdef)
# Sort all the variable types.
types = set(jax.tree.leaves(
jax.tree.map(lambda x: x.type, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
types = bv.sort_variable_types(types)
_, *state_by_types = nnx.split(module, *types)
# Each variable type goes to its own linen collection, and
# each attribute goes to its own linen variable
for typ, state in zip(types, state_by_types):
collection = bv.variable_type_name(typ)
if self.is_mutable_collection(collection):
for k, v in state.raw_mapping.items():
v = jax.tree.map(bv.to_linen_var, v,
is_leaf=lambda x: isinstance(x, nnx.VariableState))
self.put_variable(collection, k, v)


def to_linen(nnx_class: tp.Callable[..., Module], *args,
name: str | None = None, **kwargs):
"""Shortcut of `ToLinen` if user is not changing any of `ToLinen` default fields."""
"""Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields."""
return ToLinen(nnx_class, args=args, kwargs=kwargs, name=name)
6 changes: 6 additions & 0 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,15 @@ def _maybe_replicate(x):
else:
return None

def from_rules(sharding, sharding_rules):
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
return (rules[s] if s in rules else s for s in sharding)

def f(x):
if isinstance(x, (variables.VariableState, variables.Variable)):
if hasattr(x, 'sharding') and x.sharding:
if hasattr(x, 'sharding_rules') and x.sharding_rules:
return x.replace(PartitionSpec(*from_rules(x.sharding, x.sharding_rules)))
return x.replace(PartitionSpec(*x.sharding))
else:
return x.replace(_maybe_replicate(x.value))
Expand Down
47 changes: 36 additions & 11 deletions tests/nnx/bridge/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=4'

from absl.testing import absltest
import flax
Expand All @@ -24,6 +26,12 @@


class TestCompatibility(absltest.TestCase):
def setUp(self):
super().setUp()
dim1 = max(jax.device_count() // 2, 1)
device_mesh = np.array(jax.devices()).reshape(dim1, jax.device_count() // dim1)
self.mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=('in', 'out'))

def test_functional(self):
# Functional API for NNX Modules
functional = bridge.functional(nnx.Linear)(32, 64)
Expand Down Expand Up @@ -135,21 +143,35 @@ def vmap_fn(inner, x):
def test_linen_to_nnx_metadata(self):
linen_module = nn.Dense(
features=64,
kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')))
kernel_init=nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')),
bias_init=nn.with_logical_partitioning(nn.initializers.zeros_init(), ('out-alias',),
rules=(('out-alias', 'out'),)),
)
x = jax.numpy.ones((1, 32))
linen_vars = linen_module.init(jax.random.key(0), x)
nnx_model = bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x)
# nn.Partitioned metadata box is translated into a valid nnx.Variable / VariableState box.

@nnx.jit
def create_sharded_nnx_module(x):
model = bridge.lazy_init(bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)), x)
state = nnx.state(model)
sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state))
nnx.update(model, sharded_state)
return model
with self.mesh:
nnx_model = create_sharded_nnx_module(x)

# nn.Partitioned metadata boxes translated into valid nnx.Variable boxes.
self.assertIsInstance(linen_vars['params']['kernel'], nn.Partitioned)
self.assertIsInstance(linen_vars['params']['bias'], nn.LogicallyPartitioned)
self.assertIsInstance(nnx_model.params['kernel'], nnx.Variable)
np.testing.assert_array_equal(linen_vars['params']['kernel'].value,
nnx_model.params['kernel'].value)
assert nnx_model.params['kernel'].sharding == ('in', 'out')
_, nnx_state = nnx.split(nnx_model)
self.assertIsInstance(nnx_state['params']['kernel'], nnx.VariableState)
np.testing.assert_array_equal(linen_vars['params']['kernel'].value,
nnx_state['params']['kernel'].value)
assert nnx_state['params']['kernel'].sharding == ('in', 'out')
assert nnx_model.params['kernel'].value.sharding.is_equivalent_to(
jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('in', 'out')), ndim=2)

assert nnx_model.params['bias'].sharding == ('out-alias',)
assert nnx_model.params['bias'].sharding_rules == (('out-alias', 'out'),)
assert nnx_model.params['bias'].value.sharding.is_equivalent_to(
jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('out',)), ndim=1)


##################
Expand Down Expand Up @@ -306,7 +328,9 @@ class LinenMiddle(nn.Module):
@nn.compact
def __call__(self, x):
dot = bridge.to_linen(NNXInner, x.shape[-1], self.dout, self.dropout_rate, name='dot')
b = self.param('b', nn.initializers.lecun_normal(), (1, self.dout))
logical_init = nn.with_logical_partitioning(
nn.initializers.lecun_normal(), ('out-alias',), rules=(('out-alias', 'out')))
b = self.param('b', logical_init, (1, self.dout))
return dot(x) + b

class NNXOuter(nnx.Module):
Expand Down Expand Up @@ -335,6 +359,7 @@ def __call__(self, x):
self.assertIsInstance(w, nnx.Param)
np.testing.assert_allclose(model(x), x @ w + b)
assert hasattr(w, 'sharding') and w.sharding == ('in', 'out')
assert hasattr(b, 'sharding') and b.sharding == ('out-alias', )

def test_linen_nnx_linen(self):
# TODO: add when we can safely `lazy_init` the NNX module inside `ToLinen` without
Expand Down
2 changes: 1 addition & 1 deletion tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def f(m: Foo):

def test_apply_shardings(self):
n_devices = max(jax.local_device_count() // 2, 1)
devices = mesh_utils.create_device_mesh((n_devices, n_devices))
devices = mesh_utils.create_device_mesh((n_devices, jax.local_device_count() // n_devices))
mesh = jax.sharding.Mesh(devices, ('a', 'b'))

def sharding(*args):
Expand Down

0 comments on commit f948154

Please sign in to comment.