Skip to content

Commit d18452e

Browse files
authored
Changes for 32B (#1164)
* new 32b script * new 32b script * beaker eval freq not upstreaming * new var * longer timeout on capturing cuda * longer timeout on capturing cuda * update params * reduce more * no optim * working script * zpg inc * newer changes * higher zpg * changes * fix * zpg as arg * debug * update * update * del tmp script
1 parent f799155 commit d18452e

File tree

3 files changed

+84
-50
lines changed

3 files changed

+84
-50
lines changed

open_instruct/grpo_fast.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,12 @@ class Args:
351351
"""vLLM top p for nucleus sampling"""
352352
deepspeed_stage: int = 0
353353
"""the deepspeed stage"""
354+
deepspeed_zpg: int = 8
355+
"""the deepspeed zpg value. Higher values are more memory efficient but slower. Set to 1 to disable zpg, which uses less memory but is significantly slower. Ideally is set to the number of GPUs per node (usually 8, default)."""
356+
deepspeed_offload_param: bool = False
357+
"""whether to offload parameters to CPU (reduces GPU memory usage)"""
358+
deepspeed_offload_optimizer: bool = False
359+
"""whether to offload optimizer states to CPU (reduces GPU memory usage)"""
354360
gather_whole_model: bool = True
355361
"""whether to gather the whole model to boardcast (not doable for 70B but can be faster for 8B)"""
356362
enable_queue_dashboard: bool = True
@@ -766,7 +772,13 @@ def load(self, path: str, map_location=None):
766772

767773
deepspeed.init_distributed(timeout=timedelta(minutes=args.backend_timeout))
768774

769-
ds_config = get_train_ds_config(offload=False, adam_offload=False, stage=args.deepspeed_stage, bf16=True)
775+
ds_config = get_train_ds_config(
776+
offload=args.deepspeed_offload_param,
777+
adam_offload=args.deepspeed_offload_optimizer,
778+
stage=args.deepspeed_stage,
779+
bf16=True,
780+
zpg=args.deepspeed_zpg,
781+
)
770782
ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
771783
ds_config["gradient_accumulation_steps"] = 1
772784
# @vwxyzjn: MAGIC: it's actually needed to initialize this `dschf`, so
@@ -861,7 +873,7 @@ def load(self, path: str, map_location=None):
861873

862874
# reference model
863875
ds_config = get_eval_ds_config(
864-
offload=False,
876+
offload=args.deepspeed_offload_param,
865877
# inference model only has stage 3 (sharding) or stage 0 (no sharding)
866878
# stage 2 is optimizer sharding which doesn't apply to inference
867879
stage=args.deepspeed_stage if args.deepspeed_stage == 3 else 0,
@@ -965,7 +977,7 @@ def setup_model_update_group(self, vllm_engines):
965977
group_name="openrlhf",
966978
timeout=timedelta(minutes=self.args.backend_timeout),
967979
)
968-
ray_get_with_progress(refs, desc="Initializing vLLM process groups", timeout=60)
980+
ray_get_with_progress(refs, desc="Initializing vLLM process groups", timeout=600)
969981
torch.distributed.barrier()
970982

971983
def broadcast_to_vllm(self):
@@ -1276,6 +1288,8 @@ def train(
12761288
args.masked_mean_denominator,
12771289
)
12781290
loss = loss / accumulation_steps
1291+
# Clear CUDA cache before backward pass to free memory for reduce_scatter operations
1292+
torch.cuda.empty_cache()
12791293
self.model.backward(loss)
12801294
if (local_step + 1) % accumulation_steps == 0:
12811295
self.model.step()

open_instruct/vllm_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,8 @@ def create_vllm_engines(
881881
)
882882
)
883883

884-
ray_get_with_progress([engine.ready.remote() for engine in vllm_engines], "Initializing vLLM engines", timeout=300)
884+
ray_get_with_progress(
885+
[engine.ready.remote() for engine in vllm_engines], "Initializing vLLM engines", timeout=1200
886+
)
885887

