Skip to content

Commit 90715be

Browse files
Cristian GarciaFlax Authors
authored andcommitted
[nnx] disallow Array leaves
Numpy and JAX Array's are no longer consider state leaves. This makes the structure of the State completely determined by Variables, which apart from being more predictable it produces structural stability invariant to leaf type changes which let to issues such as #4142. ```python class Foo(nnx.Module): def __init__(self): self.a = jnp.array(1) # no longer allowed, instead... self.b = nnx.Param(jnp.array(1)) # just use Variables ``` Also migrates all remaining tests from pytest to absl to ensure they are tested correctly internally. PiperOrigin-RevId: 671372717
1 parent aded9ac commit 90715be

File tree

10 files changed

+240
-232
lines changed

10 files changed

+240
-232
lines changed

flax/nnx/nnx/graph.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,18 @@
5050
Leaf = tp.TypeVar('Leaf')
5151
AuxData = tp.TypeVar('AuxData')
5252

53-
StateLeaf = tp.Union[VariableState[tp.Any], np.ndarray, jax.Array]
53+
StateLeaf = VariableState[tp.Any]
54+
NodeLeaf = VariableState[tp.Any]
5455
GraphState = State[Key, StateLeaf]
5556
GraphFlatState = FlatState[StateLeaf]
5657

5758

5859
def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
59-
return isinstance(x, (VariableState, np.ndarray, jax.Array))
60+
return isinstance(x, VariableState)
6061

6162

62-
def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]:
63-
return isinstance(x, (Variable, np.ndarray, jax.Array))
63+
def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]:
64+
return isinstance(x, Variable)
6465

6566

6667
class _HashById(tp.Hashable, tp.Generic[A]):
@@ -416,6 +417,11 @@ def _graph_flatten(
416417
flat_state[(*path, key)] = value
417418
leaves.append((key, None))
418419
else:
420+
if isinstance(value, (jax.Array, np.ndarray)):
421+
path_str = '/'.join(map(str, (*path, key)))
422+
raise ValueError(
423+
f'Arrays leaves are not supported, at {path_str!r}: {value}'
424+
)
419425
static_fields.append((key, value))
420426

421427
nodedef = NodeDef.create(

flax/nnx/nnx/training/optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def __init__(
133133
self.step = OptState(jnp.array(0, dtype=jnp.uint32))
134134
self.model = model
135135
self.tx = tx
136-
self.opt_state = tx.init(nnx.state(model, wrt))
136+
self.opt_state = OptState(tx.init(nnx.state(model, wrt)))
137137
self.wrt = wrt
138138

139139
def split(self, *filters: filterlib.Filter):
@@ -198,10 +198,10 @@ def update(self, grads):
198198
"""
199199
state = nnx.state(self.model, self.wrt)
200200

201-
updates, new_opt_state = self.tx.update(grads, self.opt_state, state)
201+
updates, new_opt_state = self.tx.update(grads, self.opt_state.value, state)
202202
new_params = optax.apply_updates(state, updates)
203203
assert isinstance(new_params, nnx.State)
204204

205205
self.step.value += 1
206206
nnx.update(self.model, new_params)
207-
self.opt_state = new_opt_state
207+
self.opt_state.value = new_opt_state

flax/nnx/tests/deprecated_transforms_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,3 +368,6 @@ def __call__(self, x: jax.Array) -> jax.Array:
368368
y = module(x)
369369

370370
assert y.shape == (1, 5, 3)
371+
372+
if __name__ == '__main__':
373+
absltest.main()

flax/nnx/tests/experimental_test.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

flax/nnx/tests/filters_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
from absl.testing import absltest
1716

1817
from flax import nnx
@@ -30,4 +29,7 @@ def __init__(self, rngs):
3029
head_state = nnx.state(model, nnx.PathContains('head'))
3130

3231
self.assertIn('head', head_state)
33-
self.assertNotIn('backbone', head_state)
32+
self.assertNotIn('backbone', head_state)
33+
34+
if __name__ == '__main__':
35+
absltest.main()

flax/nnx/tests/graph_utils_test.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from collections.abc import Callable
1516
import dataclasses
1617
from functools import partial
1718
from threading import Thread
1819
from typing import Any
1920

20-
from absl.testing import absltest
21+
from absl.testing import absltest, parameterized
2122
from flax import linen, nnx, struct
2223
import jax
2324
import jax.numpy as jnp
24-
import pytest
2525

2626

2727
class StatefulLinear(nnx.Module):
@@ -77,7 +77,7 @@ def test_unflatten_empty(self):
7777

7878
graphdef, state = nnx.split(g)
7979

80-
with pytest.raises(ValueError, match='Expected key'):
80+
with self.assertRaisesRegex(ValueError, 'Expected key'):
8181
nnx.graph.unflatten(graphdef, nnx.State({}))
8282

8383
def test_update_dynamic(self):
@@ -109,8 +109,8 @@ def test_update_static_inconsistent_types(self):
109109
g = [a, 3, a, nnx.Param(4)]
110110
g2 = [a, a, 3, nnx.Param(4)]
111111

112-
with pytest.raises(
113-
ValueError, match='Trying to update a node with a different type'
112+
with self.assertRaisesRegex(
113+
ValueError, 'Trying to update a node with a different type'
114114
):
115115
nnx.graph.graph_update_static(g, g2)
116116

@@ -130,7 +130,7 @@ def test_update_static_add_shared_error(self):
130130
g = nnx.List([a, 3, a, nnx.Param(4)])
131131
g2 = nnx.List([a, 3, a, nnx.Param(4), a])
132132

133-
with pytest.raises(ValueError, match='Trying to add a new node at path'):
133+
with self.assertRaisesRegex(ValueError, 'Trying to add a new node at path'):
134134
nnx.graph.graph_update_static(g, g2)
135135

136136
def test_module_list(self):
@@ -428,10 +428,10 @@ def f_pure(graphdef: nnx.graph.GraphDef[Foo], state):
428428
def test_call_jit_update(self):
429429
class Counter(nnx.Module):
430430
def __init__(self):
431-
self.count = jnp.zeros(())
431+
self.count = nnx.Param(jnp.zeros(()))
432432

433433
def inc(self):
434-
self.count += 1
434+
self.count.value += 1
435435
return 1
436436

437437
graph_state = nnx.split(Counter())
@@ -447,7 +447,7 @@ def update(graph_state: nnx.PureState[Counter]):
447447

448448
counter = nnx.merge(*graph_state)
449449

450-
self.assertEqual(counter.count, 2)
450+
self.assertEqual(counter.count.value, 2)
451451

452452
def test_stateful_linear(self):
453453
linear = StatefulLinear(3, 2, nnx.Rngs(0))
@@ -714,7 +714,7 @@ def test_to_tree_consistent_prefix(self):
714714
pure_tree = nnx.to_tree(impure_tree, prefix=prefix)
715715

716716
prefix = (0, None, 1)
717-
with pytest.raises(ValueError, match='Inconsistent aliasing detected'):
717+
with self.assertRaisesRegex(ValueError, 'Inconsistent aliasing detected'):
718718
nnx.to_tree(impure_tree, prefix=prefix)
719719

720720
def test_simple_vmap(self):
@@ -798,12 +798,24 @@ class SimplePyTreeModule(nnx.Module, experimental_pytree=True):
798798
pass
799799

800800

801-
@pytest.mark.parametrize(['x'], [(SimpleModule(),), (SimplePyTreeModule(),)])
802-
def test_threading(x: nnx.Module):
803-
class MyThread(Thread):
804-
def run(self) -> None:
805-
nnx.graph.split(x)
801+
class TestThreading(parameterized.TestCase):
806802

807-
thread = MyThread()
808-
thread.start()
809-
thread.join()
803+
@parameterized.parameters(
804+
(SimpleModule,),
805+
(SimplePyTreeModule,),
806+
)
807+
def test_threading(self, module_fn: Callable[[], nnx.Module]):
808+
x = module_fn()
809+
810+
class MyThread(Thread):
811+
812+
def run(self) -> None:
813+
nnx.graph.split(x)
814+
815+
thread = MyThread()
816+
thread.start()
817+
thread.join()
818+
819+
820+
if __name__ == '__main__':
821+
absltest.main()

flax/nnx/tests/module_test.py

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,20 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import dataclasses
1615
from copy import deepcopy
16+
import dataclasses
1717
from typing import Any, TypeVar
1818

19+
from absl.testing import absltest
20+
from flax import nnx
1921
import jax
2022
import jax.numpy as jnp
2123
import numpy as np
22-
import pytest
23-
24-
from flax import nnx
2524

2625
A = TypeVar('A')
2726

2827

29-
class TestModule:
28+
class TestModule(absltest.TestCase):
3029
def test_has_module_state(self):
3130
class Foo(nnx.Module): ...
3231

@@ -39,9 +38,9 @@ def test_trace_level(self):
3938

4039
@jax.jit
4140
def f():
42-
with pytest.raises(
43-
nnx.errors.TraceContextError,
44-
match="Cannot mutate 'Dict' from different trace level",
41+
with self.assertRaisesRegex(
42+
nnx.errors.TraceContextError,
43+
"Cannot mutate 'Dict' from different trace level",
4544
):
4645
m.a = 2
4746

@@ -265,7 +264,7 @@ def __call__(self, x):
265264

266265
m = Foo()
267266

268-
with pytest.raises(ValueError, match='to be a Variable, got'):
267+
with self.assertRaisesRegex(ValueError, 'to be a Variable, got'):
269268
m(2)
270269

271270
def test_sow_wrong_collection(self):
@@ -280,7 +279,7 @@ def __call__(self, x):
280279

281280
m = Foo()
282281

283-
with pytest.raises(ValueError, match='to be of type'):
282+
with self.assertRaisesRegex(ValueError, 'to be of type'):
284283
m(2)
285284

286285
def test_update_static_state_submodules(self):
@@ -466,9 +465,12 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs):
466465

467466
block = Block(2, 5, rngs=nnx.Rngs(0))
468467

469-
with pytest.raises(
470-
ValueError,
471-
match="Could not find at least one instance of the following attributes: {'unknown'}",
468+
with self.assertRaisesRegex(
469+
ValueError,
470+
(
471+
'Could not find at least one instance of the following attributes:'
472+
" {'unknown'}"
473+
),
472474
):
473475
block.set_attributes(
474476
deterministic=True, use_running_average=True, unknown=True
@@ -662,26 +664,10 @@ def __init__(self, *, rngs: nnx.Rngs):
662664
assert modules[1][0] == 'linear'
663665
assert isinstance(modules[1][1], nnx.Linear)
664666

665-
def test_array_in_module(self):
666-
class Foo(nnx.Module):
667-
def __init__(self):
668-
self.a = jnp.array(1.0)
669-
670-
foo = Foo()
671-
672-
graphdef, state = nnx.split(foo)
673-
674-
assert isinstance(state, nnx.State)
675-
assert isinstance(state.a, jax.Array)
676-
677-
foo2 = nnx.merge(graphdef, state)
678-
679-
assert isinstance(foo2.a, jax.Array)
680-
681667
def test_state_in_module(self):
682668
class Foo(nnx.Module):
683669
def __init__(self):
684-
self.a = nnx.State({'b': jnp.array(1.0)})
670+
self.a = nnx.State({'b': nnx.Param(jnp.array(1.0))})
685671

686672
foo = Foo()
687673

@@ -693,3 +679,6 @@ def __init__(self):
693679
foo2 = nnx.merge(graphdef, state)
694680

695681
assert isinstance(foo2.a, nnx.State)
682+
683+
if __name__ == '__main__':
684+
absltest.main()

0 commit comments

Comments
 (0)