Skip to content

Commit

Permalink
Merge pull request #4561 from IvyZX:state-dep-warn
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 729703734
  • Loading branch information
Flax Authors committed Feb 22, 2025
2 parents 1ec5ef2 + ed73ba4 commit dd6e595
Show file tree
Hide file tree
Showing 32 changed files with 579 additions and 420 deletions.
32 changes: 16 additions & 16 deletions docs_nnx/guides/checkpointing.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions docs_nnx/guides/checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,15 @@ When interacting with checkpoint libraries (like Orbax), you may prefer to work

```{code-cell} ipython3
# Save as pure dict
pure_dict_state = state.to_pure_dict()
pure_dict_state = nnx.to_pure_dict(state)
nnx.display(pure_dict_state)
checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)
# Restore as a pure dictionary.
restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(restored_pure_dict)
nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The model still works!
```
Expand Down Expand Up @@ -181,7 +181,7 @@ restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))
# Same restore code as above.
abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
abstract_state.replace_by_pure_dict(restored_pure_dict)
nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The new model works!
Expand Down
6 changes: 3 additions & 3 deletions docs_nnx/guides/flax_gspmd.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -501,8 +501,8 @@
")\n",
"loaded_sharded = checkpointer.restore(path / 'checkpoint_name',\n",
" target=abs_state)\n",
"jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)\n",
"jax.debug.visualize_array_sharding(loaded_sharded.w2.value)"
"jax.debug.visualize_array_sharding(loaded_sharded['dot1']['kernel'].value)\n",
"jax.debug.visualize_array_sharding(loaded_sharded['w2'].value)"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions docs_nnx/guides/flax_gspmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ abs_state = jax.tree.map(
)
loaded_sharded = checkpointer.restore(path / 'checkpoint_name',
target=abs_state)
jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)
jax.debug.visualize_array_sharding(loaded_sharded.w2.value)
jax.debug.visualize_array_sharding(loaded_sharded['dot1']['kernel'].value)
jax.debug.visualize_array_sharding(loaded_sharded['w2'].value)
```

## Compile the training loop
Expand Down
10 changes: 5 additions & 5 deletions docs_nnx/guides/haiku_to_flax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ The dropout behavior:
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.GraphState.merge(params, rest))
nnx.update(model, nnx.merge_state(params, rest))

.. testcode:: Haiku
:hide:
Expand Down Expand Up @@ -378,7 +378,7 @@ The parameter structure is as follows:
_, params, _ = nnx.split(model, nnx.Param, ...)
params
State({
{
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
Expand All @@ -387,7 +387,7 @@ The parameter structure is as follows:
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
})
}
To call those custom methods:
Expand Down Expand Up @@ -634,14 +634,14 @@ Now inspect the variable pytree on both sides:
_, params, _ = nnx.split(model, nnx.Param, ...)
params
State({
{
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
})
}
Top-level Haiku functions vs top-level Flax modules
Expand Down
12 changes: 6 additions & 6 deletions docs_nnx/guides/linen_to_nnx.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Dropout behavior:
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
nnx.update(model, nnx.GraphState.merge(params, rest))
nnx.update(model, nnx.merge_state(params, rest))
.. testcode:: Linen
:hide:
Expand Down Expand Up @@ -389,7 +389,7 @@ The variable structure is as follows:
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
{
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
Expand All @@ -398,7 +398,7 @@ The variable structure is as follows:
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
})
}
To call methods other than ``__call__``:

Expand Down Expand Up @@ -531,7 +531,7 @@ Scan-over-layers is a technique where you run an input through a sequence of N r
* Up close, in the logic of this model there actually is no need for the ``jax.lax.scan`` operation at initialization time. What happens there is more like a ``jax.vmap`` operation - you are given a ``Block`` sub-``Module`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array.
* In Flax NNX, you take advantage of the fact that model initialization and running code are completely decoupled, and instead use the :func:`nnx.vmap<flax.nnx.vmap>` transform to initialize the underlying ``Block`` parameters, and the :func:`nnx.scan<flax.nnx.scan>` transform to run the model input through them.

