Skip to content

Commit

Permalink
Optimise memory usage in MultiSteps.
Browse files Browse the repository at this point in the history
Change the implementation to allow JAX/XLA to re-use memory buffers. #472

PiperOrigin-RevId: 561129449
  • Loading branch information
hbq1 authored and OptaxDev committed Aug 29, 2023
1 parent 4c04ca5 commit 2a1748f
Showing 1 changed file with 41 additions and 39 deletions.
80 changes: 41 additions & 39 deletions optax/_src/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,12 @@
# ==============================================================================
"""Transformation wrappers."""

import functools
from typing import Any, Callable, NamedTuple, Optional, Protocol, Tuple, Union

import chex
import jax
from jax import lax
import jax.numpy as jnp
from jax.tree_util import tree_flatten
from jax.tree_util import tree_map
from jax.tree_util import tree_unflatten
import numpy as np
from optax._src import base
from optax._src import numerics
Expand Down Expand Up @@ -52,12 +48,12 @@ def flatten(

def _flatten(params):
"""Flattens and concatenates all tensors in params to a single vector."""
params, _ = tree_flatten(params)
params, _ = jax.tree_util.tree_flatten(params)
return jnp.concatenate([jnp.reshape(param, [-1]) for param in params])

def _unflatten(updates, flat):
"""Extracts tensors from flat, using the structure and shapes of params."""
updates_flat, treedef = tree_flatten(updates)
updates_flat, treedef = jax.tree_util.tree_flatten(updates)
offsets = []
for update in updates_flat:
size = np.prod(update.shape)
Expand All @@ -71,7 +67,7 @@ def _unflatten(updates, flat):
jnp.reshape(flat_update, update.shape)
for flat_update, update in zip(flat_split, updates_flat)
]
return tree_unflatten(treedef, reshaped)
return jax.tree_util.tree_unflatten(treedef, reshaped)

def init_fn(params):
flat = _flatten(params)
Expand Down Expand Up @@ -144,7 +140,7 @@ def init(params):

def update(updates, state, params=None, **extra_args):
inner_state = state.inner_state
flat_updates = tree_flatten(updates)[0]
flat_updates = jax.tree_util.tree_flatten(updates)[0]
isfinite = jnp.all(
jnp.array([jnp.all(jnp.isfinite(p)) for p in flat_updates]))
notfinite_count = jnp.where(
Expand All @@ -154,7 +150,7 @@ def update(updates, state, params=None, **extra_args):
def do_update(_):
return inner.update(updates, inner_state, params, **extra_args)
def reject_update(_):
return (tree_map(jnp.zeros_like, updates), inner_state)
return (jax.tree_util.tree_map(jnp.zeros_like, updates), inner_state)

updates, new_inner_state = lax.cond(
jnp.logical_or(isfinite, notfinite_count > max_consecutive_errors),
Expand All @@ -171,7 +167,7 @@ def reject_update(_):
return base.GradientTransformationExtraArgs(init=init, update=update)


def _zeros_tree_like(inp_tree):
def _zeros_tree_like(inp_tree: chex.ArrayTree) -> chex.ArrayTree:
return jax.tree_util.tree_map(jnp.zeros_like, inp_tree)


Expand Down Expand Up @@ -376,15 +372,18 @@ def update(self,
) -> Tuple[base.Updates, MultiStepsState]:
"""Accumulates gradients and proposes non-zero updates every `k_steps`."""
k_steps = self._every_k_schedule(state.gradient_step)
acc_grads = jax.tree_util.tree_map(
functools.partial(self._acc_update, n_acc=state.mini_step),
updates, state.acc_grads)

should_skip_update, skip_state = self._should_skip_update_fn(
updates, state.gradient_step, params)
if (should_skip_update.dtype, should_skip_update.shape) != (jnp.bool_, ()):
raise ValueError(
'The `should_skip_update_fn` function should return a boolean scalar '
f'array, but it returned an array of dtype {should_skip_update.dtype}'
f' and shape {should_skip_update.shape}'
)

def final_step(args):
del args
# Note: we do not enclose variables to allow JAX to re-use memory buffers.

def _final_step(state, params, acc_grads):
final_updates, new_inner_state = self._opt.update(
acc_grads, state.inner_opt_state, params=params, **extra_args)
new_state = MultiStepsState(
Expand All @@ -395,8 +394,7 @@ def final_step(args):
skip_state=skip_state)
return final_updates, new_state

def mid_step(args):
del args
def _mid_step(state, params, acc_grads):
updates_shape_dtype, _ = jax.eval_shape(
self._opt.update, acc_grads, state.inner_opt_state, params=params)
mid_updates = jax.tree_util.tree_map(
Expand All @@ -409,27 +407,29 @@ def mid_step(args):
skip_state=skip_state)
return mid_updates, new_state

new_updates, new_state = jax.lax.cond(
state.mini_step < k_steps - 1, (), mid_step, (), final_step)

if (should_skip_update.dtype, should_skip_update.shape) != (jnp.bool_, ()):
raise ValueError(
'The `should_skip_update_fn` function should return a boolean scalar '
f'array, but it returned an array of dtype {should_skip_update.dtype}'
f' and shape {should_skip_update.shape}')
def _do_update(updates, state, params):
acc_grads = jax.tree_util.tree_map(
lambda upd, acc: self._acc_update(upd, acc, n_acc=state.mini_step),
updates, state.acc_grads)
new_updates, new_state = jax.lax.cond(
state.mini_step < k_steps - 1,
_mid_step, _final_step, *(state, params, acc_grads))
return new_updates, new_state

def _skip_update(updates, state, params):
del updates, params
multi_state_when_skip = MultiStepsState(
mini_step=state.mini_step,
gradient_step=state.gradient_step,
inner_opt_state=state.inner_opt_state,
acc_grads=state.acc_grads,
skip_state=skip_state,
)
zero_updates = _zeros_tree_like(state.acc_grads)
return zero_updates, multi_state_when_skip

multi_state_when_skip = MultiStepsState(
mini_step=state.mini_step,
gradient_step=state.gradient_step,
inner_opt_state=state.inner_opt_state,
acc_grads=state.acc_grads,
skip_state=skip_state)
zero_updates = jax.tree_util.tree_map(jnp.zeros_like, updates)
new_updates, new_state = jax.lax.cond(
should_skip_update,
(), lambda args: (zero_updates, multi_state_when_skip),
(), lambda args: (new_updates, new_state))

should_skip_update, _skip_update, _do_update, *(updates, state, params))
return new_updates, new_state

def has_updated(self, state: Union[MultiStepsState, chex.ArrayTree]) -> Array:
Expand Down Expand Up @@ -497,7 +497,9 @@ def masked(
inner = base.with_extra_args_support(inner)

def mask_pytree(pytree, mask_tree):
return tree_map(lambda m, p: p if m else MaskedNode(), mask_tree, pytree)
return jax.tree_util.tree_map(
lambda m, p: p if m else MaskedNode(), mask_tree, pytree
)

def init_fn(params):
# This is a workaround to make tree_map_params work with masking.
Expand Down Expand Up @@ -527,7 +529,7 @@ def update_fn(updates, state, params=None, **extra_args):
new_masked_updates, new_inner_state = inner.update(
masked_updates, state.inner_state, masked_params, **extra_args)

new_updates = tree_map(
new_updates = jax.tree_util.tree_map(
lambda m, new_u, old_u: new_u if m else old_u,
mask_tree, new_masked_updates, updates)
return new_updates, MaskedState(inner_state=new_inner_state)
Expand Down

0 comments on commit 2a1748f

Please sign in to comment.