From d7b0fb7bd606c9c64774a1db61ded9bb335e7e6f Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Mon, 3 Feb 2025 18:44:36 -0500 Subject: [PATCH] fix trace-level detection --- flax/core/tracers.py | 7 ++++- flax/nnx/tracers.py | 3 +- tests/nnx/module_test.py | 55 ++++++++++++++++-------------------- tests/nnx/rngs_test.py | 9 +++--- tests/nnx/transforms_test.py | 6 ++-- 5 files changed, 39 insertions(+), 41 deletions(-) diff --git a/flax/core/tracers.py b/flax/core/tracers.py index b380f6aaf..ba82ab845 100644 --- a/flax/core/tracers.py +++ b/flax/core/tracers.py @@ -15,6 +15,7 @@ """Functionality for inspecting jax tracers.""" import jax +import jax.core def current_trace(): @@ -29,5 +30,9 @@ def current_trace(): return jax.core.get_opaque_trace_state(convention="flax") def check_trace_level(base_level): + # TODO(cgarciae): skipping for now as it breaks + # too many internal tests. + # level = current_trace() + # if level != base_level: + # raise errors.JaxTransformError() pass - # TODO: re-enable when we update flax to use stackless trace context diff --git a/flax/nnx/tracers.py b/flax/nnx/tracers.py index 60221b3f4..18056a8a8 100644 --- a/flax/nnx/tracers.py +++ b/flax/nnx/tracers.py @@ -40,8 +40,7 @@ def jax_trace(self): return self._jax_trace def is_valid(self) -> bool: - # TODO: re-enable when we update nnx to use stackless trace context - return True + return self._jax_trace == current_jax_trace() def __nnx_repr__(self): yield reprlib.Object(f'{type(self).__name__}') diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 25dfc5636..ded6fa189 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -17,6 +17,7 @@ import pickle import tempfile from typing import TypeVar +import typing as tp from absl.testing import absltest import cloudpickle @@ -39,28 +40,35 @@ def __setitem__(self, idx, value): class Dict(nnx.Module): + @tp.overload + def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /): ... + + @tp.overload + def __init__( + self, mapping: tp.Optional[tp.Mapping[str, A]] = None, /, **kwargs: A + ): ... + def __init__(self, *args, **kwargs): - vars(self)['items'] = dict(*args, **kwargs) + for name, value in dict(*args, **kwargs).items(): + setattr(self, name, value) - def __getitem__(self, key): - return vars(self)['items'][key] + def __getitem__(self, key) -> A: + return getattr(self, key) def __setitem__(self, key, value): - vars(self)['items'][key] = value + setattr(self, key, value) + + def __getattr__(self, key) -> A: + return super().__getattribute__(key) def __setattr__(self, key, value): - if key == 'items': - object.__setattr__(self, key, value) - else: - vars(self)['items'][key] = value + super().__setattr__(key, value) + + def __iter__(self) -> tp.Iterator[str]: + return (k for k in vars(self) if k != '_object__state') - def __getattr__(self, key): - attrs = vars(self) - if 'items' not in attrs: - raise AttributeError('items') - elif key not in attrs['items']: - raise AttributeError(key) - return attrs['items'][key] + def __len__(self) -> int: + return len(vars(self)) class TestModule(absltest.TestCase): @@ -71,15 +79,14 @@ class Foo(nnx.Module): ... assert hasattr(foo, '_object__state') - @absltest.skip("Context checking doesn't work yet with stackless") def test_trace_level(self): m = Dict(a=nnx.Param(1)) @jax.jit def f(): with self.assertRaisesRegex( - errors.TraceContextError, - "Cannot mutate 'Dict' from different trace level", + errors.TraceContextError, + "Cannot mutate 'Dict' from different trace level", ): m.a = 2 @@ -113,18 +120,6 @@ def g(graphdef: nnx.GraphDef[Dict], state: nnx.State): assert m2.a == 2 - def test_no_trace_level_error_on_grad(self): - # No trace level error occurs because jax doesn't update - # its top trace for grad. - m = Dict(a=nnx.Param(1.0)) - - @jax.grad - def f(_): - m.a = 2.0 - return 1.0 - - f(1.0) - def test_call(self): class Foo(nnx.Module): def __init__(self, c: float, *, rngs: nnx.Rngs): diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py index f57fb84fd..f14a5ea1a 100644 --- a/tests/nnx/rngs_test.py +++ b/tests/nnx/rngs_test.py @@ -53,7 +53,6 @@ def test_rng_stream(self): self.assertIs(rngs.params.key.value, key0) self.assertFalse(jnp.allclose(key1, key2)) - @absltest.skip("Context checking doesn't work yet with stackless") def test_rng_trace_level_constraints(self): rngs = nnx.Rngs(0) @@ -128,7 +127,7 @@ def __call__(self, x): rngs = nnx.Rngs(params=0, dropout=1) m = Foo(rngs) - _, params, rng_counts, dropout_keys, param_keys = nnx.split( + graphdef, params, rng_counts, dropout_keys, param_keys = nnx.split( m, nnx.Param, nnx.RngCount, 'dropout', 'params' ) @@ -151,10 +150,10 @@ def __call__(self, x): out_axes=(0, 0, None), ) def f(params, dropout_keys, param_keys, rng_counts, x): - nnx.update(m, params, dropout_keys, param_keys, rng_counts) + m = nnx.merge(graphdef, params, dropout_keys, param_keys, rng_counts) y = m(x) - _, params, dropout_keys, param_keys, rng_counts = nnx.split( - m, nnx.Param, 'dropout', 'params', nnx.RngCount + _, params, rng_counts, dropout_keys, param_keys = nnx.split( + m, nnx.Param, nnx.RngCount, 'dropout', 'params' ) return y, params, rng_counts diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index 9c812227f..2c3fc2f25 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -24,6 +24,7 @@ from jax.experimental import checkify, mesh_utils import jax.numpy as jnp import numpy as np +from flax import errors class List(nnx.Module): @@ -2512,7 +2513,6 @@ def f(m1, m2, m3): f(m, m, m) - @absltest.skip('Enable once jax#19586 resolved') def test_captured_module_in_return_error(self): class Foo(nnx.Module): @@ -2526,8 +2526,8 @@ def f(x): return x, m with self.assertRaisesRegex( - ValueError, - r'Trying to extract graph node from different trace level.*Foo', + errors.TraceContextError, + r'Trying to extract graph node from different trace level.*Foo', ): x = jnp.zeros((5,)) f(x)