Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix trace-level detection #4527

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion flax/core/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Functionality for inspecting jax tracers."""

import jax
import jax.core


def current_trace():
Expand All @@ -29,5 +30,9 @@ def current_trace():
return jax.core.get_opaque_trace_state(convention="flax")

def check_trace_level(base_level):
# TODO(cgarciae): skipping for now as it breaks
# too many internal tests.
# level = current_trace()
# if level != base_level:
# raise errors.JaxTransformError()
pass
# TODO: re-enable when we update flax to use stackless trace context
3 changes: 1 addition & 2 deletions flax/nnx/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ def jax_trace(self):
return self._jax_trace

def is_valid(self) -> bool:
# TODO: re-enable when we update nnx to use stackless trace context
return True
return self._jax_trace == current_jax_trace()

def __nnx_repr__(self):
yield reprlib.Object(f'{type(self).__name__}')
Expand Down
55 changes: 25 additions & 30 deletions tests/nnx/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pickle
import tempfile
from typing import TypeVar
import typing as tp

from absl.testing import absltest
import cloudpickle
Expand All @@ -39,28 +40,35 @@ def __setitem__(self, idx, value):


class Dict(nnx.Module):
@tp.overload
def __init__(self, iterable: tp.Iterable[tp.Tuple[str, A]], /): ...

@tp.overload
def __init__(
self, mapping: tp.Optional[tp.Mapping[str, A]] = None, /, **kwargs: A
): ...

def __init__(self, *args, **kwargs):
vars(self)['items'] = dict(*args, **kwargs)
for name, value in dict(*args, **kwargs).items():
setattr(self, name, value)

def __getitem__(self, key):
return vars(self)['items'][key]
def __getitem__(self, key) -> A:
return getattr(self, key)

def __setitem__(self, key, value):
vars(self)['items'][key] = value
setattr(self, key, value)

def __getattr__(self, key) -> A:
return super().__getattribute__(key)

def __setattr__(self, key, value):
if key == 'items':
object.__setattr__(self, key, value)
else:
vars(self)['items'][key] = value
super().__setattr__(key, value)

def __iter__(self) -> tp.Iterator[str]:
return (k for k in vars(self) if k != '_object__state')

def __getattr__(self, key):
attrs = vars(self)
if 'items' not in attrs:
raise AttributeError('items')
elif key not in attrs['items']:
raise AttributeError(key)
return attrs['items'][key]
def __len__(self) -> int:
return len(vars(self))


class TestModule(absltest.TestCase):
Expand All @@ -71,15 +79,14 @@ class Foo(nnx.Module): ...

assert hasattr(foo, '_object__state')

@absltest.skip("Context checking doesn't work yet with stackless")
def test_trace_level(self):
m = Dict(a=nnx.Param(1))

@jax.jit
def f():
with self.assertRaisesRegex(
errors.TraceContextError,
"Cannot mutate 'Dict' from different trace level",
errors.TraceContextError,
"Cannot mutate 'Dict' from different trace level",
):
m.a = 2

Expand Down Expand Up @@ -113,18 +120,6 @@ def g(graphdef: nnx.GraphDef[Dict], state: nnx.State):

assert m2.a == 2

def test_no_trace_level_error_on_grad(self):
# No trace level error occurs because jax doesn't update
# its top trace for grad.
m = Dict(a=nnx.Param(1.0))

@jax.grad
def f(_):
m.a = 2.0
return 1.0

f(1.0)

def test_call(self):
class Foo(nnx.Module):
def __init__(self, c: float, *, rngs: nnx.Rngs):
Expand Down
9 changes: 4 additions & 5 deletions tests/nnx/rngs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def test_rng_stream(self):
self.assertIs(rngs.params.key.value, key0)
self.assertFalse(jnp.allclose(key1, key2))

@absltest.skip("Context checking doesn't work yet with stackless")
def test_rng_trace_level_constraints(self):
rngs = nnx.Rngs(0)

Expand Down Expand Up @@ -128,7 +127,7 @@ def __call__(self, x):

rngs = nnx.Rngs(params=0, dropout=1)
m = Foo(rngs)
_, params, rng_counts, dropout_keys, param_keys = nnx.split(
graphdef, params, rng_counts, dropout_keys, param_keys = nnx.split(
m, nnx.Param, nnx.RngCount, 'dropout', 'params'
)

Expand All @@ -151,10 +150,10 @@ def __call__(self, x):
out_axes=(0, 0, None),
)
def f(params, dropout_keys, param_keys, rng_counts, x):
nnx.update(m, params, dropout_keys, param_keys, rng_counts)
m = nnx.merge(graphdef, params, dropout_keys, param_keys, rng_counts)
y = m(x)
_, params, dropout_keys, param_keys, rng_counts = nnx.split(
m, nnx.Param, 'dropout', 'params', nnx.RngCount
_, params, rng_counts, dropout_keys, param_keys = nnx.split(
m, nnx.Param, nnx.RngCount, 'dropout', 'params'
)
return y, params, rng_counts

Expand Down
6 changes: 3 additions & 3 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax.experimental import checkify, mesh_utils
import jax.numpy as jnp
import numpy as np
from flax import errors


class List(nnx.Module):
Expand Down Expand Up @@ -2512,7 +2513,6 @@ def f(m1, m2, m3):

f(m, m, m)

@absltest.skip('Enable once jax#19586 resolved')
def test_captured_module_in_return_error(self):
class Foo(nnx.Module):

Expand All @@ -2526,8 +2526,8 @@ def f(x):
return x, m

with self.assertRaisesRegex(
ValueError,
r'Trying to extract graph node from different trace level.*Foo',
errors.TraceContextError,
r'Trying to extract graph node from different trace level.*Foo',
):
x = jnp.zeros((5,))
f(x)
Expand Down
Loading