Skip to content

Commit

Permalink
Merge pull request #4185 from google:nnx-transform-metadata-issue
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673008669
  • Loading branch information
IvyZX committed Sep 10, 2024
2 parents 671130b + 93c0627 commit 2f6ff41
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 54 deletions.
19 changes: 19 additions & 0 deletions flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,25 @@ def __reduce__(self):
return (FlaxError, (str(self),))


#################################################
# NNX errors #
#################################################


class TraceContextError(FlaxError):
pass


class AxisNameMissingError(FlaxError):
def __init__(self, x_sharding):
super().__init__(
'You are trying to modify param dimension via transforms like `nnx.vmap` '
f'or `nnx.scan`, while the param is partition-annotated as: {x_sharding} '
'You need to provide the axis name of the transform via extra '
'argument: transform_metadata={nnx.PARTITION_NAME: "your_axis_name"}'
)


#################################################
# lazy_init.py errors #
#################################################
Expand Down
16 changes: 9 additions & 7 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
PartitionSpecPytree, # pylint: disable=invalid-name
Sharding,
)
from flax import errors

A = tp.TypeVar('A')
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
Expand All @@ -38,13 +39,15 @@ def add_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A:
def _add_axis(x: tp.Any):
if isinstance(x, variables.VariableState):
if hasattr(x, 'sharding') and x.sharding is not None:
if axis_name is None:
raise errors.AxisNameMissingError(x.sharding)
sharding: list[str | None] = list(x.sharding)
while len(sharding) < index:
sharding.append(None)
sharding.insert(index, axis_name)
x.sharding = tuple(sharding) # type: ignore

x.add_axis(axis_name, index)
x.add_axis(index, axis_name)
return x

return jax.tree.map(
Expand All @@ -58,10 +61,12 @@ def remove_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A:
def _remove_axis(x: tp.Any):
if isinstance(x, variables.VariableState):
if hasattr(x, 'sharding') and x.sharding is not None:
if axis_name is None:
raise errors.AxisNameMissingError(x.sharding)
sharding = list(x.sharding)
assert sharding.pop(index) == axis_name
x.sharding = tuple(sharding)
x.remove_axis(axis_name, index)
x.remove_axis(index, axis_name)
return x

return jax.tree.map(
Expand All @@ -71,12 +76,9 @@ def _remove_axis(x: tp.Any):
)


def _get_partition_name(params: tp.Mapping[tp.Any, tp.Any]) -> str:
def _get_partition_name(params: tp.Mapping[tp.Any, tp.Any]) -> str | None:
if PARTITION_NAME not in params:
raise ValueError(
'Trying to transform a Partitioned variable but "partition_name" '
f'is not specified in scan_metadata: {params}'
)
return None
return params[PARTITION_NAME]


Expand Down
75 changes: 39 additions & 36 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,17 +107,25 @@ def _update_variable_sharding_metadata(
):
def _update_axes_fn(tree_node):
if isinstance(tree_node, extract.TreeNode) and isinstance(
tree_node.metatata, StateAxes
tree_node.metatata, (StateAxes, int)
):
graphdef_states_out: list[extract.GraphDefState] = []
for graphdef_state, axis in zip(
if isinstance(tree_node.metatata, int):
graph_def_state = tree_node.graphdef_states[0]
assert isinstance(graph_def_state, extract.GraphDefState)
graphdef_state = axis_fn(
graph_def_state, tree_node.metatata, transform_metadata
)
return tree_node.replace(graphdef_states=(graphdef_state,))
else:
graphdef_states_out: list[extract.GraphDefState] = []
for graphdef_state, axis in zip(
tree_node.graphdef_states, tree_node.metatata.axes
):
assert isinstance(graphdef_state, extract.GraphDefState)
if isinstance(axis, int):
graphdef_state = axis_fn(graphdef_state, axis, transform_metadata)
graphdef_states_out.append(graphdef_state)
return tree_node.replace(graphdef_states=tuple(graphdef_states_out))
):
assert isinstance(graphdef_state, extract.GraphDefState)
if isinstance(axis, int):
graphdef_state = axis_fn(graphdef_state, axis, transform_metadata)
graphdef_states_out.append(graphdef_state)
return tree_node.replace(graphdef_states=tuple(graphdef_states_out))
return tree_node

return jax.tree.map(
Expand All @@ -130,7 +138,7 @@ def _vmap_split_fn(ctx: graph.SplitContext, path, prefix, x):
return extract.TreeNode.from_split(
*ctx.split(x, *prefix.filters), metadata=prefix
)
return extract.TreeNode.from_split(*ctx.split(x))
return extract.TreeNode.from_split(*ctx.split(x), metadata=prefix)


@dataclasses.dataclass(eq=False)
Expand All @@ -144,10 +152,10 @@ def __post_init__(self):
functools.update_wrapper(self, self.f)

def __call__(self, *pure_args: tuple[tp.Any, ...]):
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
print(self.transform_metadata)
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
args = extract.from_tree(pure_args, ctxtag='vmap')

out = self.f(*args)
Expand All @@ -159,10 +167,9 @@ def __call__(self, *pure_args: tuple[tp.Any, ...]):
split_fn=_vmap_split_fn,
ctxtag='vmap',
)
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out), self.transform_metadata, spmd.add_axis
)
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out), self.transform_metadata, spmd.add_axis
)
return pure_args_out, pure_out


Expand Down Expand Up @@ -348,10 +355,9 @@ def __post_init__(self):
functools.update_wrapper(self, self.f)

def __call__(self, *pure_args: tuple[tp.Any, ...]):
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
args = extract.from_tree(pure_args, ctxtag='pmap')

out = self.f(*args)
Expand All @@ -363,10 +369,9 @@ def __call__(self, *pure_args: tuple[tp.Any, ...]):
split_fn=_vmap_split_fn,
ctxtag='pmap',
)
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out), self.transform_metadata, spmd.add_axis
)
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out), self.transform_metadata, spmd.add_axis
)
return pure_args_out, pure_out


Expand Down Expand Up @@ -986,10 +991,9 @@ def __call__(
assert self.input_carry_argnum is None
assert pure_carry_arg is None

if spmd.PARTITION_NAME in self.transform_metadata:
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)
pure_args = _update_variable_sharding_metadata(
pure_args, self.transform_metadata, spmd.remove_axis
)

args: tuple = extract.from_tree(
pure_args,
Expand Down Expand Up @@ -1057,12 +1061,11 @@ def __call__(
map_non_graph_nodes=True,
ctxtag='scan',
)
if spmd.PARTITION_NAME in self.transform_metadata:
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out),
self.transform_metadata,
spmd.add_axis,
)
pure_args_out, pure_out = _update_variable_sharding_metadata(
(pure_args_out, pure_out),
self.transform_metadata,
spmd.add_axis,
)

# extract the pure carry from the pure args
if self.input_carry_argnum == 'all':
Expand Down
20 changes: 9 additions & 11 deletions flax/nnx/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,17 +870,15 @@ def get_metadata(self) -> dict[str, tp.Any]:
del metadata['value']
return metadata

def add_axis(self, axis_name: AxisName, axis_index: AxisIndex):
if not hasattr(self, 'add_axis_hooks'):
raise ValueError(f'No add_axis_hooks found for VariableState: {self}')
for hook in self.add_axis_hooks:
hook(self, axis_name, axis_index)

def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex):
if not hasattr(self, 'remove_axis_hooks'):
raise ValueError(f'No remove_axis_hooks found for VariableState: {self}')
for hook in self.remove_axis_hooks:
hook(self, axis_name, axis_index)
def add_axis(self, axis_index: AxisIndex, axis_name: AxisName | None = None):
if hasattr(self, 'add_axis_hooks'):
for hook in self.add_axis_hooks:
hook(self, axis_name, axis_index)

def remove_axis(self, axis_index: AxisIndex, axis_name: AxisName | None = None):
if hasattr(self, 'remove_axis_hooks'):
for hook in self.remove_axis_hooks:
hook(self, axis_name, axis_index)


def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool):
Expand Down
21 changes: 21 additions & 0 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2235,6 +2235,27 @@ def forward(model, x):

self.assertEqual(y.shape, (5, 4, 3))

def test_metadata(self):
@nnx.vmap(
in_axes=(None,),
out_axes=0,
axis_size=5,
transform_metadata={nnx.spmd.PARTITION_NAME: 'c'},
)
def create_block(rngs: nnx.Rngs):
return nnx.Linear(
16,
32,
rngs=rngs,
kernel_init=nnx.with_partitioning(
nnx.initializers.lecun_normal(), ('a', 'b')
),
)

m = create_block(nnx.Rngs(0))
self.assertEqual(m.kernel.value.shape, (5, 16, 32))
self.assertEqual(m.kernel.sharding, ('c', 'a', 'b'))


class TestPmap(absltest.TestCase):

Expand Down

0 comments on commit 2f6ff41

Please sign in to comment.