Skip to content

Makes logic in optimizers use 1-based steps. #1118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
18 changes: 7 additions & 11 deletions axlearn/common/factorized_rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,12 @@ def update_fn(grads, state, params):
if params is None:
raise ValueError("param is None")

def _update(grad, v_row, v_col, v, param, step):
count_inc = optax.safe_int32_increment(state.count)

def _update(grad, v_row, v_col, v, param):
grad = grad.astype(jnp.float32)
decay_rate_t = decay_rate(step)
# `decay_rate` assumes that the first step is 1, so we pass `count_inc` to it.
decay_rate_t = decay_rate(count_inc)

# Scaled by factorized second moment statistics.
new_v_row = jnp.zeros((1,), dtype=jnp.float32)
Expand Down Expand Up @@ -169,18 +172,11 @@ def _update(grad, v_row, v_col, v, param, step):
return _UpdateResult(update, new_v_row, new_v_col, new_v)

# Transform grad and compute new per-parameter stats.
output = jax.tree.map(
lambda *args: _update(*args, state.count),
grads,
state.v_row,
state.v_col,
state.v,
params,
)
output = jax.tree.map(_update, grads, state.v_row, state.v_col, state.v, params)

# Unpack updates / stats and return.
updates = jax.tree.map(lambda o: o.update, output)
return updates, _to_state(optax.safe_int32_increment(state.count), output)
return updates, _to_state(count_inc, output)

