Skip to content

Commit

Permalink
fix trace-level detection
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Mar 6, 2025
1 parent 45a8f84 commit d7b0fb7
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 41 deletions.
7 changes: 6 additions & 1 deletion flax/core/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Functionality for inspecting jax tracers."""

import jax
import jax.core


def current_trace():
Expand All @@ -29,5 +30,9 @@ def current_trace():
return jax.core.get_opaque_trace_state(convention="flax")

def check_trace_level(base_level):
# TODO(cgarciae): skipping for now as it breaks
# too many internal tests.
# level = current_trace()
# if level != base_level:
# raise errors.JaxTransformError()
pass
# TODO: re-enable when we update flax to use stackless trace context
3 changes: 1 addition & 2 deletions flax/nnx/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def jax_trace(self):
return self._jax_trace

def is_valid(self) -> bool:
# TODO: re-enable when we update nnx to use stackless trace context
return True
return self._jax_trace == current_jax_trace()

def __nnx_repr__(self):
yield reprlib.Object(f'{type(self).__name__}')
Expand Down
55 changes: 25 additions & 30 deletions tests/nnx/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pickle
import tempfile
from typing import TypeVar
import typing as tp

from absl.testing import absltest
import cloudpickle
Expand All @@ -39,28 +40,35 @@ def __setitem__(self, idx, value):


class Dict(nnx.Module):
@tp.overload
def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /): ...

@tp.overload
def __init__(
self, mapping: tp.Optional[tp.Mapping[str, A]] = None, /, **kwargs: A
): ...

def __init__(self, *args, **kwargs):
vars(self)['items'] = dict(*args, **kwargs)
for name, value in dict(*args, **kwargs).items():
setattr(self, name, value)

def __getitem__(self, key):
return vars(self)['items'][key]
def __getitem__(self, key) -> A:
return getattr(self, key)

def __setitem__(self, key, value):
vars(self)['items'][key] = value
setattr(self, key, value)

def __getattr__(self, key) -> A:
return super().__getattribute__(key)

def __setattr__(self, key, value):
if key == 'items':
object.__setattr__(self, key, value)
else:
vars(self)['items'][key] = value
super().__setattr__(key, value)

def __iter__(self) -> tp.Iterator[str]:
return (k for k in vars(self) if k != '_object__state')

def __getattr__(self, key):
attrs = vars(self)
if 'items' not in attrs:
raise AttributeError('items')
elif key not in attrs['items']:
raise AttributeError(key)
return attrs['items'][key]
def __len__(self) -> int:
return len(vars(self))


class TestModule(absltest.TestCase):
Expand All @@ -71,15 +79,14 @@ class Foo(nnx.Module): ...

assert hasattr(foo, '_object__state')

@absltest.skip("Context checking doesn't work yet with stackless")
def test_trace_level(self):
m = Dict(a=nnx.Param(1))

@jax.jit
def f():
with self.assertRaisesRegex(
errors.TraceContextError,
"Cannot mutate 'Dict' from different trace level",
errors.TraceContextError,
"Cannot mutate 'Dict' from different trace level",
):
m.a = 2

Expand Down Expand Up @@ -113,18 +120,6 @@ def g(graphdef: nnx.GraphDef[Dict], state: nnx.State):

assert m2.a == 2

def test_no_trace_level_error_on_grad(self):
# No trace level error occurs because jax doesn't update
# its top trace for grad.
m = Dict(a=nnx.Param(1.0))

@jax.grad
def f(_):
m.a = 2.0
return 1.0

f(1.0)

def test_call(self):
class Foo(nnx.Module):
def __init__(self, c: float, *, rngs: nnx.Rngs):
Expand Down
9 changes: 4 additions & 5 deletions tests/nnx/rngs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def test_rng_stream(self):
self.assertIs(rngs.params.key.value, key0)
self.assertFalse(jnp.allclose(key1, key2))

@absltest.skip("Context checking doesn't work yet with stackless")
def test_rng_trace_level_constraints(self):
rngs = nnx.Rngs(0)

Expand Down Expand Up @@ -128,7 +127,7 @@ def __call__(self, x):

rngs = nnx.Rngs(params=0, dropout=1)
m = Foo(rngs)
_, params, rng_counts, dropout_keys, param_keys = nnx.split(
graphdef, params, rng_counts, dropout_keys, param_keys = nnx.split(
m, nnx.Param, nnx.RngCount, 'dropout', 'params'
)

Expand All @@ -151,10 +150,10 @@ def __call__(self, x):
out_axes=(0, 0, None),
)
def f(params, dropout_keys, param_keys, rng_counts, x):
nnx.update(m, params, dropout_keys, param_keys, rng_counts)
m = nnx.merge(graphdef, params, dropout_keys, param_keys, rng_counts)
y = m(x)
_, params, dropout_keys, param_keys, rng_counts = nnx.split(
m, nnx.Param, 'dropout', 'params', nnx.RngCount
_, params, rng_counts, dropout_keys, param_keys = nnx.split(
m, nnx.Param, nnx.RngCount, 'dropout', 'params'
)
return y, params, rng_counts

Expand Down
6 changes: 3 additions & 3 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax.experimental import checkify, mesh_utils
import jax.numpy as jnp
import numpy as np
from flax import errors


class List(nnx.Module):
Expand Down Expand Up @@ -2512,7 +2513,6 @@ def f(m1, m2, m3):

f(m, m, m)

@absltest.skip('Enable once jax#19586 resolved')
def test_captured_module_in_return_error(self):
class Foo(nnx.Module):

Expand All @@ -2526,8 +2526,8 @@ def f(x):
return x, m

with self.assertRaisesRegex(
ValueError,
r'Trying to extract graph node from different trace level.*Foo',
errors.TraceContextError,
r'Trying to extract graph node from different trace level.*Foo',
):
x = jnp.zeros((5,))
f(x)
Expand Down

0 comments on commit d7b0fb7

Please sign in to comment.