Skip to content

[skyrl-train] Fix loss reduction by moving normalization to the advantage computation#925

Open
justinvyu wants to merge 32 commits intoNovaSky-AI:mainfrom
justinvyu:fix_loss_reduction2
Open

[skyrl-train] Fix loss reduction by moving normalization to the advantage computation#925
justinvyu wants to merge 32 commits intoNovaSky-AI:mainfrom
justinvyu:fix_loss_reduction2

Conversation

@justinvyu
Copy link
Contributor

@justinvyu justinvyu commented Jan 23, 2026

Summary

The previous implementation for ppo policy loss reduction had a "mean of means" bias — when computing token-mean loss across micro-batches and workers with varying token counts, the naive averaging gave incorrect results where:

  • microbatches with fewer tokens are weighted more heavily (since we take a mean across microbatches within a minibatch)
    • Micro-batch 1: 100 tokens, average loss = 0.5, micro-batch 2: 900 tokens, average loss = 0.3
    • -> Naive mean: (0.5 + 0.3) / 2 = 0.4, Correct token-mean: (100×0.5 + 900×0.3) / 1000 = 0.32
  • worker minibatches with fewer tokens were weighted more heavily (since DDP all-reduce takes a mean across minibatches)
    • Same example as above, but the average is across workers instead.

After this PR, ppo_policy_loss used within forward_backward now just sums the per-token loss for all sequences and relies on the advantages passed in by the user to handle the loss normalization.

This aligns with Tinker semantics:

Notice that for all objectives we sum the token-level losses over the sequence length unlike some other loss implementations. If you would like to explore different aggregation schemes, you can include that in the advantage tensor computation.

Example for loss_reduction="token_mean":

  • Move the 1/num_minibatch_tokens normalization into the advantage: loss = sum( -advantage_i * ratio_i for i in range(num_minibatch_tokens) ) / num_minibatch_tokens
  • -> sum( -(advantage_i / num_minibatch_tokens) * ratio_i for i in range(num_minibatch_tokens) )

DDP all-reduce

DDP/FSDP defaults to a mean all-reduce for gradients across workers. This PR counteracts this by multiplying by the DP world size.

Additional details

This was the first attempt: #909

This method was to track total tokens and then do one big normalization at the optim_step in order to get an average per-token loss. But, we decided to align with Tinker's way of just summing up the loss at the end, and pushing any loss normalization to the user's advantage calculation.

The benefit is that users have full control of customizing their loss reduction strategy, rather than having it happen in our opaque forward_backward, optim_step implementation which would require some configuration argument that diverges from tinker's API. For example, we would need to add a config somewhere to determine how to average/sum the loss:

client.forward_backward(...)
client.optim_step(..., loss_reduction="token_mean")  # no longer tinker compatible

Follow-up work

The ppo_critic_loss has the same problem but is not as important as the policy loss.


Open with Devin

Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Signed-off-by: Justin Yu <justinvyu@anyscale.com>
Comment on lines 787 to 789
for param in self.model.parameters():
if param.grad is not None:
param.grad.mul_(self.strategy.world_size)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could do this at the advantage computation level, but i thought it was a bit weird to have ddp all-reduce implementation details there so i separated it to be here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i agree that this is the right separation

@erictang000 erictang000 marked this pull request as ready for review January 31, 2026 17:05
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses the 'mean of means' bias in PPO policy loss reduction by moving the normalization logic from the loss function to the advantage computation. However, a potential division-by-zero vulnerability was identified in the new normalize_minibatch_advantages function in trainer.py. This could lead to numerical instability (NaNs) and training failure if a mini-batch contains only masked-out sequences; a fix using .clamp(min=1.0) is recommended. Additionally, I have one suggestion to improve the robustness of the configuration validation.

Comment on lines 257 to 264
# assert cfg.trainer.algorithm.loss_reduction in (
# "token_mean",
# "sequence_mean",
# "seq_mean_token_sum_norm",
# ), (
# f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. "
# f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`"
# )
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This assertion for loss_reduction has been commented out. While the normalization logic has moved to trainer.py, this validation is still crucial. If an invalid loss_reduction value is provided in the configuration, normalize_minibatch_advantages will silently fail to normalize the advantages, as it lacks an else block for unknown values. This would result in an un-normalized sum for the loss, which could be very large and lead to training instability. It's safer to fail fast with an explicit error.

I recommend re-enabling this assertion to ensure only valid loss_reduction options are accepted.

Suggested change
# assert cfg.trainer.algorithm.loss_reduction in (
# "token_mean",
# "sequence_mean",
# "seq_mean_token_sum_norm",
# ), (
# f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. "
# f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`"
# )
assert cfg.trainer.algorithm.loss_reduction in (
"token_mean",
"sequence_mean",
"seq_mean_token_sum_norm",
), (
f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. "
f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`"
)

Comment on lines 1034 to 1041
# Option 1: token mean
if self.cfg.trainer.algorithm.loss_reduction == "token_mean":
data["advantages"] = advantages / loss_mask.sum()

# Option 2: sequence mean
elif self.cfg.trainer.algorithm.loss_reduction == "sequence_mean":
batch_size = len(data)
data["advantages"] = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The normalize_minibatch_advantages function performs division by loss_mask.sum() (line 1036) and loss_mask.sum(dim=-1, keepdim=True) (line 1041) without verifying if the divisor is zero. In Reinforcement Learning training, if a mini-batch consists entirely of sequences that are masked out (e.g., due to filtering or empty responses), the sum of the loss_mask will be zero. Dividing by zero will result in inf or nan values in the advantages tensor, which will propagate to the gradients and corrupt the model weights during the optimizer step. This effectively causes a Denial of Service (DoS) on the training process.

Recommendation: Use .clamp(min=1.0) on the divisor to ensure it is never zero, consistent with the implementation of masked_mean in skyrl_train/utils/ppo_utils.py.

Suggested change
# Option 1: token mean
if self.cfg.trainer.algorithm.loss_reduction == "token_mean":
data["advantages"] = advantages / loss_mask.sum()
# Option 2: sequence mean
elif self.cfg.trainer.algorithm.loss_reduction == "sequence_mean":
batch_size = len(data)
data["advantages"] = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True))
# Option 1: token mean
if self.cfg.trainer.algorithm.loss_reduction == "token_mean":
data["advantages"] = advantages / loss_mask.sum().clamp(min=1.0)
# Option 2: sequence mean
elif self.cfg.trainer.algorithm.loss_reduction == "sequence_mean":
batch_size = len(data)
data["advantages"] = advantages / (batch_size * loss_mask.sum(dim=-1, keepdim=True).clamp(min=1.0))

Copy link
Contributor

@devin-ai-integration devin-ai-integration bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Devin Review found 2 potential issues.

View 8 additional findings in Devin Review.

Open in Devin Review

Comment on lines +1098 to +1105
# iterate over mini-batches to do mini batch level normalization
for local_step in range(num_mini_batches):
start_idx = local_step * mini_batch_size
end_idx = (local_step + 1) * mini_batch_size
mini_batch = data[start_idx:end_idx]
mini_batch = self.normalize_minibatch_advantages(mini_batch)
# Copy normalized advantages back to original batch
data["advantages"][start_idx:end_idx] = mini_batch["advantages"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Double normalization of advantages when critic model is enabled

When a critic model is configured (has_critic=True), train_critic_and_policy calls _execute_training_step first for the critic and then for the policy. Both calls execute normalize_minibatch_advantages, which writes the normalized advantages back to the shared data object in-place via data["advantages"][start_idx:end_idx] = mini_batch["advantages"] at skyrl-train/skyrl_train/trainer.py:1105. The critic's training doesn't use advantages, but its call to _execute_training_step still modifies data["advantages"]. When the policy's _execute_training_step runs next, it normalizes the already-normalized advantages a second time.

Root Cause and Impact

The flow in train_critic_and_policy (skyrl-train/skyrl_train/trainer.py:1130-1144):

  1. _execute_training_step("critic", data) → normalizes data["advantages"] in-place
  2. _execute_training_step("policy", data) → normalizes the already-normalized advantages again

For example, with loss_reduction="token_mean", advantages get divided by loss_mask.sum() twice, making them far too small. With loss_reduction="seq_mean_token_sum_norm", they get divided by (batch_size * max_seq_len) twice. Additionally, if critic_mini_batch_size != policy_mini_batch_size, the mini-batch boundaries differ, so the normalization uses different batch compositions each time, compounding the error.

Impact: Policy gradients would be drastically wrong (too small by a factor proportional to the batch token count), severely degrading training quality for any configuration with a critic model.

Prompt for agents
The normalize_minibatch_advantages loop in _execute_training_step runs for both critic and policy models, but should only run for the policy model. The simplest fix is to guard the normalization loop with a check like `if model == "policy":` so that the critic path does not modify data["advantages"]. Specifically, in skyrl_train/trainer.py around lines 1098-1105, wrap the normalization loop:

    if model == "policy":
        for local_step in range(num_mini_batches):
            start_idx = local_step * mini_batch_size
            end_idx = (local_step + 1) * mini_batch_size
            mini_batch = data[start_idx:end_idx]
            mini_batch = self.normalize_minibatch_advantages(mini_batch)
            data["advantages"][start_idx:end_idx] = mini_batch["advantages"]
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +20 to +23
# # torch profiler config
# ENABLE_TORCH_PROFILER=false
# RANKS_TO_PROFILE="[0]"
# SAVE_PATH="$HOME/megatron_prof/tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_${MODEL_NAME}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Example script references undefined shell variables after commenting out their definitions

In the run_megatron.sh example script, the torch profiler config variables (ENABLE_TORCH_PROFILER, RANKS_TO_PROFILE, SAVE_PATH) are commented out on lines 20-23, but their references on lines 42-44 remain in the uv run command. In bash (without set -u), these expand to empty strings, causing the training script to receive empty/invalid values for profiler config options.

Detailed Explanation

Lines 20-23 comment out the variable definitions:

# ENABLE_TORCH_PROFILER=false
# RANKS_TO_PROFILE="[0]"
# SAVE_PATH="..."

But lines 42-44 still reference them:

trainer.policy.megatron_config.torch_profiler_config.enable=$ENABLE_TORCH_PROFILER \
trainer.policy.megatron_config.torch_profiler_config.ranks=$RANKS_TO_PROFILE \
trainer.policy.megatron_config.torch_profiler_config.save_path=$SAVE_PATH \

Impact: The script would pass empty values for these config keys, which could cause config parsing errors or unexpected behavior at runtime.

Prompt for agents
In skyrl-train/examples/megatron/run_megatron.sh, either uncomment the variable definitions on lines 20-23, or also comment out/remove the references on lines 42-44 that use $ENABLE_TORCH_PROFILER, $RANKS_TO_PROFILE, and $SAVE_PATH. If the profiler is not needed, remove both the definitions and the references from the uv run command.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments