Skip to content
95 changes: 79 additions & 16 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,10 @@ class Args:
"""the length of the pack (you should prob set to the max length of the model)"""
masked_mean_axis: int | None = None
"""the axis to compute the mean of the masked values"""
masked_mean_denominator: float | None = None
"""Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum"""
masked_mean_denominator: float | str | None = None
"""Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum.
Special value "token" means use total_batch_tokens (computed across all ranks in distributed training).
When using "token", total_batch_tokens is gathered via allreduce across all ranks."""
alpha: float = 0.6
"""The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param)
reference: [TR-DPO](https://huggingface.co/papers/2404.09656), but it's actually pretty commonly
Expand Down Expand Up @@ -442,9 +444,17 @@ def __post_init__(self):
"use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless."
)
if self.masked_mean_denominator is not None:
assert self.masked_mean_denominator > 0, (
f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!"
)
if isinstance(self.masked_mean_denominator, str):
assert self.masked_mean_denominator == "token", (
f"masked_mean_denominator string value must be 'token' or number, got {self.masked_mean_denominator}"
)
assert self.masked_mean_axis is None, (
"masked_mean_axis must not be provided when using 'token' normalization"
)
else:
assert self.masked_mean_denominator > 0, (
f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!"
)
assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!"
if self.num_samples_per_prompt_rollout == 1:
logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.")
Expand Down Expand Up @@ -1117,6 +1127,33 @@ def train(
ratio_stats = torch.zeros(len(collated_query_responses))
entropy_stats = torch.zeros(len(collated_query_responses))
for epoch_idx in range(args.num_epochs):
# Pre-compute total tokens for each accumulation group if using "token" normalization
# This ensures all minibatches in an accumulation group are normalized by the same total
accumulation_group_tokens = {}
if args.masked_mean_denominator == "token":
for group_start in range(0, len(collated_query_responses), accumulation_steps):
group_end = min(group_start + accumulation_steps, len(collated_query_responses))
# Calculate local tokens for all minibatches in this accumulation group
local_group_tokens = 0.0
for i in range(group_start, group_end):
mb_response_masks = collated_response_masks[i]
mb_response_masks_bool = mb_response_masks[:, 1:].bool()
if args.mask_tool_use and args.tool_use:
mb_tool_mask = collated_tool_masks[i]
mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool()
local_group_tokens += mb_response_masks_bool.sum().item()

# Gather total tokens across all ranks for this accumulation group
if dist.is_available() and dist.is_initialized():
dist.barrier()
local_group_tokens_tensor = torch.tensor(
local_group_tokens, dtype=torch.float32, device=self.device
)
dist.all_reduce(local_group_tokens_tensor, op=dist.ReduceOp.SUM, group=None)
accumulation_group_tokens[group_start] = local_group_tokens_tensor.item()
else:
accumulation_group_tokens[group_start] = local_group_tokens

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you makek this a function?

accumulation_group_tokens = maybe_calculate_group_tokens(...)

for i in range(len(collated_query_responses)):
mb_ref_logprob = collated_ref_logprobs[i]
mb_query_responses = collated_query_responses[i]
Expand All @@ -1127,6 +1164,13 @@ def train(
# if masking snippets, do it here.
if args.mask_tool_use and args.tool_use:
mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool()

# Get total tokens for this accumulation group if using "token" normalization
# This ensures all minibatches in the accumulation group are normalized by the same total
if args.masked_mean_denominator == "token":
group_start = (i // accumulation_steps) * accumulation_steps
total_batch_tokens = accumulation_group_tokens[group_start]

mb_attention_mask = collated_attention_masks[i]
mb_position_id = collated_position_ids[i]
mb_local_logprobs, mb_entropy = self.forward(
Expand Down Expand Up @@ -1250,30 +1294,49 @@ def train(
kl = kl4

# grpo change: directly subtract KL in loss (add)
loss_values = pg_loss_max + (args.beta * kl)

# Three loss cases:
# masked_mean_denominator is set: we use sum and divide loss by this constant.
# masked_mean_denominator is set to "token": we use sum and divide loss by total number of tokens in batch.
# masked_mean_denominator is None, masked_mean_axis is None: we take mean across tokens in minibatch (old behaviour)
# masked_mean_denominator is None, masked_mean_axis is 1: we use sample-wise averaging across the sequence axis.
loss = masked_mean(
pg_loss_max + (args.beta * kl),
loss_values,
mb_response_masks_bool,
args.masked_mean_axis,
args.masked_mean_denominator,
args.masked_mean_denominator
if args.masked_mean_denominator != "token"
else total_batch_tokens,
)
loss = loss / accumulation_steps
# When using global normalization (masked_mean_denominator == "token"), total_batch_tokens is the sum
# of tokens across all minibatches in the accumulation group. Since we normalize by this total,
# we should NOT divide by accumulation_steps (the normalization already accounts for all minibatches).
# For other normalization modes, we divide by accumulation_steps to properly scale gradients.
if args.masked_mean_denominator != "token":
loss = loss / accumulation_steps
self.model.backward(loss)
if (local_step + 1) % accumulation_steps == 0:
self.model.step()
local_step += 1
with torch.no_grad():
# for stats computation, for now no denominator is used
# unless masked_mean_denominator is a numeric value.
stats_denominator = (
args.masked_mean_denominator if args.masked_mean_denominator != "token" else None
)
# NOTE: in packed implementation, kl calculation are averages over response tokens
kl1_stats[i] = masked_mean(
kl1, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
kl1, mb_response_masks_bool, args.masked_mean_axis, stats_denominator
).float()
kl2_stats[i] = masked_mean(
kl2, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
kl2, mb_response_masks_bool, args.masked_mean_axis, stats_denominator
).float()
kl3_stats[i] = masked_mean(
kl3, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
kl3, mb_response_masks_bool, args.masked_mean_axis, stats_denominator
).float()
kl4_stats[i] = masked_mean(
kl4, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
kl4, mb_response_masks_bool, args.masked_mean_axis, stats_denominator
).float()
if args.kl_estimator == "kl1":
kl_loss_stats[i] = kl1_stats[i] * args.beta
Expand All @@ -1287,19 +1350,19 @@ def train(
(pg_losses2 > pg_losses).float(),
mb_response_masks_bool,
args.masked_mean_axis,
args.masked_mean_denominator,
stats_denominator,
)
pg_loss_stats[i] = masked_mean(
pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, stats_denominator
)
loss_stats[i] = loss
ratio_stats[i] = masked_mean(
ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
ratio, mb_response_masks_bool, args.masked_mean_axis, stats_denominator
)
if args.record_entropy:
# Calculate entropy statistics
entropy_stats[i] = masked_mean(
mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
mb_entropy, mb_response_masks_bool, args.masked_mean_axis, stats_denominator
).float()

with torch.no_grad():
Expand Down