@@ -270,6 +270,8 @@ class Args:
270270
271271 active_sampling : bool = False
272272 """Whether to continue sampling responses until you get a full batch."""
273+ no_resampling_pass_rate : float | None = None
274+ """If the response to a prompt is solved at a rate higher than this, do not resample this prompt again"""
273275
274276 record_entropy : bool = False
275277 """whether to record the entropy of the policy during training. Uses extra memory."""
@@ -349,6 +351,12 @@ class Args:
349351 """vLLM top p for nucleus sampling"""
350352 deepspeed_stage : int = 0
351353 """the deepspeed stage"""
354+ deepspeed_zpg : int = 8
355+ """the deepspeed zpg value. Higher values are more memory efficient but slower. Set to 1 to disable zpg, which uses less memory but is significantly slower. Ideally is set to the number of GPUs per node (usually 8, default)."""
356+ deepspeed_offload_param : bool = False
357+ """whether to offload parameters to CPU (reduces GPU memory usage)"""
358+ deepspeed_offload_optimizer : bool = False
359+ """whether to offload optimizer states to CPU (reduces GPU memory usage)"""
352360 gather_whole_model : bool = True
353361 """whether to gather the whole model to boardcast (not doable for 70B but can be faster for 8B)"""
354362 enable_queue_dashboard : bool = True
@@ -670,16 +678,17 @@ def __init__(self, data: np.ndarray, batch_size: int, seed: int | None = None):
670678 self .epoch_number = 0
671679 self .rng = np .random .default_rng (seed )
672680 self .rng .shuffle (self .data )
681+ self .exclude_list = []
673682
674- # Ensure the effective dataset size is divisible by batch_size
675- self .effective_size = len (self .data ) - (len (self .data ) % batch_size )
683+ self ._update_effective_size ()
676684
677685 def __iter__ (self ) -> Iterator [list [int ]]:
678686 return self
679687
680688 def __next__ (self ) -> list [int ]:
681689 if self .index >= self .effective_size :
682690 self .index = 0
691+ self ._update_effective_size ()
683692 self .epoch_number += 1
684693 self .rng .shuffle (self .data )
685694
@@ -696,6 +705,7 @@ def get_state(self) -> dict[str, Any]:
696705 "epoch_number" : self .epoch_number ,
697706 "data" : self .data .copy (),
698707 "rng_state" : self .rng .bit_generator .state ,
708+ "exclude_list" : self .exclude_list .copy (),
699709 }
700710
701711 def set_state (self , state : dict [str , Any ]) -> None :
@@ -704,6 +714,21 @@ def set_state(self, state: dict[str, Any]) -> None:
704714 self .epoch_number = state .get ("epoch_number" , 0 )
705715 self .data = state ["data" ].copy ()
706716 self .rng .bit_generator .state = state ["rng_state" ]
717+ self .exclude_list = state .get ("exclude_list" , [])
718+ self ._update_effective_size ()
719+
720+ def exclude_index (self , index : int ) -> None :
721+ """Exclude provided data points from future sampling."""
722+ self .exclude_list .append (index )
723+
724+ def _update_effective_size (self ) -> None :
725+ """Ensure the effective dataset size is divisible by batch_size and filter out all the indices excluded in the last epoch"""
726+ if self .exclude_list :
727+ mask = ~ np .isin (self .data , self .exclude_list )
728+ self .data = self .data [mask ]
729+ self .exclude_list = []
730+
731+ self .effective_size = len (self .data ) - (len (self .data ) % self .batch_size )
707732
708733
709734@ray .remote (num_gpus = 1 )
@@ -749,7 +774,13 @@ def load(self, path: str, map_location=None):
749774
750775 deepspeed .init_distributed (timeout = timedelta (minutes = args .backend_timeout ))
751776
752- ds_config = get_train_ds_config (offload = False , adam_offload = False , stage = args .deepspeed_stage , bf16 = True )
777+ ds_config = get_train_ds_config (
778+ offload = args .deepspeed_offload_param ,
779+ adam_offload = args .deepspeed_offload_optimizer ,
780+ stage = args .deepspeed_stage ,
781+ bf16 = True ,
782+ zpg = args .deepspeed_zpg ,
783+ )
753784 ds_config ["train_micro_batch_size_per_gpu" ] = args .per_device_train_batch_size
754785 ds_config ["gradient_accumulation_steps" ] = 1
755786 # @vwxyzjn: MAGIC: it's actually needed to initialize this `dschf`, so
@@ -844,7 +875,7 @@ def load(self, path: str, map_location=None):
844875
845876 # reference model
846877 ds_config = get_eval_ds_config (
847- offload = False ,
878+ offload = args . deepspeed_offload_param ,
848879 # inference model only has stage 3 (sharding) or stage 0 (no sharding)
849880 # stage 2 is optimizer sharding which doesn't apply to inference
850881 stage = args .deepspeed_stage if args .deepspeed_stage == 3 else 0 ,
@@ -948,7 +979,7 @@ def setup_model_update_group(self, vllm_engines):
948979 group_name = "openrlhf" ,
949980 timeout = timedelta (minutes = self .args .backend_timeout ),
950981 )
951- ray_get_with_progress (refs , desc = "Initializing vLLM process groups" , timeout = 60 )
982+ ray_get_with_progress (refs , desc = "Initializing vLLM process groups" , timeout = 600 )
952983 torch .distributed .barrier ()
953984
954985 def broadcast_to_vllm (self ):
@@ -1259,6 +1290,8 @@ def train(
12591290 args .masked_mean_denominator ,
12601291 )
12611292 loss = loss / accumulation_steps
1293+ # Clear CUDA cache before backward pass to free memory for reduce_scatter operations
1294+ torch .cuda .empty_cache ()
12621295 self .model .backward (loss )
12631296 if (local_step + 1 ) % accumulation_steps == 0 :
12641297 self .model .step ()
@@ -1633,6 +1666,8 @@ class BatchStatistics:
16331666 filtered_prompts_zero : int
16341667 filtered_prompts_solved : int
16351668 filtered_prompts_nonzero : int
1669+ percent_solved_mean : float
1670+ no_resampled_prompts : int
16361671
16371672
16381673def accumulate_inference_batches (
@@ -1647,6 +1682,8 @@ def accumulate_inference_batches(
16471682 actor_manager = None ,
16481683 timeout : float | None = None ,
16491684 filter_zero_std_samples : bool = False ,
1685+ no_resampling_pass_rate : float | None = None ,
1686+ iter_dataloader : ShufflingIterator | None = None ,
16501687) -> tuple [GenerationResult , Batch , dict , BatchStatistics ]:
16511688 """Accumulate multiple inference results into a single training batch.
16521689
@@ -1657,6 +1694,10 @@ def accumulate_inference_batches(
16571694 generation_config: Generation config containing n (number of samples per prompt)
16581695 num_prompts: Number of prompts to accumulate
16591696 timeout: Optional timeout in seconds for queue get operations. If None, blocks indefinitely.
1697+ filter_zero_std_samples: Whether to filter samples with zero reward std and continue sampling
1698+ no_resampling_pass_rate: Optional rate at which to note samples solved at greater than this rate
1699+ and exclude them from further sampling
1700+ iter_dataloader: Optional, used for no_resampling_pass_rate
16601701
16611702 Raises:
16621703 queue.Empty: If timeout is specified and no data is available within timeout.
@@ -1673,10 +1714,12 @@ def accumulate_inference_batches(
16731714 all_decoded_responses = []
16741715 all_reward_metrics = []
16751716 all_scores = []
1717+ all_percent_solved = []
16761718 total_filtered_prompts = 0
16771719 filtered_prompt_zero = 0
16781720 filtered_prompt_solved = 0
16791721 filtered_prompt_nonzero = 0
1722+ total_no_resampled = 0
16801723 progress_bar = tqdm (
16811724 total = num_prompts ,
16821725 desc = f"Accumulating Responses and Rewarding { num_prompts } prompts" ,
@@ -1725,6 +1768,15 @@ def accumulate_inference_batches(
17251768 )
17261769 )
17271770
1771+ percent_solved = np .mean (scores ).item () / args .max_possible_score
1772+ # Don't resample prompt that was solved at more than no_resample_positive_rate
1773+ if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate :
1774+ iter_dataloader .exclude_index (result .dataset_index )
1775+ total_no_resampled += 1
1776+ logging .debug (
1777+ f"[Data Preparation Thread] Prompt solved at { percent_solved } , will be excluded from resampling, total no resampled: { total_no_resampled } "
1778+ )
1779+
17281780 # Filter out zero std prompts
17291781 if filter_zero_std_samples and np .std (scores ) == 0 :
17301782 total_filtered_prompts += 1
@@ -1747,6 +1799,7 @@ def accumulate_inference_batches(
17471799 all_decoded_responses .extend (decoded_responses )
17481800 all_scores .extend (scores )
17491801 all_reward_metrics .append (reward_metrics )
1802+ all_percent_solved .append (percent_solved )
17501803 progress_bar .update (1 )
17511804
17521805 # Combine all results into a single GenerationResult
@@ -1841,6 +1894,7 @@ def accumulate_inference_batches(
18411894 )
18421895
18431896 combined_reward_metrics = combine_reward_metrics (all_reward_metrics )
1897+ percent_solved_mean = np .mean (all_percent_solved ) if all_percent_solved else 0.0
18441898
18451899 batch_stats = BatchStatistics (
18461900 prompt_lengths = prompt_lengths ,
@@ -1849,6 +1903,8 @@ def accumulate_inference_batches(
18491903 filtered_prompts_zero = filtered_prompt_zero ,
18501904 filtered_prompts_solved = filtered_prompt_solved ,
18511905 filtered_prompts_nonzero = filtered_prompt_nonzero ,
1906+ percent_solved_mean = percent_solved_mean ,
1907+ no_resampled_prompts = total_no_resampled ,
18521908 )
18531909 logging .info (
18541910 f"[Data Preparation Thread] Calculating rewards took { combined_reward_metrics ['time/reward' ]} seconds"
@@ -1867,6 +1923,7 @@ def data_preparation_thread(
18671923 num_training_steps : int ,
18681924 generation_config ,
18691925 resume_training_step : int ,
1926+ iter_dataloader : ShufflingIterator ,
18701927 actor_manager = None ,
18711928 model_dims : utils .ModelDims = None ,
18721929):
@@ -1884,6 +1941,8 @@ def data_preparation_thread(
18841941 reward_fn = reward_fn ,
18851942 actor_manager = actor_manager ,
18861943 filter_zero_std_samples = args .active_sampling ,
1944+ no_resampling_pass_rate = args .no_resampling_pass_rate ,
1945+ iter_dataloader = iter_dataloader ,
18871946 )
18881947 if isinstance (result , ShutdownSentinel ):
18891948 logger .info ("[Data Preparation Thread] Received shutdown sentinel, exiting" )
@@ -2580,6 +2639,8 @@ def maybe_evaluate(
25802639 actor_manager = actor_manager ,
25812640 timeout = timeout ,
25822641 filter_zero_std_samples = False ,
2642+ no_resampling_pass_rate = None ,
2643+ iter_dataloader = None ,
25832644 )
25842645
25852646 logger .info ("[Main Thread] 📊 Evaluation responses received" )
@@ -2890,6 +2951,7 @@ def run_training(
28902951 args .num_training_steps ,
28912952 generation_configs ["train" ],
28922953 resume_training_step ,
2954+ iter_dataloader ,
28932955 actor_manager ,
28942956 model_dims ,
28952957 )
0 commit comments