diff --git a/flax/experimental/nnx/nnx/graph.py b/flax/experimental/nnx/nnx/graph.py index d30a3885c..957418513 100644 --- a/flax/experimental/nnx/nnx/graph.py +++ b/flax/experimental/nnx/nnx/graph.py @@ -990,21 +990,23 @@ def merge( @dataclasses.dataclass class UpdateContextManager: tag: str - ctx: UpdateContext | None def __enter__(self): - self.ctx = UpdateContext(self.tag, None, None) - GRAPH_CONTEXT.update_context_stacks[self.tag].append(self.ctx) - return self.ctx + ctx = UpdateContext(self.tag, None, None) + GRAPH_CONTEXT.update_context_stacks[self.tag].append(ctx) + return ctx def __exit__(self, *args): - if self.ctx is None: - raise RuntimeError('ctx should not be None, this is a bug.') + stack = GRAPH_CONTEXT.update_context_stacks[self.tag] + if not stack: + raise RuntimeError( + f'No update context found for tag {self.tag!r}, this is a bug.' + ) - GRAPH_CONTEXT.update_context_stacks[self.tag].pop() - self.ctx.refmap = None - self.ctx.idxmap = None - self.ctx = None + ctx = GRAPH_CONTEXT.update_context_stacks[self.tag].pop() + # clear references + ctx.refmap = None + ctx.idxmap = None def __call__(self, f: F) -> F: @functools.wraps(f) @@ -1107,7 +1109,7 @@ def update_context(tag: str): Args: tag: A string tag to identify the context. """ - return UpdateContextManager(tag, None) + return UpdateContextManager(tag) def current_update_context(tag: str) -> UpdateContext: