Skip to content

Conversation

@hamishivi
Copy link
Collaborator

@hamishivi hamishivi commented Nov 10, 2025

Previously, grpo_fast only averaged loss locally, not globally across ranks, and not taking into account gradient accumulation. This meant that small samples had an outsized effect. See https://huggingface.co/blog/gradient_accumulation for more details. This slipped my notice... we need to check the DPO and SFT code too probably.

Anyway, the fix is: introduce a new value for masked_mean_denominator: tokens.

When set, this divides the loss by the total number of tokens across all ranks for each minibatch, which correctly scales the loss: now we normalize by the total number of tokens in the minibatch across all ranks. So while before we had:

loss = avg(rank0(loss_accum0 / rank0_accum0_toks + loss_accum1 / rank0_accum1_toks), rank1(loss_accum0 / rank1_accum0_toks + loss_accum1 / rank1_accum1_toks))

Note here if e.g. rank0_accum0_toks >> rank1_accum1_toks, rank1_accum1_toks would overwhelm the update.

Now we have:

loss = avg(rank0((loss_accum0 + loss_accum1) / all_toks), rank1((loss_accum0 + loss_accum1) / all_toks))
= all_loss / all_toks

Gave it a test locally and it seems fine. Default right now sticks to the older (technically incorrect) behaviour. Setting masked_mean_denominator to an integer value (e.g. following Dr GRPO) also technically fixes since we replace the denominators (previously token counts) with a constant.


Note

Introduces masked_mean_denominator="token" to normalize losses by total tokens across all ranks per accumulation group, and adjusts accumulation to prevent double-scaling.

  • Training/Loss normalization:
    • Add support for Args.masked_mean_denominator to accept special value "token" (with validation and axis constraints).
    • Pre-compute per-accumulation-group total token counts (honoring tool masks) and gather globally via dist.all_reduce.
    • Use these global token totals in masked_mean for loss; skip dividing by accumulation_steps when in "token" mode.
  • Metrics/statistics:
    • Stats aggregation now uses numeric denominator only; ignores "token" for stats to avoid unintended scaling.
  • Refactors:
    • Minor loss computation cleanup (loss_values), and guardrails for masked_mean_denominator argument.

Written by Cursor Bugbot for commit b37079c. This will update automatically on new commits. Configure here.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hamishivi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request resolves an inaccuracy in grpo_fast's loss accumulation, particularly affecting distributed training and gradient accumulation. The core change introduces a 'token'-based normalization strategy for the masked_mean_denominator. This ensures that the loss is correctly scaled by the total number of tokens across all participating ranks and accumulation steps, leading to more stable and accurate training dynamics by preventing smaller batches from disproportionately influencing updates.

Highlights

  • Global Loss Accumulation: Introduced a new mechanism to correctly accumulate loss globally across all distributed ranks and gradient accumulation steps, addressing a previous issue where small samples had an outsized effect on the loss.
  • New masked_mean_denominator Option: Added a special string value 'token' to the masked_mean_denominator argument. When set, this normalizes the loss by the total number of tokens in a minibatch across all ranks, ensuring proper scaling.
  • Distributed Token Counting: Implemented pre-computation and all_reduce operations to gather the total token count across all distributed ranks for each accumulation group, which is then used for global loss normalization.
  • Loss Normalization Adjustment: Modified the loss calculation to conditionally skip division by accumulation_steps when 'token' normalization is active, as the global token count already accounts for the entire accumulation group.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly adjusts the loss accumulation for grpo_fast by introducing a new normalization method based on the total number of tokens across all ranks. This prevents samples with fewer tokens from having an outsized effect on the gradient updates. The implementation looks solid. I've added a couple of suggestions to improve code readability and maintainability by reducing code duplication and simplifying conditional logic.

Comment on lines +1137 to +1140
mb_response_masks_bool = mb_response_masks[:, 1:].bool()
if args.mask_tool_use and args.tool_use:
mb_tool_mask = collated_tool_masks[i]
mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's a small inefficiency here where mb_response_masks[:, 1:].bool() is computed twice if the if condition is met. You can use an in-place &= operation to avoid this.

Additionally, this logic for calculating mb_response_masks_bool is duplicated later in the main training loop (lines 1160-1163). To improve maintainability and avoid potential bugs from inconsistent changes, consider refactoring this logic into a helper method.

Suggested change
mb_response_masks_bool = mb_response_masks[:, 1:].bool()
if args.mask_tool_use and args.tool_use:
mb_tool_mask = collated_tool_masks[i]
mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool()
mb_response_masks_bool = mb_response_masks[:, 1:].bool()
if args.mask_tool_use and args.tool_use:
mb_tool_mask = collated_tool_masks[i]
mb_response_masks_bool &= mb_tool_mask[:, 1:].bool()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea of a helper method is good tbh

Comment on lines 1253 to 1321
loss = masked_mean(
pg_loss_max + (args.beta * kl),
mb_response_masks_bool,
args.masked_mean_axis,
args.masked_mean_denominator,
)
loss = loss / accumulation_steps
loss_values,
mb_response_masks_bool,
args.masked_mean_axis,
args.masked_mean_denominator if args.masked_mean_denominator != "token" else total_batch_tokens,
)
# When using global normalization (masked_mean_denominator == "token"), total_batch_tokens is the sum
# of tokens across all minibatches in the accumulation group. Since we normalize by this total,
# we should NOT divide by accumulation_steps (the normalization already accounts for all minibatches).
# For other normalization modes, we divide by accumulation_steps to properly scale gradients.
if args.masked_mean_denominator != "token":
loss = loss / accumulation_steps
self.model.backward(loss)
if (local_step + 1) % accumulation_steps == 0:
self.model.step()
local_step += 1
with torch.no_grad():
# Convert "token" to total_batch_tokens for statistics computation
stats_denominator = (
args.masked_mean_denominator if args.masked_mean_denominator != "token" else total_batch_tokens
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition args.masked_mean_denominator != "token" is checked multiple times in this block. To improve readability and avoid redundant checks, you could store this boolean result in a variable at the beginning of this section, for example:

is_global_token_norm = args.masked_mean_denominator == "token"

Then you can use is_global_token_norm in the subsequent logic (e.g., for calculating loss and stats_denominator), which would make the code cleaner and easier to follow.

assert self.masked_mean_denominator > 0, (
f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!"
)
if isinstance(self.masked_mean_denominator, str):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change this block of code to be a function call? get_denominator, which combines all of this logic and has a few simple tests? (Using parameterized)

Comment on lines +1137 to +1140
mb_response_masks_bool = mb_response_masks[:, 1:].bool()
if args.mask_tool_use and args.tool_use:
mb_tool_mask = collated_tool_masks[i]
mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the idea of a helper method is good tbh

"""the axis to compute the mean of the masked values"""
masked_mean_denominator: float | None = None
"""Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum"""
masked_mean_denominator: float | str | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me understand why we need this? Can we get rid of this and just do the right thing?

What if we either:

  1. Get rid of this and always do "token"
  2. have it default to "token" if no float is set?

accumulation_group_tokens[group_start] = local_group_tokens_tensor.item()
else:
accumulation_group_tokens[group_start] = local_group_tokens

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you makek this a function?

accumulation_group_tokens = maybe_calculate_group_tokens(...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants