Skip to content

Commit 6f2c44d

Browse files
committed
Merge branch 'main' of github.com:allenai/open-instruct into olmo3-rlzero
2 parents 90f9ae6 + d18452e commit 6f2c44d

File tree

6 files changed

+139
-54
lines changed

6 files changed

+139
-54
lines changed

open_instruct/grpo_fast.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,8 @@ class Args:
270270

271271
active_sampling: bool = False
272272
"""Whether to continue sampling responses until you get a full batch."""
273+
no_resampling_pass_rate: float | None = None
274+
"""If the response to a prompt is solved at a rate higher than this, do not resample this prompt again"""
273275

274276
record_entropy: bool = False
275277
"""whether to record the entropy of the policy during training. Uses extra memory."""
@@ -349,6 +351,12 @@ class Args:
349351
"""vLLM top p for nucleus sampling"""
350352
deepspeed_stage: int = 0
351353
"""the deepspeed stage"""
354+
deepspeed_zpg: int = 8
355+
"""the deepspeed zpg value. Higher values are more memory efficient but slower. Set to 1 to disable zpg, which uses less memory but is significantly slower. Ideally is set to the number of GPUs per node (usually 8, default)."""
356+
deepspeed_offload_param: bool = False
357+
"""whether to offload parameters to CPU (reduces GPU memory usage)"""
358+
deepspeed_offload_optimizer: bool = False
359+
"""whether to offload optimizer states to CPU (reduces GPU memory usage)"""
352360
gather_whole_model: bool = True
353361
"""whether to gather the whole model to boardcast (not doable for 70B but can be faster for 8B)"""
354362
enable_queue_dashboard: bool = True
@@ -670,16 +678,17 @@ def __init__(self, data: np.ndarray, batch_size: int, seed: int | None = None):
670678
self.epoch_number = 0
671679
self.rng = np.random.default_rng(seed)
672680
self.rng.shuffle(self.data)
681+
self.exclude_list = []
673682

674-
# Ensure the effective dataset size is divisible by batch_size
675-
self.effective_size = len(self.data) - (len(self.data) % batch_size)
683+
self._update_effective_size()
676684

677685
def __iter__(self) -> Iterator[list[int]]:
678686
return self
679687

680688
def __next__(self) -> list[int]:
681689
if self.index >= self.effective_size:
682690
self.index = 0
691+
self._update_effective_size()
683692
self.epoch_number += 1
684693
self.rng.shuffle(self.data)
685694

@@ -696,6 +705,7 @@ def get_state(self) -> dict[str, Any]:
696705
"epoch_number": self.epoch_number,
697706
"data": self.data.copy(),
698707
"rng_state": self.rng.bit_generator.state,
708+
"exclude_list": self.exclude_list.copy(),
699709
}
700710

701711
def set_state(self, state: dict[str, Any]) -> None:
@@ -704,6 +714,21 @@ def set_state(self, state: dict[str, Any]) -> None:
704714
self.epoch_number = state.get("epoch_number", 0)
705715
self.data = state["data"].copy()
706716
self.rng.bit_generator.state = state["rng_state"]
717+
self.exclude_list = state.get("exclude_list", [])
718+
self._update_effective_size()
719+
720+
def exclude_index(self, index: int) -> None:
721+
"""Exclude provided data points from future sampling."""
722+
self.exclude_list.append(index)
723+
724+
def _update_effective_size(self) -> None:
725+
"""Ensure the effective dataset size is divisible by batch_size and filter out all the indices excluded in the last epoch"""
726+
if self.exclude_list:
727+
mask = ~np.isin(self.data, self.exclude_list)
728+
self.data = self.data[mask]
729+
self.exclude_list = []
730+
731+
self.effective_size = len(self.data) - (len(self.data) % self.batch_size)
707732

708733

709734
@ray.remote(num_gpus=1)
@@ -749,7 +774,13 @@ def load(self, path: str, map_location=None):
749774

750775
deepspeed.init_distributed(timeout=timedelta(minutes=args.backend_timeout))
751776

752-
ds_config = get_train_ds_config(offload=False, adam_offload=False, stage=args.deepspeed_stage, bf16=True)
777+
ds_config = get_train_ds_config(
778+
offload=args.deepspeed_offload_param,
779+
adam_offload=args.deepspeed_offload_optimizer,
780+
stage=args.deepspeed_stage,
781+
bf16=True,
782+
zpg=args.deepspeed_zpg,
783+
)
753784
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
754785
ds_config["gradient_accumulation_steps"] = 1
755786
# @vwxyzjn: MAGIC: it's actually needed to initialize this `dschf`, so
@@ -844,7 +875,7 @@ def load(self, path: str, map_location=None):
844875

