diff --git a/docs_nnx/guides/checkpointing.ipynb b/docs_nnx/guides/checkpointing.ipynb
index de6c7a279..af32f00a9 100644
--- a/docs_nnx/guides/checkpointing.ipynb
+++ b/docs_nnx/guides/checkpointing.ipynb
@@ -88,7 +88,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -100,7 +100,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -153,7 +153,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -173,14 +173,14 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "/Users/cris/repos/cristian/flax/.venv/lib/python3.10/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py:1136: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
+ "/Users/ivyzheng/envs/flax-head/lib/python3.11/site-packages/orbax/checkpoint/type_handlers.py:1439: UserWarning: Couldn't find sharding info under RestoreArgs. Populating sharding info from sharding file. Please note restoration time will be slightly increased due to reading from file instead of directly from RestoreArgs. Note also that this option is unsafe when restoring on a different topology than the checkpoint was saved with.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -192,7 +192,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -252,13 +252,13 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -270,7 +270,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -289,7 +289,7 @@
],
"source": [
"# Save as pure dict\n",
- "pure_dict_state = state.to_pure_dict()\n",
+ "pure_dict_state = nnx.to_pure_dict(state)\n",
"nnx.display(pure_dict_state)\n",
"checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)\n",
"\n",
@@ -297,7 +297,7 @@
"restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')\n",
"abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
"graphdef, abstract_state = nnx.split(abstract_model)\n",
- "abstract_state.replace_by_pure_dict(restored_pure_dict)\n",
+ "nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)\n",
"model = nnx.merge(graphdef, abstract_state)\n",
"assert model(x).shape == (3, 4) # The model still works!"
]
@@ -325,7 +325,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -338,7 +338,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -350,7 +350,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -379,7 +379,7 @@
"# Same restore code as above.\n",
"abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
"graphdef, abstract_state = nnx.split(abstract_model)\n",
- "abstract_state.replace_by_pure_dict(restored_pure_dict)\n",
+ "nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)\n",
"model = nnx.merge(graphdef, abstract_state)\n",
"assert model(x).shape == (3, 4) # The new model works!\n",
"\n",
@@ -440,7 +440,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.16"
+ "version": "3.11.9"
}
},
"nbformat": 4,
diff --git a/docs_nnx/guides/checkpointing.md b/docs_nnx/guides/checkpointing.md
index fa98e6db8..cc0101c25 100644
--- a/docs_nnx/guides/checkpointing.md
+++ b/docs_nnx/guides/checkpointing.md
@@ -133,7 +133,7 @@ When interacting with checkpoint libraries (like Orbax), you may prefer to work
```{code-cell} ipython3
# Save as pure dict
-pure_dict_state = state.to_pure_dict()
+pure_dict_state = nnx.to_pure_dict(state)
nnx.display(pure_dict_state)
checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)
@@ -141,7 +141,7 @@ checkpointer.save(ckpt_dir / 'pure_dict', pure_dict_state)
restored_pure_dict = checkpointer.restore(ckpt_dir / 'pure_dict')
abstract_model = nnx.eval_shape(lambda: TwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
-abstract_state.replace_by_pure_dict(restored_pure_dict)
+nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The model still works!
```
@@ -181,7 +181,7 @@ restored_pure_dict['linear2']['bias'] = jnp.zeros((4,))
# Same restore code as above.
abstract_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
-abstract_state.replace_by_pure_dict(restored_pure_dict)
+nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The new model works!
diff --git a/docs_nnx/guides/flax_gspmd.ipynb b/docs_nnx/guides/flax_gspmd.ipynb
index 44dcfd513..b428c0ac3 100644
--- a/docs_nnx/guides/flax_gspmd.ipynb
+++ b/docs_nnx/guides/flax_gspmd.ipynb
@@ -415,7 +415,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -501,8 +501,8 @@
")\n",
"loaded_sharded = checkpointer.restore(path / 'checkpoint_name',\n",
" target=abs_state)\n",
- "jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)\n",
- "jax.debug.visualize_array_sharding(loaded_sharded.w2.value)"
+ "jax.debug.visualize_array_sharding(loaded_sharded['dot1']['kernel'].value)\n",
+ "jax.debug.visualize_array_sharding(loaded_sharded['w2'].value)"
]
},
{
diff --git a/docs_nnx/guides/flax_gspmd.md b/docs_nnx/guides/flax_gspmd.md
index 50441f941..7c3a73cf0 100644
--- a/docs_nnx/guides/flax_gspmd.md
+++ b/docs_nnx/guides/flax_gspmd.md
@@ -235,8 +235,8 @@ abs_state = jax.tree.map(
)
loaded_sharded = checkpointer.restore(path / 'checkpoint_name',
target=abs_state)
-jax.debug.visualize_array_sharding(loaded_sharded.dot1.kernel.value)
-jax.debug.visualize_array_sharding(loaded_sharded.w2.value)
+jax.debug.visualize_array_sharding(loaded_sharded['dot1']['kernel'].value)
+jax.debug.visualize_array_sharding(loaded_sharded['w2'].value)
```
## Compile the training loop
diff --git a/docs_nnx/guides/haiku_to_flax.rst b/docs_nnx/guides/haiku_to_flax.rst
index 8fdb48dbd..5cfb55a57 100644
--- a/docs_nnx/guides/haiku_to_flax.rst
+++ b/docs_nnx/guides/haiku_to_flax.rst
@@ -201,7 +201,7 @@ The dropout behavior:
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
- nnx.update(model, nnx.GraphState.merge(params, rest))
+ nnx.update(model, nnx.merge_state(params, rest))
.. testcode:: Haiku
:hide:
@@ -378,7 +378,7 @@ The parameter structure is as follows:
_, params, _ = nnx.split(model, nnx.Param, ...)
params
- State({
+ {
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
@@ -387,7 +387,7 @@ The parameter structure is as follows:
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
- })
+ }
To call those custom methods:
@@ -634,14 +634,14 @@ Now inspect the variable pytree on both sides:
_, params, _ = nnx.split(model, nnx.Param, ...)
params
- State({
+ {
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
- })
+ }
Top-level Haiku functions vs top-level Flax modules
diff --git a/docs_nnx/guides/linen_to_nnx.rst b/docs_nnx/guides/linen_to_nnx.rst
index 19f25d709..bc497aec4 100644
--- a/docs_nnx/guides/linen_to_nnx.rst
+++ b/docs_nnx/guides/linen_to_nnx.rst
@@ -184,7 +184,7 @@ Dropout behavior:
grads = nnx.grad(loss_fn)(model)
_, params, rest = nnx.split(model, nnx.Param, ...)
params = jax.tree.map(lambda p, g: p - 0.1 * g, params, grads)
- nnx.update(model, nnx.GraphState.merge(params, rest))
+ nnx.update(model, nnx.merge_state(params, rest))
.. testcode:: Linen
:hide:
@@ -389,7 +389,7 @@ The variable structure is as follows:
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
- State({
+ {
'decoder': {
'bias': VariableState(type=Param, value=(784,)),
'kernel': VariableState(type=Param, value=(256, 784))
@@ -398,7 +398,7 @@ The variable structure is as follows:
'bias': VariableState(type=Param, value=(256,)),
'kernel': VariableState(type=Param, value=(784, 256))
}
- })
+ }
To call methods other than ``__call__``:
@@ -531,7 +531,7 @@ Scan-over-layers is a technique where you run an input through a sequence of N r
* Up close, in the logic of this model there actually is no need for the ``jax.lax.scan`` operation at initialization time. What happens there is more like a ``jax.vmap`` operation - you are given a ``Block`` sub-``Module`` that accepts ``(in_dim, out_dim)``, and you "vmap" it over ``num_layers`` of times to create a larger array.
* In Flax NNX, you take advantage of the fact that model initialization and running code are completely decoupled, and instead use the :func:`nnx.vmap` transform to initialize the underlying ``Block`` parameters, and the :func:`nnx.scan` transform to run the model input through them.
-For more information on Flax NNX transforms, check out the `Transforms guide `__.
+For more information on Flax NNX transforms, check out the `Transforms guide `__.
.. codediff::
:title: Linen, NNX
@@ -644,14 +644,14 @@ Now inspect the variable pytree on both sides:
# _, params, _ = nnx.split(model, nnx.Param, ...)
# params
- State({
+ {
'blocks': {
'linear': {
'bias': VariableState(type=Param, value=(5, 64)),
'kernel': VariableState(type=Param, value=(5, 64, 64))
}
}
- })
+ }
Using ``TrainState`` in Flax NNX
diff --git a/docs_nnx/guides/surgery.ipynb b/docs_nnx/guides/surgery.ipynb
index b179f6811..edbe22975 100644
--- a/docs_nnx/guides/surgery.ipynb
+++ b/docs_nnx/guides/surgery.ipynb
@@ -73,7 +73,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -95,8 +95,8 @@
"# Variable sharing (weight-tying).\n",
"model = TwoLayerMLP(4, rngs=nnx.Rngs(0))\n",
"model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate\n",
- "assert hasattr(nnx.state(model), 'linear2')\n",
- "assert hasattr(nnx.state(model)['linear2'], 'bias')\n",
+ "assert 'linear2' in nnx.state(model)\n",
+ "assert 'bias' in nnx.state(model)['linear2']\n",
"assert not hasattr(nnx.state(model)['linear2'], 'kernel')\n",
"\n",
"# Monkey-patching.\n",
@@ -256,7 +256,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
"outputs": [
{
@@ -301,7 +301,7 @@
"# Fit it into the model state.\n",
"abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))\n",
"graph_def, state = nnx.split(abs_model)\n",
- "state.replace_by_pure_dict(process_raw_dict(raw_dict))\n",
+ "nnx.replace_by_pure_dict(state, process_raw_dict(raw_dict))\n",
"restored_model = nnx.merge(graph_def, state)\n",
"\n",
"np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))"
diff --git a/docs_nnx/guides/surgery.md b/docs_nnx/guides/surgery.md
index 904eb7cf1..1df1ce596 100644
--- a/docs_nnx/guides/surgery.md
+++ b/docs_nnx/guides/surgery.md
@@ -78,8 +78,8 @@ np.testing.assert_allclose(model(x), model.linear1(model.linear1(x)))
# Variable sharing (weight-tying).
model = TwoLayerMLP(4, rngs=nnx.Rngs(0))
model.linear1.kernel = model.linear2.kernel # the bias parameter is kept separate
-assert hasattr(nnx.state(model), 'linear2')
-assert hasattr(nnx.state(model)['linear2'], 'bias')
+assert 'linear2' in nnx.state(model)
+assert 'bias' in nnx.state(model)['linear2']
assert not hasattr(nnx.state(model)['linear2'], 'kernel')
# Monkey-patching.
@@ -172,7 +172,7 @@ raw_dict['layer2'] = raw_dict.pop('linear2')
# Fit it into the model state.
abs_model = nnx.eval_shape(lambda: ModifiedTwoLayerMLP(4, rngs=nnx.Rngs(0)))
graph_def, state = nnx.split(abs_model)
-state.replace_by_pure_dict(process_raw_dict(raw_dict))
+nnx.replace_by_pure_dict(state, process_raw_dict(raw_dict))
restored_model = nnx.merge(graph_def, state)
np.testing.assert_allclose(restored_model(jnp.ones((3, 4))), old_model(jnp.ones((3, 4))))
diff --git a/docs_nnx/nnx_basics.ipynb b/docs_nnx/nnx_basics.ipynb
index bf040b98d..57dd3b869 100644
--- a/docs_nnx/nnx_basics.ipynb
+++ b/docs_nnx/nnx_basics.ipynb
@@ -8,11 +8,7 @@
"\n",
"Flax NNX is a new simplified API that is designed to make it easier to create, inspect, debug, and analyze neural networks in [JAX](https://jax.readthedocs.io/). It achieves this by adding first class support for Python reference semantics. This allows users to express their models using regular Python objects, which are modeled as PyGraphs (instead of pytrees), enabling reference sharing and mutability. Such API design should make PyTorch or Keras users feel at home.\n",
"\n",
- "To begin, install Flax with `pip` and import necessary dependencies:\n",
- "\n",
- "## Setup\n",
- "\n",
- "Install Flax with `pip` and impost necessary dependencies:"
+ "To begin, install Flax with `pip` and import necessary dependencies:"
]
},
{
@@ -30,7 +26,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -52,7 +48,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -79,7 +75,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -92,7 +88,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -104,7 +100,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -141,7 +137,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -185,18 +181,18 @@
"\n",
"Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.\n",
"\n",
- "The example below shows how to define a simple `MLP` Module consisting of two `Linear` layers, a [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer."
+ "The example below shows how to define a simple `MLP` by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The model consists of two `Linear` layers, an [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer:"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -208,7 +204,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -257,13 +253,13 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -275,7 +271,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -324,7 +320,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -386,7 +382,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 8,
"metadata": {},
"outputs": [
{
@@ -399,7 +395,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -411,7 +407,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -464,13 +460,13 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -482,7 +478,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -525,13 +521,13 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -543,7 +539,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -555,7 +551,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -586,7 +582,7 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 11,
"metadata": {},
"outputs": [
{
@@ -650,13 +646,13 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -668,7 +664,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -680,7 +676,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -692,7 +688,7 @@
{
"data": {
"text/html": [
- "
"
+ "
"
],
"text/plain": [
""
@@ -720,7 +716,7 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
diff --git a/docs_nnx/nnx_basics.md b/docs_nnx/nnx_basics.md
index 51e0cda53..fbf9be0a2 100644
--- a/docs_nnx/nnx_basics.md
+++ b/docs_nnx/nnx_basics.md
@@ -14,10 +14,6 @@ Flax NNX is a new simplified API that is designed to make it easier to create, i
To begin, install Flax with `pip` and import necessary dependencies:
-## Setup
-
-Install Flax with `pip` and impost necessary dependencies:
-
```{code-cell} ipython3
:tags: [skip-execution]
@@ -95,7 +91,7 @@ to handle them, as demonstrated in later sections of this guide.
Flax [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html)s can be used to compose other `Module`s in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.
-The example below shows how to define a simple `MLP` Module consisting of two `Linear` layers, a [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer.
+The example below shows how to define a simple `MLP` by subclassing [`nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html). The model consists of two `Linear` layers, an [`nnx.Dropout`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/stochastic.html#flax.nnx.Dropout) layer, and an [`nnx.BatchNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.BatchNorm) layer:
```{code-cell} ipython3
class MLP(nnx.Module):
diff --git a/examples/gemma/helpers.py b/examples/gemma/helpers.py
index 7743563c0..f74845bed 100644
--- a/examples/gemma/helpers.py
+++ b/examples/gemma/helpers.py
@@ -79,7 +79,7 @@ def assign_val_fn(
mdl: M = nnx.eval_shape(module_factory)
graph_def, state = nnx.split(mdl)
- state = dict(state.flat_state())
+ state = dict(nnx.to_flat_state(state))
for path, val in flax.traverse_util.flatten_dict(variables).items():
mapped_path = map_key_fn(path)
if mapped_path not in state:
@@ -88,6 +88,6 @@ def assign_val_fn(
f' exist (original path={path}).'
)
state = assign_val_fn(state, mapped_path, val)
- state = nnx.State.from_flat_path(state)
+ state = nnx.from_flat_state(state)
return nnx.merge(graph_def, state)
diff --git a/examples/gemma/sampler.py b/examples/gemma/sampler.py
index f221fbb93..5efb5c698 100644
--- a/examples/gemma/sampler.py
+++ b/examples/gemma/sampler.py
@@ -136,7 +136,7 @@ def __init__(
@property
def dtype(self) -> jnp.dtype:
params_state = nnx.state(self.transformer, nnx.Param)
- return jax.tree_util.tree_leaves(params_state.flat_state())[0].dtype
+ return jax.tree_util.tree_leaves(nnx.to_flat_state(params_state))[0].dtype
def _sample_step(self, sampler_state: _SamplingState) -> _SamplingState:
"""Performs a single sampling step."""
diff --git a/examples/lm1b_nnx/models_test.py b/examples/lm1b_nnx/models_test.py
index d2d0ce03d..5527c0190 100644
--- a/examples/lm1b_nnx/models_test.py
+++ b/examples/lm1b_nnx/models_test.py
@@ -79,7 +79,7 @@ def transfer_params(
params_linen: dict[str, Any],
):
rules = dataclasses.asdict(config.axis_rules)
- flat_params_nnx = dict(params_nnx.flat_state())
+ flat_params_nnx = dict(nnx.to_flat_state(params_nnx))
flat_params_linen = nnx.traversals.flatten_mapping(params_linen, sep='/')
def apply_rules(names: tuple[str, ...]):
@@ -163,7 +163,7 @@ def transfer_cache(
cache_nnx: nnx.State,
cache_linen: dict[str, Any],
):
- flat_cache_nnx = dict(cache_nnx.flat_state())
+ flat_cache_nnx = dict(nnx.to_flat_state(cache_nnx))
flat_cache_linen = nnx.traversals.flatten_mapping(cache_linen, sep='/')
def copy_var(nnx_name: str, linen_name: str):
diff --git a/examples/nnx_toy_examples/10_fsdp_and_optimizer.py b/examples/nnx_toy_examples/10_fsdp_and_optimizer.py
index f5cf8002b..fab6ea1f2 100644
--- a/examples/nnx_toy_examples/10_fsdp_and_optimizer.py
+++ b/examples/nnx_toy_examples/10_fsdp_and_optimizer.py
@@ -84,7 +84,7 @@ def init_optimizer_state(variable: nnx.Variable):
self.lr = lr
self.params = params
- self.momentum = jax.tree.map(init_optimizer_state, self.params)
+ self.momentum: nnx.State = jax.tree.map(init_optimizer_state, self.params)
self.decay = decay
def update(self, grads: nnx.State):
@@ -117,7 +117,7 @@ def get_named_shardings(path: tuple, value: nnx.VariableState):
else:
raise ValueError(f'Unknown path: {path}')
- named_shardings = state.map(get_named_shardings)
+ named_shardings = nnx.map_state(get_named_shardings, state)
sharded_state = jax.lax.with_sharding_constraint(state, named_shardings)
nnx.update(optimizer, sharded_state)
return model, optimizer
@@ -126,7 +126,7 @@ def get_named_shardings(path: tuple, value: nnx.VariableState):
model, optimizer = create_model()
jax.debug.visualize_array_sharding(model.w1.value)
-jax.debug.visualize_array_sharding(optimizer.momentum.w1.value)
+jax.debug.visualize_array_sharding(optimizer.momentum['w1'].value)
@nnx.jit
diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py
index f059358ba..910fb3af2 100644
--- a/flax/nnx/__init__.py
+++ b/flax/nnx/__init__.py
@@ -121,6 +121,14 @@
from .spmd import with_partitioning as with_partitioning
from .spmd import with_sharding_constraint as with_sharding_constraint
from .statelib import State as State
+from .statelib import to_flat_state as to_flat_state
+from .statelib import from_flat_state as from_flat_state
+from .statelib import to_pure_dict as to_pure_dict
+from .statelib import replace_by_pure_dict as replace_by_pure_dict
+from .statelib import filter_state as filter_state
+from .statelib import merge_state as merge_state
+from .statelib import split_state as split_state
+from .statelib import map_state as map_state
from .training import metrics as metrics
from .variablelib import Param as Param
# this needs to be imported before optimizer to prevent circular import
diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py
index f9b120864..377d9f295 100644
--- a/flax/nnx/bridge/module.py
+++ b/flax/nnx/bridge/module.py
@@ -307,7 +307,7 @@ def _get_variables(self) -> tp.Mapping:
_variables: dict = {}
variable_state: variablelib.VariableState
- for path, variable_state in state.flat_state():
+ for path, variable_state in statelib.to_flat_state(state):
try:
collection = variablelib.variable_name_from_type(variable_state.type)
except ValueError:
@@ -365,7 +365,7 @@ def to_variable(value):
real_variables[col_name] = linen_collection
states = ({},) if not real_variables else real_variables.values()
- state = ModuleState.merge(*states)
+ state = statelib.merge_state(*states, cls=ModuleState)
graph.update(module, state)
if rngs is None:
diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py
index b4b352b67..7fd651062 100644
--- a/flax/nnx/bridge/wrappers.py
+++ b/flax/nnx/bridge/wrappers.py
@@ -251,7 +251,7 @@ def __call__(self, *args, **kwargs):
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x).to_state(),
variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
states = [State(v) for v in states.values()]
- nnx_state = nnx.GraphState.merge(*states) if states else nnx.GraphState({})
+ nnx_state = nnx.merge_state(*states) if states else nnx.GraphState({})
module = nnx.merge(gdef, nnx_state)
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
out = module(*args, **kwargs)
diff --git a/flax/nnx/graph.py b/flax/nnx/graph.py
index f1a245304..6fb73ed4b 100644
--- a/flax/nnx/graph.py
+++ b/flax/nnx/graph.py
@@ -22,6 +22,7 @@
import typing as tp
from flax.nnx import filterlib, reprlib, variablelib
+from flax.nnx import statelib
from flax.nnx.proxy_caller import (
ApplyCaller,
CallableProxy,
@@ -397,7 +398,7 @@ def _apply(
fn = accessor(module)
out = fn(*args, **kwargs)
graphdef, flat_state = flatten(module)
- state_ = State.from_flat_path(flat_state)
+ state_ = statelib.from_flat_state(flat_state)
return out, (graphdef, state_)
return CallableProxy(_apply, accessor) # type: ignore
@@ -1008,7 +1009,7 @@ def graph_pop(
)
_graph_pop(node, id_to_index, path_parts, flat_states, predicates)
return tuple(
- GraphState.from_flat_path(flat_state) for flat_state in flat_states
+ statelib.from_flat_state(flat_state) for flat_state in flat_states
)
@@ -1310,7 +1311,7 @@ def split(
)
flat_states = _split_state(flat_state, filters)
states = tuple(
- State.from_flat_path(flat_state) for flat_state in flat_states
+ statelib.from_flat_state(flat_state) for flat_state in flat_states
)
return graphdef, *states
@@ -1479,7 +1480,7 @@ def merge(
ctx.outer_index_outer_ref if ctx and ctx.outer_index_outer_ref else None
)
- _state = State.merge(state, *states)
+ _state = statelib.merge_state(state, *states)
node = unflatten(
graphdef,
_state,
@@ -1735,7 +1736,7 @@ def split(
node, ref_index=ref_index, ref_outer_index=self.inner_ref_outer_index
)
states = tuple(
- State.from_flat_path(flat_state)
+ statelib.from_flat_state(flat_state)
for flat_state in _split_state(flat_state, filters)
)
assert len(states) >= 1
@@ -1764,7 +1765,7 @@ def merge(
# inner merge (2)
index_ref_cache = None
- state = State.merge(state, *states)
+ state = statelib.merge_state(state, *states)
index_ref: dict[Index, tp.Any] = {}
node = unflatten(
graphdef,
@@ -2040,7 +2041,7 @@ def split(
"""
graphdef, flat_state = flatten(node)
flat_states = _split_state(flat_state, filters)
- states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
+ states = tuple(statelib.from_flat_state(flat_state) for flat_state in flat_states)
return graphdef, *states # type: ignore[return-value]
@@ -2093,7 +2094,7 @@ def merge(
Returns:
The merged :class:`flax.nnx.Module`.
"""
- _state = State.merge(state, *states)
+ _state = statelib.merge_state(state, *states)
node = unflatten(graphdef, _state)
return node
@@ -2127,10 +2128,7 @@ def update(
*states: Additional :class:`State` objects.
"""
if states:
- if isinstance(state, State):
- state = type(state).merge(state, *states)
- else:
- state = State.merge(state, *states)
+ state = statelib.merge_state(state, *states)
_graph_update_dynamic(node, state)
@@ -2185,7 +2183,7 @@ def variables(
flat_states = variablelib.split_flat_state(
variables_iterable, (*filters, ...)
)
- states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states)
+ states = tuple(statelib.from_flat_state(flat_state) for flat_state in flat_states)
if num_filters < 2:
return states[0]
return states
@@ -2243,9 +2241,9 @@ def state(
if len(filters) == 0:
states = state
elif len(filters) == 1:
- states = state.filter(filters[0])
+ states = statelib.filter_state(state, filters[0])
else:
- states = state.filter(filters[0], filters[1], *filters[2:])
+ states = statelib.filter_state(state, filters[0], filters[1], *filters[2:])
return states
@@ -2342,7 +2340,7 @@ def pop(
predicates=predicates,
)
states = tuple(
- GraphState.from_flat_path(flat_state) for flat_state in flat_states
+ statelib.from_flat_state(flat_state) for flat_state in flat_states
)
if len(states) == 1:
diff --git a/flax/nnx/rnglib.py b/flax/nnx/rnglib.py
index a7b315089..e231c33ef 100644
--- a/flax/nnx/rnglib.py
+++ b/flax/nnx/rnglib.py
@@ -21,6 +21,7 @@
from flax import struct
from flax.nnx import graph
+from flax.nnx import statelib
from flax.nnx.statelib import State
from flax.nnx.variablelib import Variable
from flax.nnx import filterlib
@@ -260,11 +261,14 @@ def fork(
else:
num_splits = tuple(x if x is not None else 1 for x in split_pattern)
- split_keys, split_counts, broadcast_keys, broadcast_counts = state.split(
- All(split_filter, RngKey),
- All(split_filter, RngCount),
- RngKey,
- RngCount,
+ split_keys, split_counts, broadcast_keys, broadcast_counts = (
+ statelib.split_state(
+ state,
+ All(split_filter, RngKey),
+ All(split_filter, RngCount),
+ RngKey,
+ RngCount,
+ )
)
def split_key(key: tp.Any) -> jax.Array:
diff --git a/flax/nnx/statelib.py b/flax/nnx/statelib.py
index f2e9bd81b..b903de001 100644
--- a/flax/nnx/statelib.py
+++ b/flax/nnx/statelib.py
@@ -17,6 +17,7 @@
import typing as tp
from collections.abc import MutableMapping
from functools import partial
+import warnings
import jax
import jax.tree_util as jtu
@@ -112,7 +113,7 @@ def __iter__(self) -> tp.Iterator[tuple[PathParts, V]]:
return iter(zip(self._keys, self._values))
def to_nested_state(self) -> State[Key, V]:
- return State.from_flat_path(self)
+ return from_flat_state(self)
@tp.overload
def split(self, first: filterlib.Filter, /) -> FlatState[V]: ...
@@ -307,14 +308,20 @@ def __treescope_repr__(self, path, subtree_renderer):
)
def map(self, f: tp.Callable[[tuple, V], V]) -> State[K, V]:
- flat_state = self.flat_state()
- result = [
- (path, f(path, variable_state)) for path, variable_state in flat_state
- ]
- return State.from_flat_path(result)
+ warnings.warn(
+ '`flax.nnx.State` will be deprecated and be replaced by the built-in '
+ 'Python dict. Please use the equivalent `nnx.map_state` instead.',
+ DeprecationWarning,
+ )
+ return map_state(f, self)
def flat_state(self) -> FlatState[V]:
- return FlatState(traversals.flatten_to_sequence(self._mapping), sort=True)
+ warnings.warn(
+ '`flax.nnx.State` will be deprecated and be replaced by the built-in '
+ 'Python dict. Please use the equivalent `nnx.to_flat_state` instead.',
+ DeprecationWarning,
+ )
+ return to_flat_state(self)
@classmethod
def from_flat_path(
@@ -322,38 +329,33 @@ def from_flat_path(
flat_state: tp.Mapping[PathParts, V] | tp.Iterable[tuple[PathParts, V]],
/,
):
- if not isinstance(flat_state, tp.Mapping):
- flat_state = dict(flat_state)
- nested_state = traversals.unflatten_mapping(flat_state)
- return cls(nested_state)
+ warnings.warn(
+ '`flax.nnx.State` will be deprecated and be replaced by the built-in '
+ 'Python dict. Please use the equivalent `nnx.from_flat_state` instead.',
+ DeprecationWarning,
+ )
+ return from_flat_state(flat_state, cls=cls)
def to_pure_dict(self,
extract_fn: ExtractValueFn | None = None
) -> dict[str, tp.Any]:
- # Works for nnx.Variable and nnx.VariableState
- if extract_fn is None:
- extract_fn = lambda x: x.value if hasattr(x, 'value') else x
- flat_values = {k: extract_fn(x) for k, x in self.flat_state()}
- return traversals.unflatten_mapping(flat_values)
+ warnings.warn(
+ '`flax.nnx.State` will be deprecated and be replaced by the built-in '
+ 'Python dict. Please use the equivalent `nnx.to_pure_dict` instead.',
+ DeprecationWarning,
+ )
+ return to_pure_dict(self, extract_fn)
def replace_by_pure_dict(self,
pure_dict: dict[str, tp.Any],
replace_fn: SetValueFn | None = None):
- def try_convert_int(x):
- try:
- return int(x)
- except ValueError:
- return x
- # Works for nnx.Variable and nnx.VariableState
- if replace_fn is None:
- replace_fn = lambda x, v: x.replace(v) if hasattr(x, 'replace') else v
- current_flat = dict(self.flat_state())
- for kp, v in traversals.flatten_mapping(pure_dict).items():
- kp = tuple(map(try_convert_int, kp))
- if kp not in current_flat:
- raise ValueError(f'key in pure_dict not available in state: {kp}')
- current_flat[kp] = replace_fn(current_flat[kp], v)
- self.update(traversals.unflatten_mapping(current_flat))
+ warnings.warn(
+ '`flax.nnx.State` will be deprecated and be replaced by the built-in '
+ 'Python dict. Please use the equivalent `nnx.replace_by_pure_dict` '
+ 'instead.',
+ DeprecationWarning,
+ )
+ return replace_by_pure_dict(self, pure_dict, replace_fn)
@tp.overload
def split(self, first: filterlib.Filter, /) -> State[K, V]: ...
@@ -375,48 +377,12 @@ def split(
def split( # type: ignore[misc]
self, first: filterlib.Filter, /, *filters: filterlib.Filter
) -> tp.Union[State[K, V], tuple[State[K, V], ...]]:
- """Split a ``State`` into one or more ``State``'s. The
- user must pass at least one ``Filter`` (i.e. :class:`Variable`),
- and the filters must be exhaustive (i.e. they must cover all
- :class:`Variable` types in the ``State``).
-
- Example usage::
-
- >>> from flax import nnx
-
- >>> class Model(nnx.Module):
- ... def __init__(self, rngs):
- ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
- ... self.linear = nnx.Linear(2, 3, rngs=rngs)
- ... def __call__(self, x):
- ... return self.linear(self.batchnorm(x))
-
- >>> model = Model(rngs=nnx.Rngs(0))
- >>> state = nnx.state(model)
- >>> param, batch_stats = state.split(nnx.Param, nnx.BatchStat)
-
- Arguments:
- first: The first filter
- *filters: The optional, additional filters to group the state into mutually exclusive substates.
- Returns:
- One or more ``States`` equal to the number of filters passed.
- """
- filters = (first, *filters)
- flat_states = _split_state(self.flat_state(), *filters)
- *states_, rest = (state.to_nested_state() for state in flat_states)
-
- if rest:
- raise ValueError(
- 'Non-exhaustive filters, got a non-empty remainder: '
- f'{rest}.\nUse `...` to match all remaining elements.'
- )
-
- states: State | tuple[State, ...]
- if len(states_) == 1:
- states = states_[0]
- else:
- states = tuple(states_)
- return states # type: ignore
+ warnings.warn(
+ '`flax.nnx.State` will be deprecated and be replaced by the built-in '
+ 'Python dict. Please use the equivalent `nnx.split_state` instead.',
+ DeprecationWarning,
+ )
+ return split_state(self, first, *filters)
@tp.overload
def filter(
@@ -440,108 +406,34 @@ def filter(
/,
*filters: filterlib.Filter,
) -> tp.Union[State[K, V], tuple[State[K, V], ...]]:
- """Filter a ``State`` into one or more ``State``'s. The
- user must pass at least one ``Filter`` (i.e. :class:`Variable`).
- This method is similar to :meth:`split() `,
- except the filters can be non-exhaustive.
-
- Example usage::
-
- >>> from flax import nnx
-
- >>> class Model(nnx.Module):
- ... def __init__(self, rngs):
- ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
- ... self.linear = nnx.Linear(2, 3, rngs=rngs)
- ... def __call__(self, x):
- ... return self.linear(self.batchnorm(x))
-
- >>> model = Model(rngs=nnx.Rngs(0))
- >>> state = nnx.state(model)
- >>> param = state.filter(nnx.Param)
- >>> batch_stats = state.filter(nnx.BatchStat)
- >>> param, batch_stats = state.filter(nnx.Param, nnx.BatchStat)
-
- Arguments:
- first: The first filter
- *filters: The optional, additional filters to group the state into mutually exclusive substates.
- Returns:
- One or more ``States`` equal to the number of filters passed.
- """
- flat_states = _split_state(self.flat_state(), first, *filters)
- *states_, _rest = (state.to_nested_state() for state in flat_states)
-
- assert len(states_) == len(filters) + 1
-
- states: State | tuple[State, ...]
- if len(states_) == 1:
- states = states_[0]
- else:
- states = tuple(states_)
-
- return states # type: ignore
+ warnings.warn(
+ '`flax.nnx.State` will be deprecated and be replaced by the built-in '
+ 'Python dict. Please use the equivalent `nnx.filter_state` instead.',
+ DeprecationWarning,
+ )
+ return filter_state(self, first, *filters)
@classmethod
def merge(cls, state: tp.Mapping[K, V], /, *states: tp.Mapping[K, V]):
- """The inverse of :meth:`split() `.
-
- ``merge`` takes one or more ``State``'s and creates
- a new ``State``.
-
- Example usage::
-
- >>> from flax import nnx
- >>> import jax.numpy as jnp
-
- >>> class Model(nnx.Module):
- ... def __init__(self, rngs):
- ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
- ... self.linear = nnx.Linear(2, 3, rngs=rngs)
- ... def __call__(self, x):
- ... return self.linear(self.batchnorm(x))
-
- >>> model = Model(rngs=nnx.Rngs(0))
- >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
- >>> params.linear.bias.value += 1
-
- >>> state = nnx.State.merge(params, batch_stats)
- >>> nnx.update(model, state)
- >>> assert (model.linear.bias.value == jnp.array([1, 1, 1])).all()
-
- Args:
- state: A ``State`` object.
- *states: Additional ``State`` objects.
- Returns:
- The merged ``State``.
- """
- if not states:
- if isinstance(state, cls):
- return state
- return cls(state)
-
- states = (state, *states)
-
- new_state: dict[PathParts, V] = {}
-
- for state in states:
- new_state.update(traversals.flatten_mapping(state)) # type: ignore[attribute-error] # pytype is wrong here
-
- return cls.from_flat_path(new_state)
+ warnings.warn(
+ '`flax.nnx.State` will be deprecated and be replaced by the built-in '
+ 'Python dict. Please use the equivalent `nnx.merge_state` instead.',
+ DeprecationWarning,
+ )
+ return merge_state(state, *states)
def __or__(self, other: State[K, V]) -> State[K, V]:
if not other:
return self
- return State.merge(self, other)
+ return merge_state(self, other)
def __sub__(self, other: State[K, V]) -> State[K, V]:
- if not other:
- return self
-
- self_flat = dict(self.flat_state())
- other_flat = dict(other.flat_state())
- diff = {k: v for k, v in self_flat.items() if k not in other_flat}
-
- return State.from_flat_path(diff)
+ warnings.warn(
+ '`flax.nnx.State` will be deprecated and be replaced by the built-in '
+ 'Python dict. Please use the equivalent `nnx.diff` instead.',
+ DeprecationWarning,
+ )
+ return diff(self, other)
def __init_subclass__(cls) -> None:
super().__init_subclass__()
@@ -574,6 +466,253 @@ def _state_unflatten(
)
+def map_state(f: tp.Callable[[tuple, tp.Any], tp.Any], state: State) -> State:
+ flat_state = to_flat_state(state)
+ result = [
+ (path, f(path, variable_state)) for path, variable_state in flat_state
+ ]
+ return from_flat_state(result)
+
+
+def to_flat_state(state: State) -> FlatState:
+ return FlatState(traversals.flatten_to_sequence(state._mapping), sort=True)
+
+
+def from_flat_state(
+ flat_state: tp.Mapping[PathParts, V] | tp.Iterable[tuple[PathParts, V]],
+ *, cls = State, # for compatibility with State subclasses
+) -> State:
+ if not isinstance(flat_state, tp.Mapping):
+ flat_state = dict(flat_state)
+ nested_state = traversals.unflatten_mapping(flat_state)
+ return cls(nested_state)
+
+
+def to_pure_dict(
+ state, extract_fn: ExtractValueFn | None = None
+) -> dict[str, tp.Any]:
+ # Works for nnx.Variable and nnx.VariableState
+ if extract_fn is None:
+ extract_fn = lambda x: x.value if hasattr(x, 'value') else x
+ flat_values = {k: extract_fn(x) for k, x in to_flat_state(state)}
+ return traversals.unflatten_mapping(flat_values)
+
+
+def replace_by_pure_dict(
+ state, pure_dict: dict[str, tp.Any], replace_fn: SetValueFn | None = None
+):
+ def try_convert_int(x):
+ try:
+ return int(x)
+ except ValueError:
+ return x
+
+ # Works for nnx.Variable and nnx.VariableState
+ if replace_fn is None:
+ replace_fn = lambda x, v: x.replace(v) if hasattr(x, 'replace') else v
+ current_flat = dict(to_flat_state(state))
+ for kp, v in traversals.flatten_mapping(pure_dict).items():
+ kp = tuple(map(try_convert_int, kp))
+ if kp not in current_flat:
+ raise ValueError(f'key in pure_dict not available in state: {kp}')
+ current_flat[kp] = replace_fn(current_flat[kp], v)
+ state.update(traversals.unflatten_mapping(current_flat))
+
+
+@tp.overload
+def split_state(state: State, first: filterlib.Filter, /) -> State: ...
+
+
+@tp.overload
+def split_state(
+ state: State,
+ first: filterlib.Filter,
+ second: filterlib.Filter,
+ /,
+ *filters: filterlib.Filter,
+) -> tuple[State, ...]: ...
+
+
+@tp.overload
+def split_state(
+ state: State, /, *filters: filterlib.Filter
+) -> tp.Union[State, tuple[State, ...]]: ...
+
+
+def split_state( # type: ignore[misc]
+ state: State, first: filterlib.Filter, /, *filters: filterlib.Filter
+) -> tp.Union[State, tuple[State, ...]]:
+ """Split a ``State`` into one or more ``State``'s. The
+ user must pass at least one ``Filter`` (i.e. :class:`Variable`),
+ and the filters must be exhaustive (i.e. they must cover all
+ :class:`Variable` types in the ``State``).
+
+ Example usage::
+
+ >>> from flax import nnx
+
+ >>> class Model(nnx.Module):
+ ... def __init__(self, rngs):
+ ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
+ ... self.linear = nnx.Linear(2, 3, rngs=rngs)
+ ... def __call__(self, x):
+ ... return self.linear(self.batchnorm(x))
+
+ >>> model = Model(rngs=nnx.Rngs(0))
+ >>> state = nnx.state(model)
+ >>> param, batch_stats = nnx.split_state(state, nnx.Param, nnx.BatchStat)
+
+ Arguments:
+ first: The first filter
+ *filters: The optional, additional filters to group the state into mutually exclusive substates.
+ Returns:
+ One or more ``States`` equal to the number of filters passed.
+ """
+ filters = (first, *filters)
+ flat_states = _split_state(to_flat_state(state), *filters)
+ *states_, rest = (state.to_nested_state() for state in flat_states)
+
+ if rest:
+ raise ValueError(
+ 'Non-exhaustive filters, got a non-empty remainder: '
+ f'{rest}.\nUse `...` to match all remaining elements.'
+ )
+
+ states: State | tuple[State, ...]
+ if len(states_) == 1:
+ states = states_[0]
+ else:
+ states = tuple(states_)
+ return states # type: ignore
+
+
+
+@tp.overload
+def filter_state(
+ state: State,
+ first: filterlib.Filter,
+ /,
+) -> State: ...
+
+
+@tp.overload
+def filter_state(
+ state: State,
+ first: filterlib.Filter,
+ second: filterlib.Filter,
+ /,
+ *filters: filterlib.Filter,
+) -> tuple[State, ...]: ...
+
+
+def filter_state(
+ state: State,
+ first: filterlib.Filter,
+ /,
+ *filters: filterlib.Filter,
+) -> tp.Union[State, tuple[State, ...]]:
+ """Filter a ``State`` into one or more ``State``'s. The
+ user must pass at least one ``Filter`` (i.e. :class:`Variable`).
+ This method is similar to :meth:`split() `,
+ except the filters can be non-exhaustive.
+
+ Example usage::
+
+ >>> from flax import nnx
+
+ >>> class Model(nnx.Module):
+ ... def __init__(self, rngs):
+ ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
+ ... self.linear = nnx.Linear(2, 3, rngs=rngs)
+ ... def __call__(self, x):
+ ... return self.linear(self.batchnorm(x))
+
+ >>> model = Model(rngs=nnx.Rngs(0))
+ >>> state = nnx.state(model)
+ >>> param = nnx.filter_state(state, nnx.Param)
+ >>> batch_stats = nnx.filter_state(state, nnx.BatchStat)
+ >>> param, batch_stats = nnx.filter_state(state, nnx.Param, nnx.BatchStat)
+
+ Arguments:
+ first: The first filter
+ *filters: The optional, additional filters to group the state into mutually exclusive substates.
+ Returns:
+ One or more ``States`` equal to the number of filters passed.
+ """
+ flat_states = _split_state(to_flat_state(state), first, *filters)
+ *states_, _rest = (state.to_nested_state() for state in flat_states)
+
+ assert len(states_) == len(filters) + 1
+
+ states: State | tuple[State, ...]
+ if len(states_) == 1:
+ states = states_[0]
+ else:
+ states = tuple(states_)
+
+ return states # type: ignore
+
+
+def merge_state(state: tp.Mapping, /, *states: tp.Mapping,
+ cls = State # for compatibility with State subclasses
+ ) -> State:
+ """The inverse of :meth:`split() `.
+
+ ``merge`` takes one or more ``State``'s and creates
+ a new ``State``.
+
+ Example usage::
+
+ >>> from flax import nnx
+ >>> import jax.numpy as jnp
+
+ >>> class Model(nnx.Module):
+ ... def __init__(self, rngs):
+ ... self.batchnorm = nnx.BatchNorm(2, rngs=rngs)
+ ... self.linear = nnx.Linear(2, 3, rngs=rngs)
+ ... def __call__(self, x):
+ ... return self.linear(self.batchnorm(x))
+
+ >>> model = Model(rngs=nnx.Rngs(0))
+ >>> params, batch_stats = nnx.state(model, nnx.Param, nnx.BatchStat)
+ >>> params['linear']['bias'].value += 1
+
+ >>> state = nnx.merge_state(params, batch_stats)
+ >>> nnx.update(model, state)
+ >>> assert (model.linear.bias.value == jnp.array([1, 1, 1])).all()
+
+ Args:
+ state: A ``State`` object.
+ *states: Additional ``State`` objects.
+ Returns:
+ The merged ``State``.
+ """
+ if not states:
+ if isinstance(state, cls):
+ return state
+ return cls(state)
+
+ states = (state, *states)
+
+ new_state: dict[PathParts, tp.Any] = {}
+
+ for state in states:
+ new_state.update(traversals.flatten_mapping(state)) # type: ignore[attribute-error] # pytype is wrong here
+
+ return from_flat_state(new_state, cls=cls)
+
+
+def diff(state: State, other: State) -> State:
+ if not other:
+ return state
+
+ self_flat = to_flat_state(state)
+ other_flat = to_flat_state(other)
+ diff = {k: v for k, v in self_flat.items() if k not in other_flat}
+
+ return from_flat_state(diff)
+
+
def _split_state(
flat_state: FlatState[V],
*filters: filterlib.Filter,
@@ -607,7 +746,7 @@ def _split_state(
def create_path_filters(state: State):
- flat_state = state.flat_state()
+ flat_state = to_flat_state(state)
value_paths: dict[tp.Any, set[PathParts]] = {}
for path, value in flat_state:
if isinstance(value, (variablelib.Variable, variablelib.VariableState)):
diff --git a/flax/nnx/summary.py b/flax/nnx/summary.py
index fb2f49d31..7f1aed533 100644
--- a/flax/nnx/summary.py
+++ b/flax/nnx/summary.py
@@ -336,7 +336,7 @@ def tabulate(
nnx.RngState # type: ignore[misc]
if issubclass(variable_state.type, nnx.RngState)
else variable_state.type
- for _, variable_state in nnx.state(obj).flat_state()
+ for _, variable_state in nnx.to_flat_state(nnx.state(obj))
}
variable_types: list[type] = sorted(_variable_types, key=lambda t: t.__name__)
diff --git a/flax/nnx/transforms/deprecated.py b/flax/nnx/transforms/deprecated.py
index 844cea485..97d343729 100644
--- a/flax/nnx/transforms/deprecated.py
+++ b/flax/nnx/transforms/deprecated.py
@@ -21,6 +21,7 @@
from flax import struct
from flax.core.frozen_dict import FrozenDict
from flax.nnx import extract, filterlib, graph, rnglib, spmd, variablelib
+from flax.nnx import statelib
from flax.nnx.module import GraphDef, Module
from flax.nnx.proxy_caller import DelayedAccessor
from flax.nnx.statelib import State
@@ -81,11 +82,14 @@ def _fork_vmap_keys(
split_filter: filterlib.Filter,
num_splits: int,
) -> _VmapForkStates:
- split_keys, split_counts, broadcast_keys, broadcast_counts = state.split(
- filterlib.All(split_filter, rnglib.RngKey),
- filterlib.All(split_filter, rnglib.RngCount),
- rnglib.RngKey,
- rnglib.RngCount,
+ split_keys, split_counts, broadcast_keys, broadcast_counts = (
+ statelib.split_state(
+ state,
+ filterlib.All(split_filter, rnglib.RngKey),
+ filterlib.All(split_filter, rnglib.RngCount),
+ rnglib.RngKey,
+ rnglib.RngCount,
+ )
)
def split_key(key: tp.Any, count: tp.Any) -> jax.Array:
@@ -200,9 +204,13 @@ def vmap_fn(
*filters,
)
- split_keys_out, broadcast_keys_out = rng_state_out.split(split_rngs, ...)
+ split_keys_out, broadcast_keys_out = statelib.split_state(
+ rng_state_out, split_rngs, ...
+ )
- broadcast_state_out = State.merge(broadcast_state_out, broadcast_keys_out)
+ broadcast_state_out = statelib.merge_state(
+ broadcast_state_out, broadcast_keys_out
+ )
# add metadata axis name to Variable.sharding
if spmd.PARTITION_NAME in transform_metadata:
@@ -567,11 +575,11 @@ def pmap_fn(
*filters,
)
- not_keys_out, split_keys_out, broadcast_keys_out = rng_state_out.split(
- rnglib.NotKey, split_rngs, ...
+ not_keys_out, split_keys_out, broadcast_keys_out = statelib.split_state(
+ rng_state_out, rnglib.NotKey, split_rngs, ...
)
- broadcast_state_out = State.merge(
+ broadcast_state_out = statelib.merge_state(
broadcast_state_out, broadcast_keys_out, not_keys_out
)
@@ -1138,8 +1146,8 @@ def scan_fn(
*filters,
)
- split_rng_state_out, broadcast_rng_state_out = rng_state_out.split(
- broadcasts.split_rngs, ...
+ split_rng_state_out, broadcast_rng_state_out = statelib.split_state(
+ rng_state_out, broadcasts.split_rngs, ...
)
def _extract_carry_state(state: State, /):
@@ -1306,7 +1314,9 @@ def scan_apply_wrapper(*args, **kwargs):
)
# split rng state
- split_rng_state, broadcast_rng_state = rng_state.split(split_rngs, ...)
+ split_rng_state, broadcast_rng_state = statelib.split_state(
+ rng_state, split_rngs, ...
+ )
broadcasts = ScanBroadcasts(
flatdef,
diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py
index d308c8c30..01b3ac27a 100644
--- a/tests/nnx/bridge/wrappers_test.py
+++ b/tests/nnx/bridge/wrappers_test.py
@@ -82,7 +82,7 @@ def __call__(self, x):
assert gdef_before_lazy_init != gdef_full
assert 'nn_dense1' in state
assert 'batchnorm' in state
- assert 'kernel' in state.nn_dense1
+ assert 'kernel' in state['nn_dense1']
y = model(x)
k, b = state['nn_dense1']['kernel'].value, state['b'].value
np.testing.assert_allclose(y, x @ k + b, rtol=1e-5)
@@ -104,7 +104,7 @@ def dot(self, x):
model = bridge.ToNNX(Foo(), rngs=nnx.Rngs(0))
bridge.lazy_init(model, x, method=model.module.dot)
y = model(x, method=model.module.dot)
- np.testing.assert_allclose(y, x @ nnx.state(model).w.value)
+ np.testing.assert_allclose(y, x @ nnx.state(model)['w'].value)
# lazy_init only initialized param w inside dot(), so calling __call__ should fail
with self.assertRaises(flax.errors.ScopeParamNotFoundError):
y = model(x)
@@ -121,9 +121,9 @@ def __call__(self, x):
x = lambda: jnp.zeros((), jnp.int32)
model = bridge.ToNNX(Foo(), rngs=nnx.Rngs(0)).lazy_init(x)
- self.assertEqual(nnx.state(model).count.value, 0)
+ self.assertEqual(nnx.state(model)['count'].value, 0)
y = model(x, mutable=True)
- self.assertEqual(nnx.state(model).count.value, 1)
+ self.assertEqual(nnx.state(model)['count'].value, 1)
def test_linen_to_nnx_transform(self):
class NNXOuter(nnx.Module):
diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py
index 397198ae4..af50ad61e 100644
--- a/tests/nnx/graph_utils_test.py
+++ b/tests/nnx/graph_utils_test.py
@@ -103,7 +103,7 @@ def test_unflatten_pure_dict(self):
g = List([a, 3, a, nnx.Param(4)])
graphdef, state = nnx.split(g)
- pure_state = state.to_pure_dict()
+ pure_state = nnx.to_pure_dict(state)
g = nnx.merge(graphdef, pure_state)
@@ -175,7 +175,7 @@ def test_update_from_pure_dict(self):
g = [a, 3, a, nnx.Param(4)]
graphdef, state = nnx.split(g)
- pure_state = state.to_pure_dict()
+ pure_state = nnx.to_pure_dict(state)
pure_state[0]['b'] = 3
nnx.update(g, pure_state)
@@ -206,7 +206,7 @@ def test_shared_variables(self):
graphdef, state = nnx.split(g)
- assert len(state.flat_state()) == 1
+ assert len(nnx.to_flat_state(state)) == 1
g2 = nnx.merge(graphdef, state)
@@ -224,7 +224,7 @@ def __init__(self, *, rngs: nnx.Rngs) -> None:
node = Foo(rngs=nnx.Rngs(0))
graphdef, state = nnx.split(node)
- assert len(state.flat_state()) == 3 # 2 bias + 1 kernel
+ assert len(nnx.to_flat_state(state)) == 3 # 2 bias + 1 kernel
node2 = nnx.merge(graphdef, state)
@@ -257,7 +257,7 @@ def __call__(self, x):
model = Encoder(rngs=nnx.Rngs(0))
graphdef, state = nnx.split(model)
- assert len(state.flat_state()) == 1
+ assert len(nnx.to_flat_state(state)) == 1
x = jax.random.randint(jax.random.key(0), (2,), 0, 10)
y = model(x)
@@ -273,16 +273,16 @@ def __init__(self):
graphdef, state = nnx.split(m)
assert isinstance(m.a, nnx.Param)
- assert issubclass(state.a.type, nnx.Param)
- assert m.a is not state.a
- assert m.a.value == state.a.value
+ assert issubclass(state['a'].type, nnx.Param)
+ assert m.a is not state['a']
+ assert m.a.value == state['a'].value
m2 = nnx.merge(graphdef, state)
assert isinstance(m2.a, nnx.Param)
- assert issubclass(state.a.type, nnx.Param)
- assert m2.a is not state.a
- assert m2.a.value == state.a.value
+ assert issubclass(state['a'].type, nnx.Param)
+ assert m2.a is not state['a']
+ assert m2.a.value == state['a'].value
def test_shared_state_variables_not_shared_with_graph(self):
class Foo(nnx.Module):
@@ -296,22 +296,22 @@ def __init__(self):
assert isinstance(m.a, nnx.Param)
assert isinstance(m.b, nnx.Param)
- assert issubclass(state.a.type, nnx.Param)
+ assert issubclass(state['a'].type, nnx.Param)
assert 'b' not in state
- assert m.a is not state.a
- assert m.b is not state.a
- assert m.a.value == state.a.value
- assert m.b.value == state.a.value
+ assert m.a is not state['a']
+ assert m.b is not state['a']
+ assert m.a.value == state['a'].value
+ assert m.b.value == state['a'].value
m2 = nnx.merge(graphdef, state)
assert isinstance(m2.a, nnx.Param)
assert isinstance(m2.b, nnx.Param)
- assert issubclass(state.a.type, nnx.Param)
- assert m2.a is not state.a
- assert m2.b is not state.a
- assert m2.a.value == state.a.value
- assert m2.b.value == state.a.value
+ assert issubclass(state['a'].type, nnx.Param)
+ assert m2.a is not state['a']
+ assert m2.b is not state['a']
+ assert m2.a.value == state['a'].value
+ assert m2.b.value == state['a'].value
assert m2.a is m2.b
def test_pytree_flatten(self):
@@ -349,7 +349,7 @@ def __init__(self):
graphdef, state = nnx.split(m)
assert 'tree' in state
- assert 'a' in state.tree
+ assert 'a' in state['tree']
assert graphdef.attributes[0][1].type is nnx.graph.GenericPytree
m2 = nnx.merge(graphdef, state)
@@ -580,8 +580,8 @@ def test_split_merge_context(self):
self.assertFalse(hasattr(ctx, 'ctxtag'))
self.assertIsInstance(graphdef1, nnx.graph.NodeDef)
self.assertIsInstance(graphdef2, nnx.graph.NodeRef)
- self.assertLen(state1.flat_state(), 2)
- self.assertLen(state2.flat_state(), 0)
+ self.assertLen(nnx.to_flat_state(state1), 2)
+ self.assertLen(nnx.to_flat_state(state2), 0)
with nnx.graph.merge_context() as ctx:
m1 = ctx.merge(graphdef1, state1)
@@ -600,8 +600,8 @@ def test_split_merge_context_nested(self):
self.assertIsInstance(graphdef1, nnx.graph.NodeDef)
self.assertIsInstance(graphdef2, nnx.graph.NodeRef)
- self.assertLen(state1.flat_state(), 2)
- self.assertLen(state2.flat_state(), 0)
+ self.assertLen(nnx.to_flat_state(state1), 2)
+ self.assertLen(nnx.to_flat_state(state2), 0)
with nnx.graph.merge_context() as ctx:
m1 = ctx.merge(graphdef1, state1)
@@ -630,8 +630,8 @@ def __init__(self):
self.assertFalse(hasattr(ctx, 'ctxtag'))
self.assertIsInstance(graphdef1, nnx.graph.NodeDef)
self.assertIsInstance(graphdef2, nnx.graph.NodeRef)
- self.assertLen(state1.flat_state(), 1)
- self.assertLen(state2.flat_state(), 0)
+ self.assertLen(nnx.to_flat_state(state1), 1)
+ self.assertLen(nnx.to_flat_state(state2), 0)
@jax.jit
def f(graphdef1, state1, graphdef2, state2):
@@ -684,8 +684,8 @@ def test_to_tree_simple(self):
assert isinstance(t2, nnx.NodeStates)
self.assertIsInstance(t1.graphdef, nnx.graph.NodeDef)
self.assertIsInstance(t2.graphdef, nnx.graph.NodeRef)
- self.assertLen(t1.states[0].flat_state(), 2)
- self.assertLen(t2.states[0].flat_state(), 0)
+ self.assertLen(nnx.to_flat_state(t1.states[0]), 2)
+ self.assertLen(nnx.to_flat_state(t2.states[0]), 0)
impure_tree2 = nnx.from_tree(pure_tree)
@@ -719,8 +719,8 @@ def __init__(self):
assert isinstance(t2, nnx.NodeStates)
self.assertIsInstance(t1.graphdef, nnx.graph.NodeDef)
self.assertIsInstance(t2.graphdef, nnx.graph.NodeRef)
- self.assertLen(t1.states[0].flat_state(), 1)
- self.assertLen(t2.states[0].flat_state(), 0)
+ self.assertLen(nnx.to_flat_state(t1.states[0]), 1)
+ self.assertLen(nnx.to_flat_state(t2.states[0]), 0)
@jax.jit
def f(pure_tree):
@@ -746,8 +746,8 @@ def f(pure_tree):
assert isinstance(t2, nnx.NodeStates)
self.assertIsInstance(t1.graphdef, nnx.graph.NodeDef)
self.assertIsInstance(t2.graphdef, nnx.graph.NodeRef)
- self.assertLen(t1.states[0].flat_state(), 1)
- self.assertLen(t2.states[0].flat_state(), 0)
+ self.assertLen(nnx.to_flat_state(t1.states[0]), 1)
+ self.assertLen(nnx.to_flat_state(t2.states[0]), 0)
return pure_tree2
diff --git a/tests/nnx/integration_test.py b/tests/nnx/integration_test.py
index 7b572f4b1..1831f4e00 100644
--- a/tests/nnx/integration_test.py
+++ b/tests/nnx/integration_test.py
@@ -205,7 +205,7 @@ def train_step(params, counts, x, y):
def loss_fn(params):
y_pred, (_, updates) = graphdef.apply(params, counts)(x)
loss = jax.numpy.mean((y_pred - y) ** 2)
- return loss, updates.filter(Count)
+ return loss, nnx.filter_state(updates, Count)
# compute gradient
grads, counts = jax.grad(loss_fn, has_aux=True)(params)
@@ -257,7 +257,7 @@ def __call__(self, x):
y, (_, state) = graphdef.apply(state)(jnp.ones((8, 12)))
- intermediates, state = state.split(nnx.Intermediate, ...)
+ intermediates, state = nnx.split_state(state, nnx.Intermediate, ...)
assert 'y' in intermediates
@@ -278,7 +278,7 @@ def __call__(self, x):
assert model(x).shape == (3, 4)
_, state = nnx.split(model)
- pure_dict_state = state.to_pure_dict()
+ pure_dict_state = nnx.to_pure_dict(state)
nnx.display(pure_dict_state)
with tempfile.TemporaryDirectory() as tmpdir:
@@ -295,7 +295,7 @@ def __call__(self, x):
abstract_model = nnx.eval_shape(lambda: MLPs(4, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)
- abstract_state.replace_by_pure_dict(restored_pure_dict)
+ nnx.replace_by_pure_dict(abstract_state, restored_pure_dict)
model = nnx.merge(graphdef, abstract_state)
assert model(x).shape == (3, 4) # The model still works!
diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py
index 316389095..25dfc5636 100644
--- a/tests/nnx/module_test.py
+++ b/tests/nnx/module_test.py
@@ -249,7 +249,7 @@ def test_deref_number_of_fields(self):
)
graphdef, p = nnx.split(m)
- assert len(p.flat_state()) == 2
+ assert len(nnx.to_flat_state(p)) == 2
assert len(jax.tree_util.tree_leaves(p)) == 2
def test_clone(self):
@@ -286,7 +286,7 @@ def __call__(self, x):
intermediates = nnx.pop(m, nnx.Intermediate)
- assert issubclass(intermediates.y.type, nnx.Intermediate)
+ assert issubclass(intermediates['y'].type, nnx.Intermediate)
assert intermediates['y'].value == (3, 11)
assert not hasattr(m, 'y')
@@ -607,7 +607,7 @@ def __call__(self, x):
obj = Foo(nnx.Rngs(0))
- leaves = nnx.state(obj).flat_state().leaves
+ leaves = nnx.to_flat_state(nnx.state(obj)).leaves
expected_total = sum(int(np.prod(x.value.shape)) for x in leaves)
expected_total_params = sum(
@@ -688,14 +688,14 @@ class Foo(nnx.Module):
graphdef, state = nnx.split(m)
assert len(state) == 4
- assert state.b.value == 2
- assert state.b.type == nnx.Variable
- assert state.c.value == 3
- assert state.c.type == nnx.Param
- assert state.d.value == 4
- assert state.d.type == nnx.Variable
- assert state.e.value == 5
- assert state.e.type == nnx.BatchStat
+ assert state['b'].value == 2
+ assert state['b'].type == nnx.Variable
+ assert state['c'].value == 3
+ assert state['c'].type == nnx.Param
+ assert state['d'].value == 4
+ assert state['d'].type == nnx.Variable
+ assert state['e'].value == 5
+ assert state['e'].type == nnx.BatchStat
def test_post_init(self):
@@ -733,7 +733,7 @@ def __call__(self, x, *, rngs: nnx.Rngs):
graphdef, states = nnx.split(foo)
assert isinstance(states, nnx.State)
- assert issubclass(states.w.type, nnx.Param)
+ assert issubclass(states['w'].type, nnx.Param)
y, _updates = graphdef.apply(states)(x=2.0, rngs=nnx.Rngs(e=1))
@@ -757,8 +757,8 @@ def __call__(self, x, *, rngs: nnx.Rngs):
assert isinstance(graphdef, nnx.graph.NodeDef | nnx.graph.NodeRef)
assert isinstance(state, nnx.State)
- assert issubclass(state.w.type, nnx.Param)
- assert issubclass(state.c.type, nnx.Variable)
+ assert issubclass(state['w'].type, nnx.Param)
+ assert issubclass(state['c'].type, nnx.Variable)
y, (graphdef, state) = graphdef.apply(state)(x=2.0, rngs=nnx.Rngs(e=1))
@@ -820,7 +820,7 @@ def __init__(self):
graphdef, state = nnx.split(foo)
assert isinstance(state, nnx.State)
- assert isinstance(state.a, nnx.State)
+ assert isinstance(state['a'], nnx.State)
foo2 = nnx.merge(graphdef, state)
diff --git a/tests/nnx/nn/lora_test.py b/tests/nnx/nn/lora_test.py
index e525b8ad0..db501e5eb 100644
--- a/tests/nnx/nn/lora_test.py
+++ b/tests/nnx/nn/lora_test.py
@@ -107,12 +107,12 @@ def test_lora_param_type(self):
_, lora_params, params = nnx.split(model, nnx.LoRAParam, nnx.Param)
assert params == {}
assert ('lora_a' in lora_params) and ('lora_b' in lora_params)
- np.testing.assert_allclose(lora_params.lora_a.value, model.lora_a.value)
+ np.testing.assert_allclose(lora_params['lora_a'].value, model.lora_a.value)
model = nnx.LoRA(3, 4, 2, lora_param_type=nnx.Param, rngs=rngs)
_, params, lora_params = nnx.split(model, nnx.Param, nnx.LoRAParam)
assert ('lora_a' in params) and ('lora_b' in params)
- np.testing.assert_allclose(params.lora_a.value, model.lora_a.value)
+ np.testing.assert_allclose(params['lora_a'].value, model.lora_a.value)
assert lora_params == {}
diff --git a/tests/nnx/optimizer_test.py b/tests/nnx/optimizer_test.py
index 88763f143..05aac74c7 100644
--- a/tests/nnx/optimizer_test.py
+++ b/tests/nnx/optimizer_test.py
@@ -85,16 +85,16 @@ def test_sharding_propagation(self):
state = nnx.state(optimizer)
partition_spec = nnx.get_partition_spec(state)
- self.assertEqual(state.opt_state[0].mu.kernel.sharding, ('a', 'b'))
+ self.assertEqual(state['opt_state'][0]['mu']['kernel'].sharding, ('a', 'b'))
self.assertEqual(
- partition_spec.opt_state[0].mu.kernel.value,
- jax.sharding.PartitionSpec('a', 'b'),
+ partition_spec['opt_state'][0]['mu']['kernel'].value,
+ jax.sharding.PartitionSpec('a', 'b'),
)
@parameterized.product(
- module_cls=[nnx.Linear, Model],
- jit_decorator=[lambda f: f, nnx.jit, jax.jit],
- optimizer=[optax.sgd, optax.adam],
+ module_cls=[nnx.Linear, Model],
+ jit_decorator=[lambda f: f, nnx.jit, jax.jit],
+ optimizer=[optax.sgd, optax.adam],
)
def test_jit(self, module_cls, jit_decorator, optimizer):
x = jax.random.normal(jax.random.key(0), (1, 2))
diff --git a/tests/nnx/partitioning_test.py b/tests/nnx/partitioning_test.py
index bb859de3a..a183f5537 100644
--- a/tests/nnx/partitioning_test.py
+++ b/tests/nnx/partitioning_test.py
@@ -185,8 +185,8 @@ def test_get_paritition(self):
self.assertEqual(state['a']['0'].value, m.a['0'].value)
self.assertEqual(state['a']['1'].value, m.a['1'].value)
self.assertEqual(state['b'].value, m.b.value)
- self.assertIsNot(state.b, state.a['0'])
- self.assertLen(state.flat_state(), 3)
+ self.assertIsNot(state['b'], state['a']['0'])
+ self.assertLen(nnx.to_flat_state(state), 3)
if __name__ == '__main__':
diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py
index d3eb2197c..f57fb84fd 100644
--- a/tests/nnx/rngs_test.py
+++ b/tests/nnx/rngs_test.py
@@ -134,9 +134,9 @@ def __call__(self, x):
self.assertEqual(m.rngs.params.count.value, 2)
self.assertEqual(m.rngs['dropout'].count.value, 0)
- self.assertLen(dropout_keys.flat_state(), 1)
- self.assertLen(param_keys.flat_state(), 1)
- self.assertLen(rng_counts.flat_state(), 2)
+ self.assertLen(nnx.to_flat_state(dropout_keys), 1)
+ self.assertLen(nnx.to_flat_state(param_keys), 1)
+ self.assertLen(nnx.to_flat_state(rng_counts), 2)
# split dropout keys
split_dropout_keys = jax.tree.map(
@@ -184,10 +184,10 @@ def test_state_fork_split(self):
self.assertLen(jax.tree.leaves(split_counts), 2)
self.assertEmpty(jax.tree.leaves(broadcast_keys))
self.assertEmpty(jax.tree.leaves(broadcast_counts))
- self.assertEqual(split_keys.params.key.value.shape, (4,))
- self.assertEqual(split_keys.dropout.key.value.shape, (4,))
- self.assertEqual(split_counts.params.count.value, 0)
- self.assertEqual(split_counts.dropout.count.value, 0)
+ self.assertEqual(split_keys['params']['key'].value.shape, (4,))
+ self.assertEqual(split_keys['dropout']['key'].value.shape, (4,))
+ self.assertEqual(split_counts['params']['count'].value, 0)
+ self.assertEqual(split_counts['dropout']['count'].value, 0)
def test_state_fork_split_and_broadcast(self):
rngs = nnx.Rngs(params=0, dropout=1)
@@ -200,10 +200,10 @@ def test_state_fork_split_and_broadcast(self):
self.assertLen(jax.tree.leaves(split_counts), 1)
self.assertLen(jax.tree.leaves(broadcast_keys), 1)
self.assertLen(jax.tree.leaves(broadcast_counts), 1)
- self.assertEqual(split_keys.params.key.value.shape, (4,))
- self.assertEqual(broadcast_keys.dropout.key.value.shape, ())
- self.assertEqual(split_counts.params.count.value, 0)
- self.assertEqual(broadcast_counts.dropout.count.value, 0)
+ self.assertEqual(split_keys['params']['key'].value.shape, (4,))
+ self.assertEqual(broadcast_keys['dropout']['key'].value.shape, ())
+ self.assertEqual(split_counts['params']['count'].value, 0)
+ self.assertEqual(broadcast_counts['dropout']['count'].value, 0)
def test_state_fork_multidimensional_split(self):
rngs = nnx.Rngs(params=0, dropout=1)
@@ -216,10 +216,10 @@ def test_state_fork_multidimensional_split(self):
self.assertLen(jax.tree.leaves(split_counts), 2)
self.assertEmpty(jax.tree.leaves(broadcast_keys))
self.assertEmpty(jax.tree.leaves(broadcast_counts))
- self.assertEqual(split_keys.params.key.value.shape, (4, 1, 3))
- self.assertEqual(split_keys.dropout.key.value.shape, (4, 1, 3))
- self.assertEqual(split_counts.params.count.value, 0)
- self.assertEqual(split_counts.dropout.count.value, 0)
+ self.assertEqual(split_keys['params']['key'].value.shape, (4, 1, 3))
+ self.assertEqual(split_keys['dropout']['key'].value.shape, (4, 1, 3))
+ self.assertEqual(split_counts['params']['count'].value, 0)
+ self.assertEqual(split_counts['dropout']['count'].value, 0)
def test_state_fork_multidimensional_split_mixed(self):
rngs = nnx.Rngs(params=0, dropout=1)
@@ -232,10 +232,10 @@ def test_state_fork_multidimensional_split_mixed(self):
self.assertLen(jax.tree.leaves(split_counts), 1)
self.assertLen(jax.tree.leaves(broadcast_keys), 1)
self.assertLen(jax.tree.leaves(broadcast_counts), 1)
- self.assertEqual(split_keys.params.key.value.shape, (4, 1, 3))
- self.assertEqual(broadcast_keys.dropout.key.value.shape, ())
- self.assertEqual(split_counts.params.count.value, 0)
- self.assertEqual(broadcast_counts.dropout.count.value, 0)
+ self.assertEqual(split_keys['params']['key'].value.shape, (4, 1, 3))
+ self.assertEqual(broadcast_keys['dropout']['key'].value.shape, ())
+ self.assertEqual(split_counts['params']['count'].value, 0)
+ self.assertEqual(broadcast_counts['dropout']['count'].value, 0)
def test_reseed(self):
class Model(nnx.Module):
diff --git a/tests/nnx/state_test.py b/tests/nnx/state_test.py
index 3cfde22e0..e1f78d1d5 100644
--- a/tests/nnx/state_test.py
+++ b/tests/nnx/state_test.py
@@ -80,15 +80,15 @@ def __init__(self, *, rngs: nnx.Rngs):
def test_pure_dict(self):
module = nnx.Linear(4, 5, rngs=nnx.Rngs(0))
state = nnx.state(module)
- pure_dict = state.to_pure_dict()
+ pure_dict = nnx.to_pure_dict(state)
assert isinstance(pure_dict, dict)
assert isinstance(pure_dict['kernel'], jax.Array)
assert isinstance(pure_dict['bias'], jax.Array)
- state.replace_by_pure_dict(jax.tree.map(jnp.zeros_like, pure_dict))
+ nnx.replace_by_pure_dict(state, jax.tree.map(jnp.zeros_like, pure_dict))
assert isinstance(state, nnx.State)
- assert isinstance(state.kernel, nnx.VariableState)
- assert jnp.array_equal(state.kernel.value, jnp.zeros((4, 5)))
- assert state.kernel.type == nnx.Param
+ assert isinstance(state['kernel'], nnx.VariableState)
+ assert jnp.array_equal(state['kernel'].value, jnp.zeros((4, 5)))
+ assert state['kernel'].type == nnx.Param
nnx.update(module, state)
assert jnp.array_equal(module(jnp.ones((3, 4))), jnp.zeros((3, 5)))
diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py
index 41208e882..0b846b757 100644
--- a/tests/nnx/transforms_test.py
+++ b/tests/nnx/transforms_test.py
@@ -1,4 +1,3 @@
-import dataclasses
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import dataclasses
from functools import partial
import typing as tp
@@ -439,14 +439,14 @@ def f(m: Dict):
assert m.a[0] is m.b
assert isinstance(grads, nnx.State)
assert grads['a']['0'].value == 2.0
- assert issubclass(grads.a['0'].type, nnx.Variable)
+ assert issubclass(grads['a']['0'].type, nnx.Variable)
assert grads['a']['1'].value == 1.0
- assert issubclass(grads.a['1'].type, nnx.Variable)
- assert len(grads.flat_state()) == 2
+ assert issubclass(grads['a']['1'].type, nnx.Variable)
+ assert len(nnx.to_flat_state(grads)) == 2
nnx.update(m, grads)
- assert m.a[0] is m.b
+ assert m['a'][0] is m.b
assert m['a'][0].value == 2.0
assert m['a'][1].value == 1.0
assert m['b'].value == 2.0
@@ -470,7 +470,7 @@ def f(m: Dict):
assert isinstance(grads, nnx.State)
assert grads['a']['0'].value == 1.0
- assert issubclass(grads.a['0'].type, nnx.Param)
+ assert issubclass(grads['a']['0'].type, nnx.Param)
assert len(grads) == 2
nnx.update(m, grads)
@@ -498,7 +498,7 @@ def f(m: Dict):
assert isinstance(grads, nnx.State)
assert grads['a']['1'].value == 1.0
- assert issubclass(grads.a['1'].type, nnx.BatchStat)
+ assert issubclass(grads['a']['1'].type, nnx.BatchStat)
assert len(grads) == 1
nnx.update(m, grads)
@@ -519,9 +519,9 @@ def test_multiple_inputs(self):
grads = grad_fn(m, x, y)
assert 'kernel' in grads
- assert grads.kernel.value.shape == (2, 3)
+ assert grads['kernel'].value.shape == (2, 3)
assert 'bias' in grads
- assert grads.bias.value.shape == (3,)
+ assert grads['bias'].value.shape == (3,)
@parameterized.parameters(
{
@@ -546,13 +546,13 @@ def test_multiple_graph_nodes(self, loss_fn, argnums):
grads_m1, grads_m2 = grad_fn(*inputs)
assert 'kernel' in grads_m1
- assert grads_m1.kernel.value.shape == (2, 3)
+ assert grads_m1['kernel'].value.shape == (2, 3)
assert 'bias' in grads_m1
- assert grads_m1.bias.value.shape == (3,)
+ assert grads_m1['bias'].value.shape == (3,)
assert 'kernel' in grads_m2
- assert grads_m2.kernel.value.shape == (3, 3)
+ assert grads_m2['kernel'].value.shape == (3, 3)
assert 'bias' in grads_m2
- assert grads_m2.bias.value.shape == (3,)
+ assert grads_m2['bias'].value.shape == (3,)
def test_multiple_args(self):
m1 = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
@@ -695,8 +695,8 @@ def f_bwd(res, g):
self.assertIsInstance(m, Foo)
# m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g})
- m_g.x.value = cos_x * out_g * m.y
- m_g.y.value = sin_x * out_g
+ m_g['x'].value = cos_x * out_g * m.y
+ m_g['y'].value = sin_x * out_g
return (m_g,)
f.defvjp(f_fwd, f_bwd)
@@ -738,7 +738,7 @@ def f_bwd(res, g):
self.assertEqual(out_g.shape, ())
self.assertIsInstance(m, Foo)
- m_g.x.value = cos_x * out_g * m.y
+ m_g['x'].value = cos_x * out_g * m.y
del m_g['y']
return (m_g,)
@@ -779,8 +779,8 @@ def f_bwd(res, g):
self.assertIsInstance(m, Foo)
# m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g})
- m_g.x.value = cos_x * out_g * m.y
- m_g.y.value = sin_x * out_g
+ m_g['x'].value = cos_x * out_g * m.y
+ m_g['y'].value = sin_x * out_g
return (m_g,)
f.defvjp(f_fwd, f_bwd)
@@ -884,8 +884,8 @@ def f_bwd(a, b, res, g):
self.assertIsInstance(m, Foo)
# m_g = nnx.State({'x': cos_x * out_g * m.y, 'y': sin_x * out_g})
- m_g.x.value = cos_x * out_g * m.y
- m_g.y.value = sin_x * out_g
+ m_g['x'].value = cos_x * out_g * m.y
+ m_g['y'].value = sin_x * out_g
return (m_g,)
f.defvjp(f_fwd, f_bwd)
@@ -1712,10 +1712,10 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]:
# test sharding layer axes is not present inside scan
state = nnx.state(self.linear)
- assert state.kernel.value.shape == (3, 3) # type: ignore
- assert state.kernel.sharding == ('din', 'dout') # type: ignore
- assert state.bias.value.shape == (3,) # type: ignore
- assert state.bias.sharding == ('dout',) # type: ignore
+ assert state['kernel'].value.shape == (3, 3) # type: ignore
+ assert state['kernel'].sharding == ('din', 'dout') # type: ignore
+ assert state['bias'].value.shape == (3,) # type: ignore
+ assert state['bias'].sharding == ('dout',) # type: ignore
return x, None
@@ -1730,20 +1730,28 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]:
# test sharding layers axes is set
state = nnx.state(m)
- assert state.scan_module.linear.kernel.value.shape == (5, 3, 3)
- assert state.scan_module.linear.kernel.sharding == ('layers', 'din', 'dout')
- assert state.scan_module.linear.bias.value.shape == (5, 3)
- assert state.scan_module.linear.bias.sharding == ('layers', 'dout')
+ assert state['scan_module']['linear']['kernel'].value.shape == (5, 3, 3)
+ assert state['scan_module']['linear']['kernel'].sharding == (
+ 'layers',
+ 'din',
+ 'dout',
+ )
+ assert state['scan_module']['linear']['bias'].value.shape == (5, 3)
+ assert state['scan_module']['linear']['bias'].sharding == ('layers', 'dout')
x = jnp.ones((1, 3))
y, out = m(x, None)
# test sharding axes is preserved
state = nnx.state(m)
- assert state.scan_module.linear.kernel.value.shape == (5, 3, 3)
- assert state.scan_module.linear.kernel.sharding == ('layers', 'din', 'dout')
- assert state.scan_module.linear.bias.value.shape == (5, 3)
- assert state.scan_module.linear.bias.sharding == ('layers', 'dout')
+ assert state['scan_module']['linear']['kernel'].value.shape == (5, 3, 3)
+ assert state['scan_module']['linear']['kernel'].sharding == (
+ 'layers',
+ 'din',
+ 'dout',
+ )
+ assert state['scan_module']['linear']['bias'].value.shape == (5, 3)
+ assert state['scan_module']['linear']['bias'].sharding == ('layers', 'dout')
def test_type_error_less_than_one_args(self):
class Block(nnx.Module):