For more information on Flax NNX transforms, check out the `Transforms guide <https://flax.readthedocs.io/en/latest/guides/transforms.html>`__.
For more information on Flax NNX transforms, check out the `Transforms guide <https://flax.readthedocs.build/en/guides/transforms.html>`__.

.. codediff::
:title: Linen, NNX
Expand Down Expand Up @@ -644,14 +644,14 @@ Now inspect the variable pytree on both sides:
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
State({
{
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
})
}
Using ``TrainState`` in Flax NNX
Expand Down
10 changes: 5 additions & 5 deletions docs_nnx/guides/surgery.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -95,8 +95,8 @@
"# Variable sharing (weight-tying).\n",
"model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n",
"model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate\n",
"assert hasattr(nnx.state(model), 'linear2')\n",
"assert hasattr(nnx.state(model)['linear2'], 'bias')\n",
"assert 'linear2' in nnx.state(model)\n",
"assert 'bias' in nnx.state(model)['linear2']\n",
"assert not hasattr(nnx.state(model)['linear2'], 'kernel')\n",
"\n",
"# Monkey-patching.\n",
Expand Down Expand Up @@ -256,7 +256,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -301,7 +301,7 @@
"# Fit it into the model state.\n",
"abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
"graph_def, state = nnx.split(abs_model)\n",
"state.replace_by_pure_dict(process_raw_dict(raw_dict))\n",
"nnx.replace_by_pure_dict(state, process_raw_dict(raw_dict))\n",
"restored_model = nnx.merge(graph_def, state)\n",
"\n",
"np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))"
Expand Down
6 changes: 3 additions & 3 deletions docs_nnx/guides/surgery.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))
# Variable sharing (weight-tying).
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate
assert hasattr(nnx.state(model), 'linear2')
assert hasattr(nnx.state(model)['linear2'], 'bias')
assert 'linear2' in nnx.state(model)
assert 'bias' in nnx.state(model)['linear2']
assert not hasattr(nnx.state(model)['linear2'], 'kernel')
# Monkey-patching.
Expand Down Expand Up @@ -172,7 +172,7 @@ raw_dict['layer2'] = raw_dict.pop('linear2')
# Fit it into the model state.
abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graph_def, state = nnx.split(abs_model)
state.replace_by_pure_dict(process_raw_dict(raw_dict))
nnx.replace_by_pure_dict(state, process_raw_dict(raw_dict))
restored_model = nnx.merge(graph_def, state)
np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))
Expand Down
68 changes: 32 additions & 36 deletions docs_nnx/nnx_basics.ipynb

Large diffs are not rendered by default.

6 changes: 1 addition & 5 deletions docs_nnx/nnx_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ Flax NNX is a new simplified API that is designed to make it easier to create, i

To begin, install Flax with `pip` and import necessary dependencies:

## Setup

Install Flax with `pip` and impost necessary dependencies:

```{code-cell} ipython3
:tags: [skip-execution]
Expand Down Expand Up @@ -95,7 +91,7 @@ to handle them, as demonstrated in later sections of this guide.

Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.

The example below shows how to define a simple `MLP` Module consisting of two `Linear` layers, a [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer.
The example below shows how to define a simple `MLP` by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The model consists of two `Linear` layers, an [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer:

```{code-cell} ipython3
class MLP(nnx.Module):
Expand Down
4 changes: 2 additions & 2 deletions examples/gemma/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def assign_val_fn(

mdl: M = nnx.eval_shape(module_factory)
graph_def, state = nnx.split(mdl)
state = dict(state.flat_state())
state = dict(nnx.to_flat_state(state))
for path, val in flax.traverse_util.flatten_dict(variables).items():
mapped_path = map_key_fn(path)
if mapped_path not in state:
Expand All @@ -88,6 +88,6 @@ def assign_val_fn(
f' exist (original path={path}).'
)
state = assign_val_fn(state, mapped_path, val)
state = nnx.State.from_flat_path(state)
state = nnx.from_flat_state(state)

return nnx.merge(graph_def, state)
2 changes: 1 addition & 1 deletion examples/gemma/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def __init__(
@property
def dtype(self) -> jnp.dtype:
params_state = nnx.state(self.transformer, nnx.Param)
return jax.tree_util.tree_leaves(params_state.flat_state())[0].dtype
return jax.tree_util.tree_leaves(nnx.to_flat_state(params_state))[0].dtype

def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
"""Performs a single sampling step."""
Expand Down
4 changes: 2 additions & 2 deletions examples/lm1b_nnx/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def transfer_params(
params_linen: dict[str, Any],
):
rules = dataclasses.asdict(config.axis_rules)
flat_params_nnx = dict(params_nnx.flat_state())
flat_params_nnx = dict(nnx.to_flat_state(params_nnx))
flat_params_linen = nnx.traversals.flatten_mapping(params_linen, sep='/')

def apply_rules(names: tuple[str, ...]):
Expand Down Expand Up @@ -163,7 +163,7 @@ def transfer_cache(
cache_nnx: nnx.State,
cache_linen: dict[str, Any],
):
flat_cache_nnx = dict(cache_nnx.flat_state())
flat_cache_nnx = dict(nnx.to_flat_state(cache_nnx))
flat_cache_linen = nnx.traversals.flatten_mapping(cache_linen, sep='/')

def copy_var(nnx_name: str, linen_name: str):
Expand Down
6 changes: 3 additions & 3 deletions examples/nnx_toy_examples/10_fsdp_and_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def init_optimizer_state(variable: nnx.Variable):

self.lr = lr
self.params = params
self.momentum = jax.tree.map(init_optimizer_state, self.params)
self.momentum: nnx.State = jax.tree.map(init_optimizer_state, self.params)
self.decay = decay

def update(self, grads: nnx.State):
Expand Down Expand Up @@ -117,7 +117,7 @@ def get_named_shardings(path: tuple, value: nnx.VariableState):
else:
raise ValueError(f'Unknown path: {path}')

named_shardings = state.map(get_named_shardings)
named_shardings = nnx.map_state(get_named_shardings, state)
sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)
nnx.update(optimizer, sharded_state)
return model, optimizer
Expand All @@ -126,7 +126,7 @@ def get_named_shardings(path: tuple, value: nnx.VariableState):
model, optimizer = create_model()

jax.debug.visualize_array_sharding(model.w1.value)
jax.debug.visualize_array_sharding(optimizer.momentum.w1.value)
jax.debug.visualize_array_sharding(optimizer.momentum['w1'].value)


@nnx.jit
Expand Down
8 changes: 8 additions & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@
from .spmd import with_partitioning as with_partitioning
from .spmd import with_sharding_constraint as with_sharding_constraint
from .statelib import State as State
from .statelib import to_flat_state as to_flat_state
from .statelib import from_flat_state as from_flat_state
from .statelib import to_pure_dict as to_pure_dict
from .statelib import replace_by_pure_dict as replace_by_pure_dict
from .statelib import filter_state as filter_state
from .statelib import merge_state as merge_state
from .statelib import split_state as split_state
from .statelib import map_state as map_state
from .training import metrics as metrics
from .variablelib import Param as Param
# this needs to be imported before optimizer to prevent circular import
Expand Down
4 changes: 2 additions & 2 deletions flax/nnx/bridge/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def _get_variables(self) -> tp.Mapping:
_variables: dict = {}

variable_state: variablelib.VariableState
for path, variable_state in state.flat_state():
for path, variable_state in statelib.to_flat_state(state):
try:
collection = variablelib.variable_name_from_type(variable_state.type)
except ValueError:
Expand Down Expand Up @@ -365,7 +365,7 @@ def to_variable(value):
real_variables[col_name] = linen_collection

states = ({},) if not real_variables else real_variables.values()
state = ModuleState.merge(*states)
state = statelib.merge_state(*states, cls=ModuleState)
graph.update(module, state)

if rngs is None:
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def __call__(self, *args, **kwargs):
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x).to_state(),
variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
states = [State(v) for v in states.values()]
nnx_state = nnx.GraphState.merge(*states) if states else nnx.GraphState({})
nnx_state = nnx.merge_state(*states) if states else nnx.GraphState({})
module = nnx.merge(gdef, nnx_state)
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
out = module(*args, **kwargs)
Expand Down
Loading

0 comments on commit dd6e595

Please sign in to comment.