845876
# reference model
846877
ds_config = get_eval_ds_config(
847-
offload=False,
878+
offload=args.deepspeed_offload_param,
848879
# inference model only has stage 3 (sharding) or stage 0 (no sharding)
849880
# stage 2 is optimizer sharding which doesn't apply to inference
850881
stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0,
@@ -948,7 +979,7 @@ def setup_model_update_group(self, vllm_engines):
948979
group_name="openrlhf",
949980
timeout=timedelta(minutes=self.args.backend_timeout),
950981
)
951-
ray_get_with_progress(refs, desc="Initializing vLLM process groups", timeout=60)
982+
ray_get_with_progress(refs, desc="Initializing vLLM process groups", timeout=600)
952983
torch.distributed.barrier()
953984

954985
def broadcast_to_vllm(self):
@@ -1259,6 +1290,8 @@ def train(
12591290
args.masked_mean_denominator,
12601291
)
12611292
loss = loss / accumulation_steps
1293+
# Clear CUDA cache before backward pass to free memory for reduce_scatter operations
1294+
torch.cuda.empty_cache()
12621295
self.model.backward(loss)
12631296
if (local_step + 1) % accumulation_steps == 0:
12641297
self.model.step()
@@ -1633,6 +1666,8 @@ class BatchStatistics:
16331666
filtered_prompts_zero: int
16341667
filtered_prompts_solved: int
16351668
filtered_prompts_nonzero: int
1669+
percent_solved_mean: float
1670+
no_resampled_prompts: int
16361671

16371672

