Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions slime/backends/megatron_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ def log_rollout_data(
"rollout_routed_experts",
"max_seq_lens",
"dynamic_global_batch_size",
"raw_rewards_partitioned",
]:
continue
# Upload per sample mean for each rollout value
Expand Down Expand Up @@ -505,13 +506,15 @@ def quantile(total_value, n_quantiles, data) -> dict:
percentile = {f"p{min(math.ceil(q*100),100)}": p for q, p in zip(quantiles, percentile, strict=True)}
return percentile

raw_rewards = rollout_data["raw_reward"]
# raw_reward may stay global for passrate; correct-sample metrics
# must use rewards aligned with local response_lengths/total_lengths.
raw_rewards_partitioned = rollout_data["raw_rewards_partitioned"]
# Additional metrics for correct cases are calculated separately below.
correct_response_lengths = []
correct_total_lengths = []
correct_loss_masks = []
correct_entropy = []
for i, raw_reward in enumerate(raw_rewards):
for i, raw_reward in enumerate(raw_rewards_partitioned):
if raw_reward == 1:
correct_response_lengths.append(response_lengths[i])
correct_total_lengths.append(total_lengths[i])
Expand Down
4 changes: 4 additions & 0 deletions slime/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,5 +306,9 @@ def process_rollout_data(args, rollout_data_ref, dp_rank, dp_size):
# save the seqlen of the whole rollout batch
Timer().seq_lens = total_lengths
rollout_data["total_lengths"] = [total_lengths[i] for i in partition]
if "raw_reward" in rollout_data:
# Keep raw_reward global for passrate metrics, but provide a DP-local view
# for correct-sample metrics that index local response_lengths/total_lengths.
rollout_data["raw_rewards_partitioned"] = [rollout_data["raw_reward"][i] for i in partition]

return rollout_data