diff --git a/flax/nnx/bridge/module.py b/flax/nnx/bridge/module.py index 927ed479..36b93dc6 100644 --- a/flax/nnx/bridge/module.py +++ b/flax/nnx/bridge/module.py @@ -460,13 +460,16 @@ def init_with_output( **kwargs, ) + def is_initializing(self) -> bool: + return self._object__state._initializing + def compact(f: F) -> F: @functools.wraps(f) def compact_wrapper(self, *args, **kwargs): if not isinstance(self, Module): raise ValueError( - f"Expected 'self' to be a nnx.compat.Module, got {type(self).__name__}" + f"Expected 'self' to be a nnx.bridge.Module, got {type(self).__name__}" ) MODULE_CONTEXT.parent_stack.append(CompactContext(self))