Skip to content

Commit 43876d7

Browse files
committed
moved stop rate up
1 parent 344fae6 commit 43876d7

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

open_instruct/grpo_fast.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,6 +1844,10 @@ def data_preparation_thread(
18441844
else:
18451845
raise ValueError(f"Invalid advantage normalization type: {args.advantage_normalization_type}")
18461846

1847+
# Calculate stop rate before potentially masking truncated completions
1848+
stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len(
1849+
result.finish_reasons
1850+
)
18471851
if args.mask_truncated_completions:
18481852
stop_idxes = torch.tensor(
18491853
[i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"]
@@ -1930,10 +1934,6 @@ def data_preparation_thread(
19301934
sequence_length_unsolved = (
19311935
np.array([]) if np.all(scores == args.max_possible_score) else np.array(sequence_lengths[scores == 0])
19321936
)
1933-
stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len(
1934-
result.finish_reasons
1935-
)
1936-
19371937
batch_metrics = asdict(batch_stats)
19381938
total_reward_groups = real_num_responses / args.num_samples_per_prompt_rollout
19391939
batch_metrics["percent_filtered_groups"] = batch_metrics["total_prompts"] / total_reward_groups

0 commit comments

Comments
 (0)