@dataclasses.dataclass
class VxSpec:
Expand Down
14 changes: 7 additions & 7 deletions axlearn/common/learner_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_learner(self, ema_decay: Optional[float], method: str):
)
learning_rate_fn = schedule.as_schedule_fn(learning_rate)
weight_decay = 1e-4
step = 0
step = 1
sgd_cfg = config_for_function(sgd_optimizer).set(
learning_rate=learning_rate,
decouple_weight_decay=True,
Expand Down Expand Up @@ -315,9 +315,9 @@ def loss_fn(model_params, inputs):
self.assertNestedAllClose(
{
"learning_rate": learning_rate_fn(step),
"lr_schedule_step": 0,
"lr_schedule_step": step,
"gradient_norm": 1.0093285,
"schedule_step": 0,
"schedule_step": step,
"schedule_scale": -1.0 * learning_rate_fn(step),
},
summaries,
Expand Down Expand Up @@ -487,9 +487,9 @@ def loss_fn(x):
self.assertNestedAllClose(
{
"learning_rate": learning_rate_fn(step),
"lr_schedule_step": 0,
"lr_schedule_step": 1,
"gradient_norm": expected_grad_norm,
"schedule_step": 0,
"schedule_step": 1,
"schedule_scale": -1.0 * learning_rate_fn(step),
},
summaries["optimizer"],
Expand Down Expand Up @@ -649,13 +649,13 @@ def loss_fn(x):
self.assertNestedAllClose(
{
"optimizer/learning_rate": 1.0,
"optimizer/lr_schedule_step": 0,
"optimizer/lr_schedule_step": 1,
"optimizer/gradient_norm": jnp.sqrt(jnp.sum(2 * expected_grad**2)),
"param_rms/weight": jnp.sqrt((0 + 4 + 4 + 9) / 4),
"param_rms/moving_mean": 0.5,
"grad_rms/weight": jnp.sqrt(jnp.mean(expected_grad**2)),
"grad_rms/moving_mean": jnp.sqrt(jnp.mean(expected_grad**2)),
"optimizer/schedule_step": 0,
"optimizer/schedule_step": 1,
"optimizer/schedule_scale": -1.0,
},
output_collection.summaries,
Expand Down
97 changes: 54 additions & 43 deletions axlearn/common/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,11 @@ def scale_by_schedule(
) -> PartitionedGradientTransformation:
"""Scales updates using a custom schedule for the step size.

Unlike optax.scale_by_schedule, this implementation uses 1-based steps, i.e., the first
step will be 1 to be consistent with the step count in trainer, summaries, and checkpoints.

Args:
step_size_fn: A function that takes an update count as input and returns a scale factor
step_size_fn: A function that takes the current step as input and returns a scale factor
to multiply the updates by.
name: Name for this transformation (used to group logged summaries).
If None, will not group logged summaries under a name.
Expand All @@ -220,17 +223,25 @@ def scale_by_schedule(
schedule_fn = schedule.as_schedule_fn(step_size_fn)
summary_name_prefix = "" if name is None else f"{name}/"

def wrapped_schedule_fn(step):
scale_multiplier = schedule_fn(step)
def init_fn(params):
del params
return optax.ScaleByScheduleState(count=jnp.zeros([], jnp.int32))

def update_fn(updates, state, params=None):
del params
count_inc = optax.safe_int32_increment(state.count)
step_size = schedule_fn(count_inc)
context = current_context()
if context:
context.add_summary(summary_name_prefix + "schedule_step", step)
context.add_summary(summary_name_prefix + "schedule_scale", scale_multiplier)
return scale_multiplier
context.add_summary(summary_name_prefix + "schedule_step", count_inc)
context.add_summary(summary_name_prefix + "schedule_scale", step_size)
updates = jax.tree.map(lambda g: jnp.array(step_size, dtype=g.dtype) * g, updates)
return updates, optax.ScaleByScheduleState(count=count_inc)

return with_partition_fn(
optax.scale_by_schedule(wrapped_schedule_fn),
lambda _: optax.ScaleByScheduleState(
return PartitionedGradientTransformation(
init=init_fn,
update=update_fn,
partition=lambda _: optax.ScaleByScheduleState(
count=OptStateSpec(shape=[], dtype=jnp.int32, mesh_axes=PartitionSpec())
),
)
Expand Down Expand Up @@ -562,12 +573,15 @@ def update_fn(updates: NestedTensor, state: AddDecayedWeightsState, params: Nest
if params is None:
raise ValueError(optax.NO_PARAMS_MSG) # pylint: disable=no-member

if not learning_rate_exponent:
if learning_rate_exponent is None:
lr_scale = 1.0
updated_state = state
else:
learning_rate_fn = schedule.as_schedule_fn(learning_rate)
lr = learning_rate_fn(state.count)
count_inc = optax.safe_int32_increment(state.count)
lr = learning_rate_fn(count_inc)
lr_scale = lr**learning_rate_exponent
updated_state = AddDecayedWeightsState(count_inc)

param_scales = _weight_decay_scales(params, per_param_scale=per_param_scale)
f = lambda g, p, s: g + weight_decay * lr_scale * p.value * s
Expand All @@ -578,10 +592,6 @@ def update_fn(updates: NestedTensor, state: AddDecayedWeightsState, params: Nest
param_scales,
is_leaf=lambda x: x is None,
)
if learning_rate_exponent is None:
updated_state = state
else:
updated_state = AddDecayedWeightsState(optax.safe_int32_increment(state.count))
return updates, updated_state

def partition_fn(param_specs):
Expand Down Expand Up @@ -926,7 +936,8 @@ class _UpdateResult:

def update_fn(updates, state, params=None):
del params
decay_t = decay_fn(state.count)
count_inc = optax.safe_int32_increment(state.count)
decay_t = decay_fn(count_inc)

def _to_qint_tensor_ema(value: Tensor) -> _TensorEma:
# Map value to integer with a symmetric quantization scheme.
Expand All @@ -952,17 +963,16 @@ def _to_float(value: Tensor, qstep_size: Tensor) -> Tensor:
return value.astype(qstep_size.dtype) * jnp.expand_dims(qstep_size, axis=0)

# pylint: disable-next=redefined-outer-name
def _update(value: Tensor, ema: Tensor, qstep_size: Tensor, count: Tensor) -> _UpdateResult:
def _update(value: Tensor, ema: Tensor, qstep_size: Tensor) -> _UpdateResult:
update = new_ema = (1 - decay_t) * value + decay_t * _to_float(ema, qstep_size)
if debias:
bias_correction = 1 - decay_t**count
bias_correction = 1 - decay_t**count_inc
update = new_ema / bias_correction.astype(new_ema.dtype)
return _UpdateResult(update=update, tensor_ema=_to_tensor_ema(new_ema))

# Transform updates and compute new per-tensor EMA.
count_inc = optax.safe_int32_increment(state.count)
update_results = jax.tree.map(
lambda update, ema, scale: _update(update, ema=ema, qstep_size=scale, count=count_inc),
lambda update, ema, scale: _update(update, ema=ema, qstep_size=scale),
updates,
state.ema,
state.scale,
Expand Down Expand Up @@ -1157,9 +1167,9 @@ def __call__(self, *, count: Tensor, mean: Tensor, stddev: Tensor) -> dict[str,
"""Returns the drop_norm thresholds given the gradient norm stats.

Args:
count: the number of training steps.
count: the number of previous updates to mean/stddev.
mean: the running average of gradient norms.
stdev: the running average of gradient norm variance.
stddev: the running average of gradient norm variance.

Returns:
A dict where keys represent threshold names and values are scalar tensors representing
Expand Down Expand Up @@ -1191,7 +1201,7 @@ def drop_norm_by_grad_norm_stddev(
"""Return drop norm thresholds based on grad norm stddev."""

def fn(count: Tensor, mean: Tensor, stddev: Tensor) -> dict[str, Tensor]:
# We do not drop norm for the first `min_count` data batches,
# We do not drop norm until we have collected stats for at least `min_count` steps,
# otherwise the threshold is `mean + stddev * k` for multiplier `k`.
thresholds = {}
for v in multipliers:
Expand Down Expand Up @@ -1302,7 +1312,6 @@ def init_fn(params):

def update_fn(updates, state, params=None):
inner_state = state.inner_state
count = state.count
grad_norm_ema = state.grad_norm_ema
grad_norm_square_ema = state.grad_norm_square_ema
drop_stats = state.drop_stats
Expand All @@ -1319,7 +1328,7 @@ def _moment(
# bias correrction decay
# Sec 7.1 https://arxiv.org/pdf/1804.04235.pdf
decay = grad_norm_ema_decay
decay *= (1 - decay ** (count - 1)) / (1 - decay**count)
decay *= (1 - decay**count) / (1 - decay ** (optax.safe_int32_increment(count)))
new_norm_ema = decay * norm_ema + (1 - decay) * val
new_square_ema = decay * norm_square_ema + (1 - decay) * (val**2)
return new_norm_ema, new_square_ema
Expand Down Expand Up @@ -1361,7 +1370,9 @@ def _is_valid_step(
drop_norm,
norm_ema=grad_norm_ema,
norm_square_ema=grad_norm_square_ema,
count=count,
# Note that `count` for `drop_norm` represents the number of updates to
# mean/stddev, so we pass `state.count` rather than `count_inc`.
count=state.count,
drop_stats=drop_stats,
)
is_valid_step = jnp.logical_and(is_finite, is_valid_step)
Expand All @@ -1376,20 +1387,20 @@ def _is_valid_step(
optax.safe_int32_increment(state.nonvalid_count),
)
if use_adaptive_drop_norm:
inc_count = jnp.where(
count_inc = jnp.where(
is_valid_step,
optax.safe_int32_increment(count),
count,
optax.safe_int32_increment(state.count),
state.count,
)
new_norm_ema, new_norm_square_ema = _moment(
g_norm, grad_norm_ema, grad_norm_square_ema, inc_count
g_norm, grad_norm_ema, grad_norm_square_ema, state.count
)
new_norm_ema = jnp.where(is_valid_step, new_norm_ema, grad_norm_ema)
new_norm_square_ema = jnp.where(
is_valid_step, new_norm_square_ema, grad_norm_square_ema
)
else:
inc_count = None
count_inc = None
new_norm_ema = None
new_norm_square_ema = None
context = current_context()
Expand All @@ -1402,8 +1413,8 @@ def _is_valid_step(
context.add_summary(
"gradient_norm_std_ema", _stddev(new_norm_ema, new_norm_square_ema)
)
if inc_count is not None:
context.add_summary("count", inc_count)
if count_inc is not None:
context.add_summary("count", count_inc)
if new_drop_stats is not None:
for key, val in new_drop_stats.items():
context.add_summary(f"count_exceeds_{key}", val)
Expand All @@ -1430,7 +1441,7 @@ def _is_valid_step(
)

return final_updates, SkipClipState(
count=inc_count,
count=count_inc,
nonvalid_count=nonvalid_count,
grad_norm_ema=new_norm_ema,
grad_norm_square_ema=new_norm_square_ema,
Expand Down Expand Up @@ -1588,10 +1599,10 @@ def update_fn(updates, state, params):
if params is None:
raise ValueError("params are required for param_ema.")

decay_t = decay_fn(state.count)
count_inc = optax.safe_int32_increment(state.count)
decay_t = decay_fn(count_inc)

# Transform updates and compute new per-tensor EMA.
count_inc = optax.safe_int32_increment(state.count)
new_ema = jax.tree.map(
lambda param, ema: (1 - decay_t) * param.value + decay_t * ema,
params,
Expand Down Expand Up @@ -1840,7 +1851,7 @@ def _init(param: OptParam):

def update_fn(grads: NestedTensor, state: _AdastarState, params: NestedOptParam):
"""Applies (stage 1) gradient transformation to compute raw_updates."""
incremented_count = optax.safe_int32_increment(state.count)
count_inc = optax.safe_int32_increment(state.count)

if params is None:
raise ValueError("param is None")
Expand All @@ -1852,7 +1863,7 @@ def _moment(
return x, None
value = acc = decay * acc + (1 - decay) * x
if debias:
value = optax.bias_correction(acc, decay=decay, count=incremented_count)
value = optax.bias_correction(acc, decay=decay, count=count_inc)
return value, acc

def _split_update_results(
Expand Down Expand Up @@ -1967,7 +1978,7 @@ def _smoothed_updates(
),
summary_suffix="corr_param_smoothed_updates",
)
return smoothed_updates, _AdastarState(count=incremented_count, pps=pps_tree)
return smoothed_updates, _AdastarState(count=count_inc, pps=pps_tree)

def partition_fn(param_specs):
def _partition(param_spec: ParameterSpec):
Expand All @@ -1988,15 +1999,15 @@ def _partition(param_spec: ParameterSpec):
)

def update2_fn(updates, state: Tensor, params: NestedOptParam):
step = state
step_inc = optax.safe_int32_increment(state)

def _update2(u: Tensor, param: OptParam):
lr_scaled_updates = learning_rate * u
updates_with_wd = lr_scaled_updates + weight_decay * param.value
schedule_scale = update_schedule(step)
schedule_scale = update_schedule(step_inc)
context = current_context()
if context:
context.add_summary("schedule_step", step)
context.add_summary("schedule_step", step_inc)
context.add_summary("schedule_scale", schedule_scale)
context.add_summary("learning_rate", learning_rate * schedule_scale)
context.add_summary("weight_decay_rate", weight_decay * schedule_scale)
Expand All @@ -2008,7 +2019,7 @@ def _update2(u: Tensor, param: OptParam):
params,
is_leaf=lambda x: x is None,
)
return updates2, optax.safe_int32_increment(step)
return updates2, step_inc

# Stage 1.
tx = {
Expand Down
Loading