16381673
def accumulate_inference_batches(
@@ -1647,6 +1682,8 @@ def accumulate_inference_batches(
16471682
actor_manager=None,
16481683
timeout: float | None = None,
16491684
filter_zero_std_samples: bool = False,
1685+
no_resampling_pass_rate: float | None = None,
1686+
iter_dataloader: ShufflingIterator | None = None,
16501687
) -> tuple[GenerationResult, Batch, dict, BatchStatistics]:
16511688
"""Accumulate multiple inference results into a single training batch.
16521689
@@ -1657,6 +1694,10 @@ def accumulate_inference_batches(
16571694
generation_config: Generation config containing n (number of samples per prompt)
16581695
num_prompts: Number of prompts to accumulate
16591696
timeout: Optional timeout in seconds for queue get operations. If None, blocks indefinitely.
1697+
filter_zero_std_samples: Whether to filter samples with zero reward std and continue sampling
1698+
no_resampling_pass_rate: Optional rate at which to note samples solved at greater than this rate
1699+
and exclude them from further sampling
1700+
iter_dataloader: Optional, used for no_resampling_pass_rate
16601701
16611702
Raises:
16621703
queue.Empty: If timeout is specified and no data is available within timeout.
@@ -1673,10 +1714,12 @@ def accumulate_inference_batches(
16731714
all_decoded_responses = []
16741715
all_reward_metrics = []
16751716
all_scores = []
1717+
all_percent_solved = []
16761718
total_filtered_prompts = 0
16771719
filtered_prompt_zero = 0
16781720
filtered_prompt_solved = 0
16791721
filtered_prompt_nonzero = 0
1722+
total_no_resampled = 0
16801723
progress_bar = tqdm(
16811724
total=num_prompts,
16821725
desc=f"Accumulating Responses and Rewarding {num_prompts} prompts",
@@ -1725,6 +1768,15 @@ def accumulate_inference_batches(
17251768
)
17261769
)
17271770

1771+
percent_solved = np.mean(scores).item() / args.max_possible_score
1772+
# Don't resample prompt that was solved at more than no_resample_positive_rate
1773+
if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate:
1774+
iter_dataloader.exclude_index(result.dataset_index)
1775+
total_no_resampled += 1
1776+
logging.debug(
1777+
f"[Data Preparation Thread] Prompt solved at {percent_solved}, will be excluded from resampling, total no resampled: {total_no_resampled}"
1778+
)
1779+
17281780
# Filter out zero std prompts
17291781
if filter_zero_std_samples and np.std(scores) == 0:
17301782
total_filtered_prompts += 1
@@ -1747,6 +1799,7 @@ def accumulate_inference_batches(
17471799
all_decoded_responses.extend(decoded_responses)
17481800
all_scores.extend(scores)
17491801
all_reward_metrics.append(reward_metrics)
1802+
all_percent_solved.append(percent_solved)
17501803
progress_bar.update(1)
17511804

17521805
# Combine all results into a single GenerationResult
@@ -1841,6 +1894,7 @@ def accumulate_inference_batches(
18411894
)
18421895

18431896
combined_reward_metrics = combine_reward_metrics(all_reward_metrics)
1897+
percent_solved_mean = np.mean(all_percent_solved) if all_percent_solved else 0.0
18441898

18451899
batch_stats = BatchStatistics(
18461900
prompt_lengths=prompt_lengths,
@@ -1849,6 +1903,8 @@ def accumulate_inference_batches(
18491903
filtered_prompts_zero=filtered_prompt_zero,
18501904
filtered_prompts_solved=filtered_prompt_solved,
18511905
filtered_prompts_nonzero=filtered_prompt_nonzero,
1906+
percent_solved_mean=percent_solved_mean,
1907+
no_resampled_prompts=total_no_resampled,
18521908
)
18531909
logging.info(
18541910
f"[Data Preparation Thread] Calculating rewards took {combined_reward_metrics['time/reward']} seconds"
@@ -1867,6 +1923,7 @@ def data_preparation_thread(
18671923
num_training_steps: int,
18681924
generation_config,
18691925
resume_training_step: int,
1926+
iter_dataloader: ShufflingIterator,
18701927
actor_manager=None,
18711928
model_dims: utils.ModelDims = None,
18721929
):
@@ -1884,6 +1941,8 @@ def data_preparation_thread(
18841941
reward_fn=reward_fn,
18851942
actor_manager=actor_manager,
18861943
filter_zero_std_samples=args.active_sampling,
1944+
no_resampling_pass_rate=args.no_resampling_pass_rate,
1945+
iter_dataloader=iter_dataloader,
18871946
)
18881947
if isinstance(result, ShutdownSentinel):
18891948
logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting")
@@ -2580,6 +2639,8 @@ def maybe_evaluate(
25802639
actor_manager=actor_manager,
25812640
timeout=timeout,
25822641
filter_zero_std_samples=False,
2642+
no_resampling_pass_rate=None,
2643+
iter_dataloader=None,
25832644
)
25842645

25852646
logger.info("[Main Thread] 📊 Evaluation responses received")
@@ -2890,6 +2951,7 @@ def run_training(
28902951
args.num_training_steps,
28912952
generation_configs["train"],
28922953
resume_training_step,
2954+
iter_dataloader,
28932955
actor_manager,
28942956
model_dims,
28952957
)

open_instruct/test_grpo_fast.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def create_mock_args(self, num_engines=4, num_samples=1):
138138
mock_args.vllm_tensor_parallel_size = 1
139139
mock_args.num_samples_per_prompt_rollout = num_samples
140140
mock_args.verbose = False
141+
mock_args.max_possible_score = 1.0
141142
return mock_args
142143

143144
def create_mock_model_dims(self):

open_instruct/vllm_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,8 @@ def create_vllm_engines(
881881
)
882882
)
883883

884-
ray_get_with_progress([engine.ready.remote() for engine in vllm_engines], "Initializing vLLM engines", timeout=300)
884+
ray_get_with_progress(
885+
[engine.ready.remote() for engine in vllm_engines], "Initializing vLLM engines", timeout=1200
886+
)
885887

886888
return vllm_engines

scripts/eval/oe-eval.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ NEXT_MODEL_DEV=(
207207

208208
# Reasoning
209209
"bbh:cot::hamish_zs_reasoning_deepseek_v2" # OLD: "bbh:cot::hamish_zs_reasoning_deepseek"
210-
"gpqa:0shot_cot::hamish_zs_reasoning_deepseek"
210+
"gpqa:0shot_cot::qwen3-instruct"
211211
"zebralogic::hamish_zs_reasoning_deepseek"
212212
"agi_eval_english:0shot_cot::hamish_zs_reasoning_deepseek"
213213

scripts/train/debug/single_gpu_integration_test.sh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ uv run python mason.py \
3333
--per_device_train_batch_size 1 \
3434
--num_unique_prompts_rollout 8 \
3535
--num_samples_per_prompt_rollout 4 \
36-
--model_name_or_path Qwen/Qwen2.5-0.5B \
36+
--model_name_or_path Qwen/Qwen2.5-1.5B \
3737
--stop_strings "</answer>" \
3838
--apply_r1_style_format_reward \
3939
--apply_verifiable_reward true \
@@ -57,4 +57,6 @@ uv run python mason.py \
5757
--push_to_hub false \
5858
--active_sampling \
5959
--async_steps 8 \
60+
--no_resampling_pass_rate 0.6 \
61+
--verbose \
6062
--single_gpu_mode

0 commit comments

Comments
 (0)