From b1cb9521fb82ef29f4da234d1dbcf82935576635 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 24 May 2024 08:53:04 -0700 Subject: [PATCH] [nnx] fix UpdateContextManager PiperOrigin-RevId: 636934066 --- flax/experimental/nnx/nnx/graph.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/flax/experimental/nnx/nnx/graph.py b/flax/experimental/nnx/nnx/graph.py index d30a3885c..957418513 100644 --- a/flax/experimental/nnx/nnx/graph.py +++ b/flax/experimental/nnx/nnx/graph.py @@ -990,21 +990,23 @@ def merge( @dataclasses.dataclass class UpdateContextManager: tag: str - ctx: UpdateContext | None def __enter__(self): - self.ctx = UpdateContext(self.tag, None, None) - GRAPH_CONTEXT.update_context_stacks[self.tag].append(self.ctx) - return self.ctx + ctx = UpdateContext(self.tag, None, None) + GRAPH_CONTEXT.update_context_stacks[self.tag].append(ctx) + return ctx def __exit__(self, *args): - if self.ctx is None: - raise RuntimeError('ctx should not be None, this is a bug.') + stack = GRAPH_CONTEXT.update_context_stacks[self.tag] + if not stack: + raise RuntimeError( + f'No update context found for tag {self.tag!r}, this is a bug.' + ) - GRAPH_CONTEXT.update_context_stacks[self.tag].pop() - self.ctx.refmap = None - self.ctx.idxmap = None - self.ctx = None + ctx = GRAPH_CONTEXT.update_context_stacks[self.tag].pop() + # clear references + ctx.refmap = None + ctx.idxmap = None def __call__(self, f: F) -> F: @functools.wraps(f) @@ -1107,7 +1109,7 @@ def update_context(tag: str): Args: tag: A string tag to identify the context. """ - return UpdateContextManager(tag, None) + return UpdateContextManager(tag) def current_update_context(tag: str) -> UpdateContext: