Skip to content

Commit 2650849

Browse files
author
Flax Authors
committed
Merge pull request #2495 from levskaya:lifetimefix
PiperOrigin-RevId: 478475003
2 parents ef67f1c + 53a4862 commit 2650849

File tree

1 file changed

+17
-13
lines changed

1 file changed

+17
-13
lines changed

flax/linen/module.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from flax.core.scope import ( # pylint: disable=g-multiple-import
4040
CollectionFilter, DenyList, FrozenVariableDict, Variable, VariableDict,
4141
union_filters)
42+
from flax.ids import FlaxId
4243
from flax.ids import uuid
4344

4445

@@ -115,7 +116,7 @@ def _module_repr(module: 'Module', num_spaces: int = 4):
115116
else:
116117
return f'{cls_name}()'
117118

118-
#
119+
# Tabulation utilities.
119120
# -----------------------------------------------------------------------------
120121

121122
_find_non_lifted_module = re.compile(r'.*\((.*)\)')
@@ -507,7 +508,7 @@ def reimport(self, other: '_ModuleInternalState') -> None:
507508
'__reduce__', '__reduce_ex__', '__copy__', '__deepcopy__')
508509

509510

510-
_caches: 'weakref.WeakKeyDictionary[Scope, Dict[int, Module]]' = (
511+
_caches: 'weakref.WeakKeyDictionary[Scope, weakref.WeakValueDictionary[FlaxId, Module]]' = (
511512
weakref.WeakKeyDictionary())
512513

513514

@@ -697,7 +698,7 @@ def _call_wrapped_method(self, fun, args, kwargs):
697698
if add_call_info:
698699
call_index = _context.call_info_stack[-1].get_call_index(self)
699700
scope_path = jax.tree_util.tree_map(_fix_path_part, self.scope.path)
700-
701+
701702
# call method
702703
if _use_named_call:
703704
with jax.named_scope(_derive_profiling_name(self, fun)):
@@ -883,7 +884,7 @@ def _register_submodules(self, name, val):
883884
"""Registers a submodule."""
884885
assert self.scope, 'Trying to register submodules on unbound scope.'
885886
root = self.scope.root
886-
cache = _caches.get(root, {})
887+
cache = _caches.get(root, weakref.WeakValueDictionary())
887888
_caches[root] = cache
888889
queue = []
889890
def adopt_attr_modules(cache, queue, suffix, subvalue):
@@ -895,9 +896,12 @@ def adopt_attr_modules(cache, queue, suffix, subvalue):
895896
# Preserve sharing-by-reference relationships during adoption
896897
# via cache keyed on unique instance ids.
897898
key = subvalue._id
898-
if key not in cache:
899-
cache[key] = subvalue.clone()
900-
subvalue = cache[key]
899+
if key in cache:
900+
subvalue = cache[key]
901+
else:
902+
# We must bind to local variable before adding to weakvalue dict.
903+
subvalue = subvalue.clone()
904+
cache[key] = subvalue
901905
if subvalue.name is None:
902906
object.__setattr__(subvalue, 'parent', self)
903907
object.__setattr__(subvalue, 'name', f'{name}{suffix}')
@@ -1464,13 +1468,13 @@ def __call__(self, x):
14641468

14651469
def perturb(self, name: str, value: T, collection: str = 'perturbations') -> T:
14661470
"""Add an zero-value variable ('perturbation') to the intermediate value.
1467-
1471+
14681472
The gradient of `value` would be the same as the gradient of this
14691473
perturbation variable. Therefore, if you define your loss function with
14701474
both params and perturbations as standalone arguments, you can get the
14711475
intermediate gradients of `value` by running `jax.grad` on the perturbation
14721476
argument.
1473-
1477+
14741478
Note: this is an experimental API and may be tweaked later for better
14751479
performance and usability.
14761480
At its current stage, it creates extra dummy variables that occupies extra
@@ -1505,7 +1509,7 @@ def loss(params, perturbations, inputs, targets):
15051509
"""
15061510
value += self.variable(collection, name, lambda: jnp.zeros_like(value)).value
15071511
return value
1508-
1512+
15091513
def tabulate(
15101514
self,
15111515
rngs: Union[PRNGKey, RNGSequences],
@@ -1541,7 +1545,7 @@ def __call__(self, x):
15411545
15421546
This gives the following output::
15431547
1544-
Foo Summary
1548+
Foo Summary
15451549
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
15461550
┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃
15471551
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
@@ -1559,7 +1563,7 @@ def __call__(self, x):
15591563
├─────────┼────────┼───────────────┼───────────────┼──────────────────────┤
15601564
│ │ │ │ Total │ 50 (200 B) │
15611565
└─────────┴────────┴───────────────┴───────────────┴──────────────────────┘
1562-
1566+
15631567
Total Parameters: 50 (200 B)
15641568
15651569
**Note**: rows order in the table does not represent execution order,
@@ -1591,7 +1595,7 @@ def __call__(self, x):
15911595
A string summarizing the Module.
15921596
"""
15931597
from flax.linen import summary
1594-
1598+
15951599
tabulate_fn = summary.tabulate(self, rngs, depth=depth,
15961600
show_repeated=show_repeated, mutable=mutable,
15971601
console_kwargs=console_kwargs)

0 commit comments

Comments
 (0)