Skip to content

Commit d7b0fb7

Browse files
committed
fix trace-level detection
1 parent 45a8f84 commit d7b0fb7

File tree

5 files changed

+39
-41
lines changed

5 files changed

+39
-41
lines changed

flax/core/tracers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Functionality for inspecting jax tracers."""
1616

1717
import jax
18+
import jax.core
1819

1920

2021
def current_trace():
@@ -29,5 +30,9 @@ def current_trace():
2930
return jax.core.get_opaque_trace_state(convention="flax")
3031

3132
def check_trace_level(base_level):
33+
# TODO(cgarciae): skipping for now as it breaks
34+
# too many internal tests.
35+
# level = current_trace()
36+
# if level != base_level:
37+
# raise errors.JaxTransformError()
3238
pass
33-
# TODO: re-enable when we update flax to use stackless trace context

flax/nnx/tracers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ def jax_trace(self):
4040
return self._jax_trace
4141

4242
def is_valid(self) -> bool:
43-
# TODO: re-enable when we update nnx to use stackless trace context
44-
return True
43+
return self._jax_trace == current_jax_trace()
4544

4645
def __nnx_repr__(self):
4746
yield reprlib.Object(f'{type(self).__name__}')

tests/nnx/module_test.py

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pickle
1818
import tempfile
1919
from typing import TypeVar
20+
import typing as tp
2021

2122
from absl.testing import absltest
2223
import cloudpickle
@@ -39,28 +40,35 @@ def __setitem__(self, idx, value):
3940

4041

4142
class Dict(nnx.Module):
43+
@tp.overload
44+
def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /): ...
45+
46+
@tp.overload
47+
def __init__(
48+
self, mapping: tp.Optional[tp.Mapping[str, A]] = None, /, **kwargs: A
49+
): ...
50+
4251
def __init__(self, *args, **kwargs):
43-
vars(self)['items'] = dict(*args, **kwargs)
52+
for name, value in dict(*args, **kwargs).items():
53+
setattr(self, name, value)
4454

45-
def __getitem__(self, key):
46-
return vars(self)['items'][key]
55+
def __getitem__(self, key) -> A:
56+
return getattr(self, key)
4757

4858
def __setitem__(self, key, value):
49-
vars(self)['items'][key] = value
59+
setattr(self, key, value)
60+
61+
def __getattr__(self, key) -> A:
62+
return super().__getattribute__(key)
5063

5164
def __setattr__(self, key, value):
52-
if key == 'items':
53-
object.__setattr__(self, key, value)
54-
else:
55-
vars(self)['items'][key] = value
65+
super().__setattr__(key, value)
66+
67+
def __iter__(self) -> tp.Iterator[str]:
68+
return (k for k in vars(self) if k != '_object__state')
5669

57-
def __getattr__(self, key):
58-
attrs = vars(self)
59-
if 'items' not in attrs:
60-
raise AttributeError('items')
61-
elif key not in attrs['items']:
62-
raise AttributeError(key)
63-
return attrs['items'][key]
70+
def __len__(self) -> int:
71+
return len(vars(self))
6472

6573

6674
class TestModule(absltest.TestCase):
@@ -71,15 +79,14 @@ class Foo(nnx.Module): ...
7179

7280
assert hasattr(foo, '_object__state')
7381

74-
@absltest.skip("Context checking doesn't work yet with stackless")
7582
def test_trace_level(self):
7683
m = Dict(a=nnx.Param(1))
7784

7885
@jax.jit
7986
def f():
8087
with self.assertRaisesRegex(
81-
errors.TraceContextError,
82-
"Cannot mutate 'Dict' from different trace level",
88+
errors.TraceContextError,
89+
"Cannot mutate 'Dict' from different trace level",
8390
):
8491
m.a = 2
8592

@@ -113,18 +120,6 @@ def g(graphdef: nnx.GraphDef[Dict], state: nnx.State):
113120

114121
assert m2.a == 2
115122

116-
def test_no_trace_level_error_on_grad(self):
117-
# No trace level error occurs because jax doesn't update
118-
# its top trace for grad.
119-
m = Dict(a=nnx.Param(1.0))
120-
121-
@jax.grad
122-
def f(_):
123-
m.a = 2.0
124-
return 1.0
125-
126-
f(1.0)
127-
128123
def test_call(self):
129124
class Foo(nnx.Module):
130125
def __init__(self, c: float, *, rngs: nnx.Rngs):

tests/nnx/rngs_test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def test_rng_stream(self):
5353
self.assertIs(rngs.params.key.value, key0)
5454
self.assertFalse(jnp.allclose(key1, key2))
5555

56-
@absltest.skip("Context checking doesn't work yet with stackless")
5756
def test_rng_trace_level_constraints(self):
5857
rngs = nnx.Rngs(0)
5958

@@ -128,7 +127,7 @@ def __call__(self, x):
128127

129128
rngs = nnx.Rngs(params=0, dropout=1)
130129
m = Foo(rngs)
131-
_, params, rng_counts, dropout_keys, param_keys = nnx.split(
130+
graphdef, params, rng_counts, dropout_keys, param_keys = nnx.split(
132131
m, nnx.Param, nnx.RngCount, 'dropout', 'params'
133132
)
134133

@@ -151,10 +150,10 @@ def __call__(self, x):
151150
out_axes=(0, 0, None),
152151
)
153152
def f(params, dropout_keys, param_keys, rng_counts, x):
154-
nnx.update(m, params, dropout_keys, param_keys, rng_counts)
153+
m = nnx.merge(graphdef, params, dropout_keys, param_keys, rng_counts)
155154
y = m(x)
156-
_, params, dropout_keys, param_keys, rng_counts = nnx.split(
157-
m, nnx.Param, 'dropout', 'params', nnx.RngCount
155+
_, params, rng_counts, dropout_keys, param_keys = nnx.split(
156+
m, nnx.Param, nnx.RngCount, 'dropout', 'params'
158157
)
159158
return y, params, rng_counts
160159

tests/nnx/transforms_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from jax.experimental import checkify, mesh_utils
2525
import jax.numpy as jnp
2626
import numpy as np
27+
from flax import errors
2728

2829

2930
class List(nnx.Module):
@@ -2512,7 +2513,6 @@ def f(m1, m2, m3):
25122513

25132514
f(m, m, m)
25142515

2515-
@absltest.skip('Enable once jax#19586 resolved')
25162516
def test_captured_module_in_return_error(self):
25172517
class Foo(nnx.Module):
25182518

@@ -2526,8 +2526,8 @@ def f(x):
25262526
return x, m
25272527

25282528
with self.assertRaisesRegex(
2529-
ValueError,
2530-
r'Trying to extract graph node from different trace level.*Foo',
2529+
errors.TraceContextError,
2530+
r'Trying to extract graph node from different trace level.*Foo',
25312531
):
25322532
x = jnp.zeros((5,))
25332533
f(x)

0 commit comments

Comments
 (0)