6666import torch .utils .data
6767import vllm
6868import wandb
69+ from datasets import Dataset
6970from huggingface_hub import HfApi
7071from peft import PeftModel , get_peft_model_state_dict
7172from ray .util import queue as ray_queue
@@ -539,20 +540,6 @@ def __post_init__(self):
539540 )
540541
541542
542- def next_batch (dataset_indices : list [int ], dataset : datasets .Dataset ) -> Batch :
543- """Extract next batch of data based on indices."""
544- data_next = dataset [dataset_indices ]
545- return Batch (
546- queries = data_next [INPUT_IDS_PROMPT_KEY ],
547- ground_truths = data_next [GROUND_TRUTHS_KEY ],
548- datasets = data_next [VERIFIER_SOURCE_KEY ],
549- raw_queries = data_next [RAW_PROMPT_KEY ],
550- indices = dataset_indices ,
551- decoded_responses = None ,
552- scores = None ,
553- )
554-
555-
556543def masked_mean (
557544 values : torch .Tensor , mask : torch .Tensor , axis : int | None = None , denominator : float | None = None
558545) -> torch .Tensor :
@@ -673,7 +660,8 @@ def __init__(self, data: np.ndarray, batch_size: int, seed: int | None = None):
673660 def __iter__ (self ) -> Iterator [list [int ]]:
674661 return self
675662
676- def __next__ (self ) -> list [int ]:
663+ def __next__ (self ) -> list [int ] | int :
664+ """Return a list of next indices or a single index if batch size is 1"""
677665 if self .index >= self .effective_size :
678666 self .index = 0
679667 self ._update_effective_size ()
@@ -682,6 +670,8 @@ def __next__(self) -> list[int]:
682670
683671 end_index = self .index + self .batch_size
684672 batch = self .data [self .index : end_index ].tolist ()
673+ if self .batch_size == 1 :
674+ batch = batch [0 ]
685675 self .index = end_index
686676
687677 return batch
@@ -1672,8 +1662,12 @@ def accumulate_inference_batches(
16721662 timeout : float | None = None ,
16731663 active_sampling : bool = False ,
16741664 filter_zero_std_samples : bool = False ,
1665+ replenish_prompts : bool = False ,
16751666 no_resampling_pass_rate : float | None = None ,
16761667 iter_dataloader : ShufflingIterator | None = None ,
1668+ prompt_dataset : Dataset = None ,
1669+ param_prompt_Q : ray_queue .Queue | None = None ,
1670+ training_step : int = None ,
16771671) -> tuple [GenerationResult , Batch , dict , BatchStatistics ]:
16781672 """Accumulate multiple inference results into a single training batch.
16791673
@@ -1686,9 +1680,11 @@ def accumulate_inference_batches(
16861680 timeout: Optional timeout in seconds for queue get operations. If None, blocks indefinitely.
16871681 active_sampling: Whether to continue sampling until we have sampled num_prompts prompts with non-zero std
16881682 filter_zero_std_samples: Whether to filter samples with zero reward std
1683+ replenish_prompts: Add a prompt back onto the prompt_Q after receiving a finished result
16891684 no_resampling_pass_rate: Optional rate at which to note samples solved at greater than this rate
16901685 and exclude them from further sampling
16911686 iter_dataloader: Optional, used for no_resampling_pass_rate
1687+ param_prompt_Q: Queue containing prompts to send to generator, used to replenish used prompts
16921688
16931689 Raises:
16941690 queue.Empty: If timeout is specified and no data is available within timeout.
@@ -1697,6 +1693,14 @@ def accumulate_inference_batches(
16971693 Tuple of (combined_result, Batch with queries, ground_truths, datasets, prompt_lengths, response_lengths)
16981694 or (ShutdownSentinel, None, None, None) if shutdown signal received
16991695 """
1696+ if no_resampling_pass_rate is not None :
1697+ assert iter_dataloader is not None , "no_resampling requires the iter_dataloader passed"
1698+
1699+ if replenish_prompts :
1700+ assert param_prompt_Q is not None and iter_dataloader is not None and prompt_dataset is not None , (
1701+ "replenish_prompts requires param_prompt_Q and iter_dataloader and prompt_dataset"
1702+ )
1703+
17001704 results = []
17011705 all_queries = []
17021706 all_ground_truths = []
@@ -1731,7 +1735,21 @@ def accumulate_inference_batches(
17311735 f"Dataset index: { result .dataset_index } , Epoch: { result .epoch_number } "
17321736 )
17331737
1734- query , ground_truth , dataset , raw_query = pending_queries_map .pop (result .dataset_index )
1738+ query , ground_truth , dataset_name , raw_query = pending_queries_map .pop (result .dataset_index )
1739+
1740+ # Replenish generation queue with new prompt
1741+ if replenish_prompts :
1742+ dataset_index = next (iter_dataloader )
1743+ add_prompt_to_generator (
1744+ prompt_dataset [dataset_index ],
1745+ dataset_index ,
1746+ iter_dataloader .epoch_number ,
1747+ training_step ,
1748+ pending_queries_map ,
1749+ param_prompt_Q ,
1750+ generation_config ,
1751+ is_eval = False ,
1752+ )
17351753
17361754 # TODO(finbarrtimbers): Move this to LLMRayActor.
17371755 for i in range (len (result .finish_reasons )):
@@ -1745,7 +1763,7 @@ def accumulate_inference_batches(
17451763 # TODO(finbarrtimbers): Make PendingQueriesMap.pop return a Batch, and add a Batch.repeat method.
17461764 k_queries = repeat_each ([query ], generation_config .n )
17471765 k_ground_truths = repeat_each ([ground_truth ], generation_config .n )
1748- k_datasets = repeat_each ([dataset ], generation_config .n )
1766+ k_datasets = repeat_each ([dataset_name ], generation_config .n )
17491767 k_raw_queries = repeat_each ([raw_query ], generation_config .n )
17501768
17511769 scores , reward_metrics = asyncio .run (
@@ -1916,6 +1934,7 @@ def accumulate_inference_batches(
19161934def data_preparation_thread (
19171935 reward_fn : Callable ,
19181936 inference_results_Q : ray_queue .Queue , # Ray queue
1937+ param_prompt_Q : ray_queue .Queue ,
19191938 packed_sequences_Q : Queue ,
19201939 pending_queries_map : dict ,
19211940 args : Args ,
@@ -1924,6 +1943,7 @@ def data_preparation_thread(
19241943 generation_config ,
19251944 resume_training_step : int ,
19261945 iter_dataloader : ShufflingIterator ,
1946+ train_dataset : Dataset ,
19271947 actor_manager = None ,
19281948 model_dims : utils .ModelDims = None ,
19291949):
@@ -1942,8 +1962,12 @@ def data_preparation_thread(
19421962 actor_manager = actor_manager ,
19431963 active_sampling = args .active_sampling ,
19441964 filter_zero_std_samples = args .filter_zero_std_samples ,
1965+ replenish_prompts = True ,
19451966 no_resampling_pass_rate = args .no_resampling_pass_rate ,
19461967 iter_dataloader = iter_dataloader ,
1968+ prompt_dataset = train_dataset ,
1969+ param_prompt_Q = param_prompt_Q ,
1970+ training_step = training_step ,
19471971 )
19481972 if isinstance (result , ShutdownSentinel ):
19491973 logger .info ("[Data Preparation Thread] Received shutdown sentinel, exiting" )
@@ -2366,8 +2390,9 @@ def create_generation_configs(args: Args):
23662390 return {"train" : generation_config , "eval" : eval_generation_config }
23672391
23682392
2369- def split_and_insert_batch (
2370- batch : Batch ,
2393+ def add_prompt_to_generator (
2394+ example : dict [str , Any ],
2395+ example_index : int ,
23712396 epoch_number : int ,
23722397 training_step : int ,
23732398 pending_queries_map : PendingQueriesMap ,
@@ -2376,20 +2401,22 @@ def split_and_insert_batch(
23762401 is_eval : bool ,
23772402) -> None :
23782403 """Split a batch into multiple inference batches and insert individual prompts into queues and mapping."""
2379- for idx , query , ground_truth , dataset , raw_query in zip (
2380- batch .indices , batch .queries , batch .ground_truths , batch .datasets , batch .raw_queries
2381- ):
2382- pending_queries_map .insert (idx , query , ground_truth , dataset , raw_query )
2383- param_prompt_Q .put (
2384- PromptRequest (
2385- prompt = query ,
2386- generation_config = generation_config ,
2387- epoch_number = epoch_number ,
2388- training_step = training_step ,
2389- dataset_index = idx ,
2390- is_eval = is_eval ,
2391- )
2404+ query = example [INPUT_IDS_PROMPT_KEY ]
2405+ ground_truth = example [GROUND_TRUTHS_KEY ]
2406+ dataset_name = example [VERIFIER_SOURCE_KEY ]
2407+ raw_query = example [RAW_PROMPT_KEY ]
2408+ pending_queries_map .insert (example_index , query , ground_truth , dataset_name , raw_query )
2409+
2410+ param_prompt_Q .put (
2411+ PromptRequest (
2412+ prompt = query ,
2413+ generation_config = generation_config ,
2414+ epoch_number = epoch_number ,
2415+ training_step = training_step ,
2416+ dataset_index = example_index ,
2417+ is_eval = is_eval ,
23922418 )
2419+ )
23932420
23942421
23952422def load_data_from_packing_thread (
@@ -2641,8 +2668,7 @@ def maybe_evaluate(
26412668 timeout = timeout ,
26422669 active_sampling = False ,
26432670 filter_zero_std_samples = False ,
2644- no_resampling_pass_rate = None ,
2645- iter_dataloader = None ,
2671+ replenish_prompts = False ,
26462672 )
26472673
26482674 logger .info ("[Main Thread] 📊 Evaluation responses received" )
@@ -2896,7 +2922,7 @@ def run_training(
28962922 args ,
28972923 tokenizer ,
28982924 train_dataset ,
2899- eval_batch ,
2925+ eval_dataset ,
29002926 policy_group ,
29012927 vllm_engines ,
29022928 generation_configs ,
@@ -2946,6 +2972,7 @@ def run_training(
29462972 data_preparation_thread ,
29472973 reward_fn ,
29482974 inference_results_Q ,
2975+ param_prompt_Q ,
29492976 packed_sequences_Q ,
29502977 pending_queries_map ,
29512978 args ,
@@ -2954,6 +2981,7 @@ def run_training(
29542981 generation_configs ["train" ],
29552982 resume_training_step ,
29562983 iter_dataloader ,
2984+ train_dataset ,
29572985 actor_manager ,
29582986 model_dims ,
29592987 )
@@ -2967,11 +2995,11 @@ def health_check_fn():
29672995 )
29682996
29692997 # Send initial data to ensure we have a N-step offset.
2970- for _ in range (args .async_steps ):
2971- dataset_indices = next (iter_dataloader )
2972- batch = next_batch ( dataset_indices , train_dataset )
2973- split_and_insert_batch (
2974- batch ,
2998+ for _ in range (args .async_steps * args . num_unique_prompts_rollout ):
2999+ dataset_index = next (iter_dataloader )
3000+ add_prompt_to_generator (
3001+ train_dataset [ dataset_index ],
3002+ dataset_index ,
29753003 iter_dataloader .epoch_number ,
29763004 resume_training_step ,
29773005 pending_queries_map ,
@@ -2985,7 +3013,6 @@ def health_check_fn():
29853013 else :
29863014 num_total_tokens = 0
29873015
2988- num_prompts_to_refill = 0
29893016 training_start_time = time .perf_counter () # Track overall training start time
29903017 for training_step in range (resume_training_step , args .num_training_steps + 1 ):
29913018 start_time = time .perf_counter ()
@@ -3017,35 +3044,22 @@ def health_check_fn():
30173044 num_filtered_prompts ,
30183045 ) = load_data_from_packing_thread (packed_sequences_Q , num_total_tokens , stop_event , health_check_fn )
30193046
3020- num_prompts_to_refill += args .num_unique_prompts_rollout + num_filtered_prompts
3021-
3022- while num_prompts_to_refill >= args .num_unique_prompts_rollout :
3023- batch = next_batch (next (iter_dataloader ), train_dataset )
3024- split_and_insert_batch (
3025- batch ,
3026- iter_dataloader .epoch_number ,
3027- training_step ,
3028- pending_queries_map ,
3029- param_prompt_Q ,
3030- generation_configs ["train" ],
3031- is_eval = False ,
3032- )
3033- num_prompts_to_refill -= args .num_unique_prompts_rollout
3034-
30353047 if (
30363048 training_step % args .local_eval_every == 0
3037- and eval_batch is not None
3049+ and eval_dataset is not None
30383050 and (args .eval_on_step_0 or training_step > 1 )
30393051 ):
3040- split_and_insert_batch (
3041- eval_batch ,
3042- iter_dataloader .epoch_number ,
3043- training_step ,
3044- eval_pending_queries_map ,
3045- param_prompt_Q ,
3046- generation_configs ["eval" ],
3047- is_eval = True ,
3048- )
3052+ for eval_index , eval_example in enumerate (eval_dataset ):
3053+ add_prompt_to_generator (
3054+ eval_example ,
3055+ eval_index ,
3056+ iter_dataloader .epoch_number ,
3057+ training_step ,
3058+ eval_pending_queries_map ,
3059+ param_prompt_Q ,
3060+ generation_configs ["eval" ],
3061+ is_eval = True ,
3062+ )
30493063 if collated_data is None :
30503064 continue
30513065
@@ -3122,7 +3136,7 @@ def health_check_fn():
31223136 eval_pending_queries_map ,
31233137 generation_configs ["eval" ],
31243138 generate_metrics_Q ,
3125- len (eval_batch . queries ) if eval_batch else 0 ,
3139+ len (eval_dataset ) if eval_dataset else 0 ,
31263140 model_dims ,
31273141 actor_manager ,
31283142 )
@@ -3199,7 +3213,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
31993213 logger .info (f"Restored episode count: { episode } " )
32003214
32013215 train_dataset_idxs = np .arange (len (train_dataset ))
3202- iter_dataloader = ShufflingIterator (train_dataset_idxs , args . num_unique_prompts_rollout , seed = args .seed )
3216+ iter_dataloader = ShufflingIterator (train_dataset_idxs , 1 , seed = args .seed )
32033217
32043218 if checkpoint_state and "shuffling_iterator_state" in checkpoint_state :
32053219 iter_dataloader .set_state (checkpoint_state ["shuffling_iterator_state" ])
@@ -3212,11 +3226,6 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
32123226 generate_metrics_Q = Queue (maxsize = args .async_steps )
32133227 weight_sync_metrics_Q = Queue (maxsize = args .async_steps )
32143228
3215- if eval_dataset is None :
3216- eval_batch = None
3217- else :
3218- eval_dataset_indices = list (range (len (eval_dataset )))
3219- eval_batch = next_batch (eval_dataset_indices , eval_dataset )
32203229 reward_fn = make_reward_fn (args )
32213230
32223231 stop_event = threading .Event ()
@@ -3227,7 +3236,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
32273236 args ,
32283237 tokenizer ,
32293238 train_dataset ,
3230- eval_batch ,
3239+ eval_dataset ,
32313240 policy_group ,
32323241 vllm_engines ,
32333242 generation_configs ,
0 commit comments