886888
return vllm_engines
Lines changed: 64 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,86 @@
11
#!/bin/bash
2-
# Note: This was originally a script that Saurabh came up to run some experiments.
3-
# Finbarr has been using it a lot for testing, so we thought we'd check it in.
4-
num_prompts=25376
5-
exp_name=rlvr_ace_fn_and_og_ocr_stdio_from_base_with_perf_penalty
6-
BEAKER_IMAGE="${1:-${BEAKER_USER}/open-instruct-integration-test}"
7-
uv run python mason.py \
8-
--cluster ai2/augusta \
9-
--image "$BEAKER_IMAGE" \
10-
--pure_docker_mode \
11-
--workspace ai2/open-instruct-dev \
12-
--gs_model_name "stego32" \
13-
--priority urgent \
14-
--preemptible \
15-
--num_nodes 4 \
16-
--description "Large (multi-node) test script." \
17-
--timeout 3600 \
18-
--max_retries 0 \
19-
--env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \
20-
--budget ai2/oe-adapt \
21-
--gpus 8 -- source configs/beaker_configs/ray_node_setup.sh \&\& source configs/beaker_configs/code_api_setup.sh \&\&python open_instruct/grpo_fast.py \
2+
3+
4+
export exp_name=test_olmo3_32b_rl_run_${RANDOM}
5+
export data_mix="hamishivi/math_rlvr_mixture_dpo 1.0 hamishivi/code_rlvr_mixture_dpo 1.0 hamishivi/IF_multi_constraints_upto5_filtered_dpo_0625_filter 30186 allenai/rlvr_general_mix-keyword-filtered 21387"
6+
export beaker_image=hamishivi/open_instruct_rl32_test10
7+
export model_path=/weka/oe-adapt-default/hamishi/model_checkpoints/olmo3-merge-32b-1e-4-5e-5/olmo3-merge-32b-1e-4-5e-5/
8+
9+
10+
python mason.py \
11+
--budget ai2/oe-adapt \
12+
--cluster ai2/augusta \
13+
--image ${beaker_image} \
14+
--pure_docker_mode \
15+
--workspace ai2/olmo-instruct \
16+
--priority urgent \
17+
--gs_model_name "sft_olmo3_32b_rl_run_testing" \
18+
--preemptible \
19+
--num_nodes 18 \
20+
--gpus 8 \
21+
--max_retries 0 \
22+
--env VLLM_ALLOW_LONG_MAX_MODEL_LEN=1 \
23+
--env LD_LIBRARY_PATH=/var/lib/tcpxo/lib64 \
24+
--env NCCL_LIB_DIR=/var/lib/tcpxo/lib64 \
25+
--env HOSTED_VLLM_API_BASE=http://ceres-cs-aus-447.reviz.ai2.in:8001/v1 \
26+
-- source configs/beaker_configs/ray_node_setup.sh \&\& source configs/beaker_configs/code_api_setup.sh \&\& python open_instruct/grpo_fast.py \
2227
--exp_name ${exp_name} \
2328
--beta 0.0 \
24-
--num_samples_per_prompt_rollout 16 \
29+
--num_samples_per_prompt_rollout 8 \
2530
--num_unique_prompts_rollout 64 \
2631
--num_mini_batches 1 \
2732
--num_epochs 1 \
28-
--learning_rate 5e-7 \
33+
--learning_rate 1e-6 \
2934
--per_device_train_batch_size 1 \
35+
--output_dir /output \
3036
--kl_estimator kl3 \
31-
--dataset_mixer_list saurabh5/rlvr_acecoder_filtered ${num_prompts} saurabh5/open-code-reasoning-rlvr-stdio ${num_prompts} \
37+
--dataset_mixer_list ${data_mix} \
3238
--dataset_mixer_list_splits train \
33-
--dataset_mixer_eval_list saurabh5/rlvr_acecoder_filtered 8 saurabh5/open-code-reasoning-rlvr-stdio 8 \
39+
--dataset_mixer_eval_list hamishivi/omega-combined 8 allenai/IF_multi_constraints_upto5 8 saurabh5/rlvr_acecoder_filtered 8 hamishivi/tulu_3_rewritten_400k_string_f1_only_v2_nocode_all_filtered_qwen2_5_openthoughts2 4 hamishivi/virtuoussy_multi_subject_rlvr 4 \
3440
--dataset_mixer_eval_list_splits train \
3541
--max_prompt_token_length 2048 \
36-
--response_length 4096 \
37-
--pack_length 20480 \
38-
--model_name_or_path "/weka/oe-adapt-default/finbarrt/stego32/step358000-hf" \
39-
--tokenizer_name_or_path "allenai/OLMo-2-1124-7B" \
40-
--chat_template_name tulu_thinker \
41-
--inflight_updates True \
42-
--stop_strings "</answer>" \
42+
--response_length 32768 \
43+
--pack_length 35840 \
44+
--model_name_or_path ${model_path} \
45+
--chat_template_name olmo_thinker \
4346
--non_stop_penalty False \
47+
--mask_truncated_completions False \
4448
--temperature 1.0 \
45-
--verbose False \
4649
--ground_truths_key ground_truth \
4750
--sft_messages_key messages \
48-
--total_episodes 10240 \
49-
--gather_whole_model False \
51+
--total_episodes 10000000 \
5052
--deepspeed_stage 3 \
51-
--num_learners_per_node 8 8 8 \
52-
--vllm_num_engines 2 \
53-
--vllm_tensor_parallel_size 4 \
53+
--num_learners_per_node 8 8 8 8 8 8 8 8 8 8 8 8 \
54+
--vllm_num_engines 6 \
55+
--gather_whole_model False \
56+
--vllm_tensor_parallel_size 8 \
5457
--lr_scheduler_type constant \
5558
--apply_verifiable_reward true \
56-
--code_api_url \$CODE_API_URL/test_program \
5759
--seed 1 \
58-
--local_eval_every 1 \
59-
--add_bos \
60-
--gradient_checkpointing \
60+
--local_eval_every 50 \
61+
--save_freq 25 \
62+
--eval_priority urgent \
6163
--try_launch_beaker_eval_jobs_on_weka True \
64+
--gradient_checkpointing \
6265
--with_tracking \
63-
--update_progress_every 1 \
64-
--vllm_enable_prefix_caching \
66+
--llm_judge_model hosted_vllm/Qwen/Qwen3-32B \
67+
--llm_judge_timeout 600 \
68+
--llm_judge_max_tokens 2048 \
69+
--llm_judge_max_context_length 32768 \
70+
--clip_higher 0.272 \
71+
--allow_world_padding False \
72+
--use_fp8_kv_cache False \
73+
--code_api_url https://p9f1719l7f.execute-api.us-west-2.amazonaws.com/prod/test_program \
74+
--code_pass_rate_reward_threshold 0.99 \
6575
--oe_eval_max_length 32768 \
66-
--oe_eval_tasks "codex_humanevalplus:0-shot-chat-v1::tulu-thinker,mbppplus:0-shot-chat::tulu-thinker,livecodebench_codegeneration::tulu-thinker" \
67-
--dataset_skip_cache True \
68-
--push_to_hub False
76+
--checkpoint_state_freq 100 \
77+
--backend_timeout 1200 \
78+
--inflight_updates true \
79+
--async_steps 8 \
80+
--active_sampling \
81+
--advantage_normalization_type centered \
82+
--truncated_importance_sampling_ratio_cap 2.0 \
83+
--oe_eval_beaker_image oe-eval-beaker/oe_eval_olmo2_retrofit_auto \
84+
--oe_eval_tasks mmlu:cot::hamish_zs_reasoning_deepseek,bbh:cot::hamish_zs_reasoning_deepseek_v2,gpqa:0shot_cot::qwen3-instruct,zebralogic::hamish_zs_reasoning_deepseek,agi_eval_english:0shot_cot::hamish_zs_reasoning_deepseek,omega_500:0-shot-chat_deepseek,aime:zs_cot_r1::pass_at_32_2024_deepseek,aime:zs_cot_r1::pass_at_32_2025_deepseek,codex_humanevalplus:0-shot-chat::tulu-thinker_deepseek,mbppplus:0-shot-chat::tulu-thinker_deepseek,livecodebench_codegeneration::tulu-thinker_deepseek,alpaca_eval_v3::hamish_zs_reasoning_deepseek,ifeval::hamish_zs_reasoning_deepseek \
85+
--vllm_enforce_eager \
86+
--deepspeed_zpg 32

0 commit comments

Comments
 (0)