Skip to content

Commit

Permalink
Merge pull request #2495 from levskaya:lifetimefix
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 478475003
  • Loading branch information
Flax Authors committed Oct 3, 2022
2 parents ef67f1c + 53a4862 commit 2650849
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from flax.core.scope import ( # pylint: disable=g-multiple-import
CollectionFilter, DenyList, FrozenVariableDict, Variable, VariableDict,
union_filters)
from flax.ids import FlaxId
from flax.ids import uuid


Expand Down Expand Up @@ -115,7 +116,7 @@ def _module_repr(module: 'Module', num_spaces: int = 4):
else:
return f'{cls_name}()'

#
# Tabulation utilities.
# -----------------------------------------------------------------------------

_find_non_lifted_module = re.compile(r'.*\((.*)\)')
Expand Down Expand Up @@ -507,7 +508,7 @@ def reimport(self, other: '_ModuleInternalState') -> None:
'__reduce__', '__reduce_ex__', '__copy__', '__deepcopy__')


_caches: 'weakref.WeakKeyDictionary[Scope, Dict[int, Module]]' = (
_caches: 'weakref.WeakKeyDictionary[Scope, weakref.WeakValueDictionary[FlaxId, Module]]' = (
weakref.WeakKeyDictionary())


Expand Down Expand Up @@ -697,7 +698,7 @@ def _call_wrapped_method(self, fun, args, kwargs):
if add_call_info:
call_index = _context.call_info_stack[-1].get_call_index(self)
scope_path = jax.tree_util.tree_map(_fix_path_part, self.scope.path)

# call method
if _use_named_call:
with jax.named_scope(_derive_profiling_name(self, fun)):
Expand Down Expand Up @@ -883,7 +884,7 @@ def _register_submodules(self, name, val):
"""Registers a submodule."""
assert self.scope, 'Trying to register submodules on unbound scope.'
root = self.scope.root
cache = _caches.get(root, {})
cache = _caches.get(root, weakref.WeakValueDictionary())
_caches[root] = cache
queue = []
def adopt_attr_modules(cache, queue, suffix, subvalue):
Expand All @@ -895,9 +896,12 @@ def adopt_attr_modules(cache, queue, suffix, subvalue):
# Preserve sharing-by-reference relationships during adoption
# via cache keyed on unique instance ids.
key = subvalue._id
if key not in cache:
cache[key] = subvalue.clone()
subvalue = cache[key]
if key in cache:
subvalue = cache[key]
else:
# We must bind to local variable before adding to weakvalue dict.
subvalue = subvalue.clone()
cache[key] = subvalue
if subvalue.name is None:
object.__setattr__(subvalue, 'parent', self)
object.__setattr__(subvalue, 'name', f'{name}{suffix}')
Expand Down Expand Up @@ -1464,13 +1468,13 @@ def __call__(self, x):

def perturb(self, name: str, value: T, collection: str = 'perturbations') -> T:
"""Add an zero-value variable ('perturbation') to the intermediate value.
The gradient of `value` would be the same as the gradient of this
perturbation variable. Therefore, if you define your loss function with
both params and perturbations as standalone arguments, you can get the
intermediate gradients of `value` by running `jax.grad` on the perturbation
argument.
Note: this is an experimental API and may be tweaked later for better
performance and usability.
At its current stage, it creates extra dummy variables that occupies extra
Expand Down Expand Up @@ -1505,7 +1509,7 @@ def loss(params, perturbations, inputs, targets):
"""
value += self.variable(collection, name, lambda: jnp.zeros_like(value)).value
return value

def tabulate(
self,
rngs: Union[PRNGKey, RNGSequences],
Expand Down Expand Up @@ -1541,7 +1545,7 @@ def __call__(self, x):
This gives the following output::
Foo Summary
Foo Summary
┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ path ┃ module ┃ inputs ┃ outputs ┃ params ┃
┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
Expand All @@ -1559,7 +1563,7 @@ def __call__(self, x):
├─────────┼────────┼───────────────┼───────────────┼──────────────────────┤
│ │ │ │ Total │ 50 (200 B) │
└─────────┴────────┴───────────────┴───────────────┴──────────────────────┘
Total Parameters: 50 (200 B)
**Note**: rows order in the table does not represent execution order,
Expand Down Expand Up @@ -1591,7 +1595,7 @@ def __call__(self, x):
A string summarizing the Module.
"""
from flax.linen import summary

tabulate_fn = summary.tabulate(self, rngs, depth=depth,
show_repeated=show_repeated, mutable=mutable,
console_kwargs=console_kwargs)
Expand Down

0 comments on commit 2650849

Please sign in to comment.