39
39
from flax .core .scope import ( # pylint: disable=g-multiple-import
40
40
CollectionFilter , DenyList , FrozenVariableDict , Variable , VariableDict ,
41
41
union_filters )
42
+ from flax .ids import FlaxId
42
43
from flax .ids import uuid
43
44
44
45
@@ -115,7 +116,7 @@ def _module_repr(module: 'Module', num_spaces: int = 4):
115
116
else :
116
117
return f'{ cls_name } ()'
117
118
118
- #
119
+ # Tabulation utilities.
119
120
# -----------------------------------------------------------------------------
120
121
121
122
_find_non_lifted_module = re .compile (r'.*\((.*)\)' )
@@ -507,7 +508,7 @@ def reimport(self, other: '_ModuleInternalState') -> None:
507
508
'__reduce__' , '__reduce_ex__' , '__copy__' , '__deepcopy__' )
508
509
509
510
510
- _caches : 'weakref.WeakKeyDictionary[Scope, Dict[int , Module]]' = (
511
+ _caches : 'weakref.WeakKeyDictionary[Scope, weakref.WeakValueDictionary[FlaxId , Module]]' = (
511
512
weakref .WeakKeyDictionary ())
512
513
513
514
@@ -697,7 +698,7 @@ def _call_wrapped_method(self, fun, args, kwargs):
697
698
if add_call_info :
698
699
call_index = _context .call_info_stack [- 1 ].get_call_index (self )
699
700
scope_path = jax .tree_util .tree_map (_fix_path_part , self .scope .path )
700
-
701
+
701
702
# call method
702
703
if _use_named_call :
703
704
with jax .named_scope (_derive_profiling_name (self , fun )):
@@ -883,7 +884,7 @@ def _register_submodules(self, name, val):
883
884
"""Registers a submodule."""
884
885
assert self .scope , 'Trying to register submodules on unbound scope.'
885
886
root = self .scope .root
886
- cache = _caches .get (root , {} )
887
+ cache = _caches .get (root , weakref . WeakValueDictionary () )
887
888
_caches [root ] = cache
888
889
queue = []
889
890
def adopt_attr_modules (cache , queue , suffix , subvalue ):
@@ -895,9 +896,12 @@ def adopt_attr_modules(cache, queue, suffix, subvalue):
895
896
# Preserve sharing-by-reference relationships during adoption
896
897
# via cache keyed on unique instance ids.
897
898
key = subvalue ._id
898
- if key not in cache :
899
- cache [key ] = subvalue .clone ()
900
- subvalue = cache [key ]
899
+ if key in cache :
900
+ subvalue = cache [key ]
901
+ else :
902
+ # We must bind to local variable before adding to weakvalue dict.
903
+ subvalue = subvalue .clone ()
904
+ cache [key ] = subvalue
901
905
if subvalue .name is None :
902
906
object .__setattr__ (subvalue , 'parent' , self )
903
907
object .__setattr__ (subvalue , 'name' , f'{ name } { suffix } ' )
@@ -1464,13 +1468,13 @@ def __call__(self, x):
1464
1468
1465
1469
def perturb (self , name : str , value : T , collection : str = 'perturbations' ) -> T :
1466
1470
"""Add an zero-value variable ('perturbation') to the intermediate value.
1467
-
1471
+
1468
1472
The gradient of `value` would be the same as the gradient of this
1469
1473
perturbation variable. Therefore, if you define your loss function with
1470
1474
both params and perturbations as standalone arguments, you can get the
1471
1475
intermediate gradients of `value` by running `jax.grad` on the perturbation
1472
1476
argument.
1473
-
1477
+
1474
1478
Note: this is an experimental API and may be tweaked later for better
1475
1479
performance and usability.
1476
1480
At its current stage, it creates extra dummy variables that occupies extra
@@ -1505,7 +1509,7 @@ def loss(params, perturbations, inputs, targets):
1505
1509
"""
1506
1510
value += self .variable (collection , name , lambda : jnp .zeros_like (value )).value
1507
1511
return value
1508
-
1512
+
1509
1513
def tabulate (
1510
1514
self ,
1511
1515
rngs : Union [PRNGKey , RNGSequences ],
@@ -1541,7 +1545,7 @@ def __call__(self, x):
1541
1545
1542
1546
This gives the following output::
1543
1547
1544
- Foo Summary
1548
+ Foo Summary
1545
1549
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
1546
1550
┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃
1547
1551
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
@@ -1559,7 +1563,7 @@ def __call__(self, x):
1559
1563
├─────────┼────────┼───────────────┼───────────────┼──────────────────────┤
1560
1564
│ │ │ │ Total │ 50 (200 B) │
1561
1565
└─────────┴────────┴───────────────┴───────────────┴──────────────────────┘
1562
-
1566
+
1563
1567
Total Parameters: 50 (200 B)
1564
1568
1565
1569
**Note**: rows order in the table does not represent execution order,
@@ -1591,7 +1595,7 @@ def __call__(self, x):
1591
1595
A string summarizing the Module.
1592
1596
"""
1593
1597
from flax .linen import summary
1594
-
1598
+
1595
1599
tabulate_fn = summary .tabulate (self , rngs , depth = depth ,
1596
1600
show_repeated = show_repeated , mutable = mutable ,
1597
1601
console_kwargs = console_kwargs )
0 commit comments