-
Notifications
You must be signed in to change notification settings - Fork 459
Correct loss accumulation for grpo_fast #1161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @hamishivi, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves an inaccuracy in Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly adjusts the loss accumulation for grpo_fast by introducing a new normalization method based on the total number of tokens across all ranks. This prevents samples with fewer tokens from having an outsized effect on the gradient updates. The implementation looks solid. I've added a couple of suggestions to improve code readability and maintainability by reducing code duplication and simplifying conditional logic.
| mb_response_masks_bool = mb_response_masks[:, 1:].bool() | ||
| if args.mask_tool_use and args.tool_use: | ||
| mb_tool_mask = collated_tool_masks[i] | ||
| mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a small inefficiency here where mb_response_masks[:, 1:].bool() is computed twice if the if condition is met. You can use an in-place &= operation to avoid this.
Additionally, this logic for calculating mb_response_masks_bool is duplicated later in the main training loop (lines 1160-1163). To improve maintainability and avoid potential bugs from inconsistent changes, consider refactoring this logic into a helper method.
| mb_response_masks_bool = mb_response_masks[:, 1:].bool() | |
| if args.mask_tool_use and args.tool_use: | |
| mb_tool_mask = collated_tool_masks[i] | |
| mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() | |
| mb_response_masks_bool = mb_response_masks[:, 1:].bool() | |
| if args.mask_tool_use and args.tool_use: | |
| mb_tool_mask = collated_tool_masks[i] | |
| mb_response_masks_bool &= mb_tool_mask[:, 1:].bool() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the idea of a helper method is good tbh
| loss = masked_mean( | ||
| pg_loss_max + (args.beta * kl), | ||
| mb_response_masks_bool, | ||
| args.masked_mean_axis, | ||
| args.masked_mean_denominator, | ||
| ) | ||
| loss = loss / accumulation_steps | ||
| loss_values, | ||
| mb_response_masks_bool, | ||
| args.masked_mean_axis, | ||
| args.masked_mean_denominator if args.masked_mean_denominator != "token" else total_batch_tokens, | ||
| ) | ||
| # When using global normalization (masked_mean_denominator == "token"), total_batch_tokens is the sum | ||
| # of tokens across all minibatches in the accumulation group. Since we normalize by this total, | ||
| # we should NOT divide by accumulation_steps (the normalization already accounts for all minibatches). | ||
| # For other normalization modes, we divide by accumulation_steps to properly scale gradients. | ||
| if args.masked_mean_denominator != "token": | ||
| loss = loss / accumulation_steps | ||
| self.model.backward(loss) | ||
| if (local_step + 1) % accumulation_steps == 0: | ||
| self.model.step() | ||
| local_step += 1 | ||
| with torch.no_grad(): | ||
| # Convert "token" to total_batch_tokens for statistics computation | ||
| stats_denominator = ( | ||
| args.masked_mean_denominator if args.masked_mean_denominator != "token" else total_batch_tokens | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The condition args.masked_mean_denominator != "token" is checked multiple times in this block. To improve readability and avoid redundant checks, you could store this boolean result in a variable at the beginning of this section, for example:
is_global_token_norm = args.masked_mean_denominator == "token"Then you can use is_global_token_norm in the subsequent logic (e.g., for calculating loss and stats_denominator), which would make the code cleaner and easier to follow.
| assert self.masked_mean_denominator > 0, ( | ||
| f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!" | ||
| ) | ||
| if isinstance(self.masked_mean_denominator, str): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you change this block of code to be a function call? get_denominator, which combines all of this logic and has a few simple tests? (Using parameterized)
| mb_response_masks_bool = mb_response_masks[:, 1:].bool() | ||
| if args.mask_tool_use and args.tool_use: | ||
| mb_tool_mask = collated_tool_masks[i] | ||
| mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the idea of a helper method is good tbh
| """the axis to compute the mean of the masked values""" | ||
| masked_mean_denominator: float | None = None | ||
| """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum""" | ||
| masked_mean_denominator: float | str | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you help me understand why we need this? Can we get rid of this and just do the right thing?
What if we either:
- Get rid of this and always do "token"
- have it default to "token" if no float is set?
| accumulation_group_tokens[group_start] = local_group_tokens_tensor.item() | ||
| else: | ||
| accumulation_group_tokens[group_start] = local_group_tokens | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you makek this a function?
accumulation_group_tokens = maybe_calculate_group_tokens(...)
Previously, grpo_fast only averaged loss locally, not globally across ranks, and not taking into account gradient accumulation. This meant that small samples had an outsized effect. See https://huggingface.co/blog/gradient_accumulation for more details. This slipped my notice... we need to check the DPO and SFT code too probably.
Anyway, the fix is: introduce a new value for
masked_mean_denominator:tokens.When set, this divides the loss by the total number of tokens across all ranks for each minibatch, which correctly scales the loss: now we normalize by the total number of tokens in the minibatch across all ranks. So while before we had:
Note here if e.g. rank0_accum0_toks >> rank1_accum1_toks, rank1_accum1_toks would overwhelm the update.
Now we have:
Gave it a test locally and it seems fine. Default right now sticks to the older (technically incorrect) behaviour. Setting
masked_mean_denominatorto an integer value (e.g. following Dr GRPO) also technically fixes since we replace the denominators (previously token counts) with a constant.Note
Introduces
masked_mean_denominator="token"to normalize losses by total tokens across all ranks per accumulation group, and adjusts accumulation to prevent double-scaling.Args.masked_mean_denominatorto accept special value"token"(with validation and axis constraints).dist.all_reduce.masked_meanfor loss; skip dividing byaccumulation_stepswhen in"token"mode."token"for stats to avoid unintended scaling.loss_values), and guardrails formasked_mean_denominatorargument.Written by Cursor Bugbot for commit b37079c. This will update automatically on new commits. Configure here.