-
Notifications
You must be signed in to change notification settings - Fork 464
Correct loss accumulation for grpo_fast #1161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
hamishivi
wants to merge
17
commits into
main
Choose a base branch
from
accum-loss-fix-grpo-fast
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+257
−65
Open
Changes from 8 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
1b6c5e1
trying out accum fix
hamishivi 23ea9ce
Merge branch 'main' into accum-loss-fix-grpo-fast
hamishivi 92a4bc2
fix
hamishivi a9ae929
fix
hamishivi c7ccc15
Fix up
hamishivi 0aad4ac
fix
hamishivi 89fb420
fix
hamishivi b37079c
fix
hamishivi 0331016
Merge branch 'main' into accum-loss-fix-grpo-fast
hamishivi 1390a6e
loss fixes
hamishivi f94b4b2
fix
hamishivi 1931b99
lint
hamishivi 90eb98e
fix
hamishivi 6e698e9
fix
hamishivi 802a7e7
quick and dirty group-level
hamishivi 849ebb4
fix quality
hamishivi 8c7559d
correct hacky group level
hamishivi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -252,8 +252,10 @@ class Args: | |
| """the length of the pack (you should prob set to the max length of the model)""" | ||
| masked_mean_axis: int | None = None | ||
| """the axis to compute the mean of the masked values""" | ||
| masked_mean_denominator: float | None = None | ||
| """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum""" | ||
| masked_mean_denominator: float | str | None = None | ||
| """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum. | ||
| Special value "token" means use total_batch_tokens (computed across all ranks in distributed training). | ||
| When using "token", total_batch_tokens is gathered via allreduce across all ranks.""" | ||
| alpha: float = 0.6 | ||
| """The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param) | ||
| reference: [TR-DPO](https://huggingface.co/papers/2404.09656), but it's actually pretty commonly | ||
|
|
@@ -442,9 +444,17 @@ def __post_init__(self): | |
| "use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless." | ||
| ) | ||
| if self.masked_mean_denominator is not None: | ||
| assert self.masked_mean_denominator > 0, ( | ||
| f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!" | ||
| ) | ||
| if isinstance(self.masked_mean_denominator, str): | ||
hamishivi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| assert self.masked_mean_denominator == "token", ( | ||
| f"masked_mean_denominator string value must be 'token' or number, got {self.masked_mean_denominator}" | ||
| ) | ||
| assert self.masked_mean_axis is None, ( | ||
| "masked_mean_axis must not be provided when using 'token' normalization" | ||
| ) | ||
| else: | ||
| assert self.masked_mean_denominator > 0, ( | ||
| f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!" | ||
| ) | ||
| assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" | ||
| if self.num_samples_per_prompt_rollout == 1: | ||
| logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") | ||
|
|
@@ -1117,6 +1127,33 @@ def train( | |
| ratio_stats = torch.zeros(len(collated_query_responses)) | ||
| entropy_stats = torch.zeros(len(collated_query_responses)) | ||
| for epoch_idx in range(args.num_epochs): | ||
| # Pre-compute total tokens for each accumulation group if using "token" normalization | ||
| # This ensures all minibatches in an accumulation group are normalized by the same total | ||
| accumulation_group_tokens = {} | ||
| if args.masked_mean_denominator == "token": | ||
| for group_start in range(0, len(collated_query_responses), accumulation_steps): | ||
| group_end = min(group_start + accumulation_steps, len(collated_query_responses)) | ||
| # Calculate local tokens for all minibatches in this accumulation group | ||
| local_group_tokens = 0.0 | ||
| for i in range(group_start, group_end): | ||
| mb_response_masks = collated_response_masks[i] | ||
| mb_response_masks_bool = mb_response_masks[:, 1:].bool() | ||
| if args.mask_tool_use and args.tool_use: | ||
| mb_tool_mask = collated_tool_masks[i] | ||
| mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() | ||
hamishivi marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| local_group_tokens += mb_response_masks_bool.sum().item() | ||
|
|
||
| # Gather total tokens across all ranks for this accumulation group | ||
| if dist.is_available() and dist.is_initialized(): | ||
| dist.barrier() | ||
| local_group_tokens_tensor = torch.tensor( | ||
| local_group_tokens, dtype=torch.float32, device=self.device | ||
| ) | ||
| dist.all_reduce(local_group_tokens_tensor, op=dist.ReduceOp.SUM, group=None) | ||
| accumulation_group_tokens[group_start] = local_group_tokens_tensor.item() | ||
| else: | ||
| accumulation_group_tokens[group_start] = local_group_tokens | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you makek this a function? |
||
| for i in range(len(collated_query_responses)): | ||
| mb_ref_logprob = collated_ref_logprobs[i] | ||
| mb_query_responses = collated_query_responses[i] | ||
|
|
@@ -1127,6 +1164,13 @@ def train( | |
| # if masking snippets, do it here. | ||
| if args.mask_tool_use and args.tool_use: | ||
| mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() | ||
|
|
||
| # Get total tokens for this accumulation group if using "token" normalization | ||
| # This ensures all minibatches in the accumulation group are normalized by the same total | ||
| if args.masked_mean_denominator == "token": | ||
| group_start = (i // accumulation_steps) * accumulation_steps | ||
| total_batch_tokens = accumulation_group_tokens[group_start] | ||
|
|
||
| mb_attention_mask = collated_attention_masks[i] | ||
| mb_position_id = collated_position_ids[i] | ||
| mb_local_logprobs, mb_entropy = self.forward( | ||
|
|
@@ -1250,30 +1294,49 @@ def train( | |
| kl = kl4 | ||
|
|
||
| # grpo change: directly subtract KL in loss (add) | ||
| loss_values = pg_loss_max + (args.beta * kl) | ||
|
|
||
| # Three loss cases: | ||
| # masked_mean_denominator is set: we use sum and divide loss by this constant. | ||
| # masked_mean_denominator is set to "token": we use sum and divide loss by total number of tokens in batch. | ||
| # masked_mean_denominator is None, masked_mean_axis is None: we take mean across tokens in minibatch (old behaviour) | ||
| # masked_mean_denominator is None, masked_mean_axis is 1: we use sample-wise averaging across the sequence axis. | ||
| loss = masked_mean( | ||
| pg_loss_max + (args.beta * kl), | ||
| loss_values, | ||
| mb_response_masks_bool, | ||
| args.masked_mean_axis, | ||
| args.masked_mean_denominator, | ||
| args.masked_mean_denominator | ||
| if args.masked_mean_denominator != "token" | ||
| else total_batch_tokens, | ||
| ) | ||
| loss = loss / accumulation_steps | ||
| # When using global normalization (masked_mean_denominator == "token"), total_batch_tokens is the sum | ||
| # of tokens across all minibatches in the accumulation group. Since we normalize by this total, | ||
| # we should NOT divide by accumulation_steps (the normalization already accounts for all minibatches). | ||
| # For other normalization modes, we divide by accumulation_steps to properly scale gradients. | ||
| if args.masked_mean_denominator != "token": | ||
| loss = loss / accumulation_steps | ||
| self.model.backward(loss) | ||
| if (local_step + 1) % accumulation_steps == 0: | ||
| self.model.step() | ||
| local_step += 1 | ||
| with torch.no_grad(): | ||
| # for stats computation, for now no denominator is used | ||
| # unless masked_mean_denominator is a numeric value. | ||
| stats_denominator = ( | ||
| args.masked_mean_denominator if args.masked_mean_denominator != "token" else None | ||
| ) | ||
| # NOTE: in packed implementation, kl calculation are averages over response tokens | ||
| kl1_stats[i] = masked_mean( | ||
| kl1, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| kl1, mb_response_masks_bool, args.masked_mean_axis, stats_denominator | ||
| ).float() | ||
| kl2_stats[i] = masked_mean( | ||
| kl2, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| kl2, mb_response_masks_bool, args.masked_mean_axis, stats_denominator | ||
| ).float() | ||
| kl3_stats[i] = masked_mean( | ||
| kl3, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| kl3, mb_response_masks_bool, args.masked_mean_axis, stats_denominator | ||
| ).float() | ||
| kl4_stats[i] = masked_mean( | ||
| kl4, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| kl4, mb_response_masks_bool, args.masked_mean_axis, stats_denominator | ||
| ).float() | ||
| if args.kl_estimator == "kl1": | ||
| kl_loss_stats[i] = kl1_stats[i] * args.beta | ||
|
|
@@ -1287,19 +1350,19 @@ def train( | |
| (pg_losses2 > pg_losses).float(), | ||
| mb_response_masks_bool, | ||
| args.masked_mean_axis, | ||
| args.masked_mean_denominator, | ||
| stats_denominator, | ||
| ) | ||
| pg_loss_stats[i] = masked_mean( | ||
| pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, stats_denominator | ||
| ) | ||
| loss_stats[i] = loss | ||
| ratio_stats[i] = masked_mean( | ||
| ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| ratio, mb_response_masks_bool, args.masked_mean_axis, stats_denominator | ||
| ) | ||
| if args.record_entropy: | ||
| # Calculate entropy statistics | ||
| entropy_stats[i] = masked_mean( | ||
| mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator | ||
| mb_entropy, mb_response_masks_bool, args.masked_mean_axis, stats_denominator | ||
| ).float() | ||
|
|
||
| with torch.no_grad(): | ||
|
|
||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.