From ca8370d88e1d99a23edeaa819e967c2d81266343 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 09:03:35 -0700 Subject: [PATCH 01/96] now we have an impl --- open_instruct/grpo_fast.py | 116 +++++++++++++++++++++++-------------- 1 file changed, 74 insertions(+), 42 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 13ee119ee..af1983c6d 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1013,18 +1013,52 @@ def update_ref_policy(self): else: ref_param.data.mul_(1.0 - self.args.alpha).add_(param.data, alpha=self.args.alpha) - def train( + def setup_dataloader( self, - collated_query_responses, - collated_tool_masks, - collated_attention_masks, - collated_position_ids, - collated_advantages, - collated_response_masks, - collated_vllm_logprobs, - pad_token_id: int, - num_mini_batches: int, + dataset: Dataset, + reward_fn: Callable, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + pending_queries_map: dict, + generation_config, + resume_training_step: int, + num_training_steps: int, + worker_data_queues: list, + main_thread_metrics_queue: ray_queue.Queue, + actor_manager=None, + model_dims: utils.ModelDims = None, ): + from open_instruct.streaming_dataloader import StreamingDataLoader + + self.dataloader = StreamingDataLoader( + dataset=dataset, + rank=self.local_rank, + world_size=self.args.world_size, + reward_fn=reward_fn, + inference_results_Q=inference_results_Q, + param_prompt_Q=param_prompt_Q, + pending_queries_map=pending_queries_map, + tokenizer=self.tokenizer, + args=self.args, + generation_config=generation_config, + resume_training_step=resume_training_step, + num_training_steps=num_training_steps, + actor_manager=actor_manager, + model_dims=model_dims, + worker_data_queues=worker_data_queues, + main_thread_metrics_queue=main_thread_metrics_queue, + ) + + def train(self, pad_token_id: int, num_mini_batches: int): + batch_data = next(self.dataloader) + collated_query_responses = batch_data["collated_query_responses"] + collated_tool_masks = batch_data["collated_tool_masks"] + collated_attention_masks = batch_data["collated_attention_masks"] + collated_position_ids = batch_data["collated_position_ids"] + collated_advantages = batch_data["collated_advantages"] + collated_response_masks = batch_data["collated_response_masks"] + collated_vllm_logprobs = batch_data["collated_vllm_logprobs"] + args = self.args to_device_inplace(collated_query_responses, self.device) to_device_inplace(collated_tool_masks, self.device) @@ -2522,7 +2556,6 @@ def weight_sync_thread( def one_training_step( args: Args, policy_group: ModelGroup, - collated_data: list[dict[str, list[torch.Tensor]]], tokenizer: PreTrainedTokenizer, data_thread_metrics: dict[str, Any], episode: int, @@ -2546,7 +2579,7 @@ def one_training_step( metrics_list, _ = ray_get_with_progress( [ policy_group.models[i].train.remote( - **collated_data[i], pad_token_id=tokenizer.pad_token_id, num_mini_batches=args.num_mini_batches + pad_token_id=tokenizer.pad_token_id, num_mini_batches=args.num_mini_batches ) for i in range(args.world_size) ], @@ -2967,27 +3000,32 @@ def run_training( [engine.ready.remote() for engine in vllm_engines], "Checking engines are ready to work", timeout=300 ) - logger.info("======== ✅ data preparation thread starts =========") - packing_future = executor.submit( - data_preparation_thread, - reward_fn, - inference_results_Q, - param_prompt_Q, - packed_sequences_Q, - pending_queries_map, - args, - tokenizer, - args.num_training_steps, - generation_configs["train"], - resume_training_step, - iter_dataloader, - train_dataset, - actor_manager, - model_dims, + logger.info("======== ✅ Setting up dataloaders =========") + worker_data_queues = [ray_queue.Queue(maxsize=args.async_steps) for _ in range(args.world_size)] + main_thread_metrics_queue = ray_queue.Queue(maxsize=args.async_steps) + ray_get_with_progress( + [ + policy_group.models[i].setup_dataloader.remote( + dataset=train_dataset, + reward_fn=reward_fn, + inference_results_Q=inference_results_Q, + param_prompt_Q=param_prompt_Q, + pending_queries_map=pending_queries_map, + generation_config=generation_configs["train"], + resume_training_step=resume_training_step, + num_training_steps=args.num_training_steps, + worker_data_queues=worker_data_queues, + main_thread_metrics_queue=main_thread_metrics_queue, + actor_manager=actor_manager, + model_dims=model_dims, + ) + for i in range(args.world_size) + ], + desc="Setting up dataloaders for all workers", ) def health_check_fn(): - [f.result() for f in [packing_future, weight_sync_thread_future] if f.done()] + [f.result() for f in [weight_sync_thread_future] if f.done()] ray_get_with_progress( [engine.check_background_threads.remote() for engine in vllm_engines], desc="Checking vLLM engine health", @@ -3034,15 +3072,11 @@ def health_check_fn(): health_check_fn() health_check_time = time.perf_counter() - health_check_start - ( - collated_data, - data_thread_metrics, - num_total_tokens, - num_step_tokens, - prompt_lengths, - response_lengths, - num_filtered_prompts, - ) = load_data_from_packing_thread(packed_sequences_Q, num_total_tokens, stop_event, health_check_fn) + batch_metadata = main_thread_metrics_queue.get() + num_step_tokens = batch_metadata["num_new_tokens"] + num_total_tokens += num_step_tokens + prompt_lengths = batch_metadata["prompt_lengths"] + response_lengths = batch_metadata["response_lengths"] if ( training_step % args.local_eval_every == 0 @@ -3060,11 +3094,10 @@ def health_check_fn(): generation_configs["eval"], is_eval=True, ) - if collated_data is None: - continue episode += args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout + data_thread_metrics = batch_metadata["metrics"] for metrics_Q in [generate_metrics_Q, weight_sync_metrics_Q]: try: data_thread_metrics |= metrics_Q.get_nowait() @@ -3076,7 +3109,6 @@ def health_check_fn(): one_training_step( args, policy_group, - collated_data, tokenizer, data_thread_metrics, episode, From aba89488777014ba3325ab83dbf2d09ac5fca638 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 09:56:12 -0700 Subject: [PATCH 02/96] Cleaned up code --- open_instruct/grpo_fast.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index af1983c6d..3b3ca7329 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1023,8 +1023,8 @@ def setup_dataloader( generation_config, resume_training_step: int, num_training_steps: int, - worker_data_queues: list, - main_thread_metrics_queue: ray_queue.Queue, + work_dir: str, + global_batch_size: int, actor_manager=None, model_dims: utils.ModelDims = None, ): @@ -1032,8 +1032,6 @@ def setup_dataloader( self.dataloader = StreamingDataLoader( dataset=dataset, - rank=self.local_rank, - world_size=self.args.world_size, reward_fn=reward_fn, inference_results_Q=inference_results_Q, param_prompt_Q=param_prompt_Q, @@ -1041,12 +1039,15 @@ def setup_dataloader( tokenizer=self.tokenizer, args=self.args, generation_config=generation_config, + work_dir=work_dir, + global_batch_size=global_batch_size, resume_training_step=resume_training_step, num_training_steps=num_training_steps, actor_manager=actor_manager, model_dims=model_dims, - worker_data_queues=worker_data_queues, - main_thread_metrics_queue=main_thread_metrics_queue, + dp_world_size=self.args.world_size, + dp_rank=self.local_rank, + fs_local_rank=self.local_rank, ) def train(self, pad_token_id: int, num_mini_batches: int): @@ -3001,8 +3002,6 @@ def run_training( ) logger.info("======== ✅ Setting up dataloaders =========") - worker_data_queues = [ray_queue.Queue(maxsize=args.async_steps) for _ in range(args.world_size)] - main_thread_metrics_queue = ray_queue.Queue(maxsize=args.async_steps) ray_get_with_progress( [ policy_group.models[i].setup_dataloader.remote( @@ -3014,8 +3013,8 @@ def run_training( generation_config=generation_configs["train"], resume_training_step=resume_training_step, num_training_steps=args.num_training_steps, - worker_data_queues=worker_data_queues, - main_thread_metrics_queue=main_thread_metrics_queue, + work_dir=args.output_dir, + global_batch_size=args.num_unique_prompts_rollout, actor_manager=actor_manager, model_dims=model_dims, ) @@ -3072,12 +3071,6 @@ def health_check_fn(): health_check_fn() health_check_time = time.perf_counter() - health_check_start - batch_metadata = main_thread_metrics_queue.get() - num_step_tokens = batch_metadata["num_new_tokens"] - num_total_tokens += num_step_tokens - prompt_lengths = batch_metadata["prompt_lengths"] - response_lengths = batch_metadata["response_lengths"] - if ( training_step % args.local_eval_every == 0 and eval_dataset is not None @@ -3097,7 +3090,7 @@ def health_check_fn(): episode += args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout - data_thread_metrics = batch_metadata["metrics"] + data_thread_metrics = {} for metrics_Q in [generate_metrics_Q, weight_sync_metrics_Q]: try: data_thread_metrics |= metrics_Q.get_nowait() @@ -3106,6 +3099,10 @@ def health_check_fn(): data_thread_metrics["time/health_check"] = health_check_time + num_step_tokens = 0 + prompt_lengths = [] + response_lengths = [] + one_training_step( args, policy_group, From 180272e54ae3b9ac236b0225c9b571ed65782847 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 10:23:57 -0700 Subject: [PATCH 03/96] updated code --- open_instruct/grpo_fast.py | 45 ++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3b3ca7329..3e7c8097e 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -718,7 +718,7 @@ def from_pretrained( beaker_config: BeakerRuntimeConfig, wandb_url: str, tokenizer: PreTrainedTokenizer, - ): + ) -> int: # ------------------------------------------------------------ # Monkey patch to load checkpoints with `weights_only=False` # otherwise it errors out with: @@ -1013,41 +1013,30 @@ def update_ref_policy(self): else: ref_param.data.mul_(1.0 - self.args.alpha).add_(param.data, alpha=self.args.alpha) - def setup_dataloader( + def build_dataloader( self, + data_loader_config, dataset: Dataset, reward_fn: Callable, inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, pending_queries_map: dict, generation_config, - resume_training_step: int, - num_training_steps: int, - work_dir: str, - global_batch_size: int, - actor_manager=None, - model_dims: utils.ModelDims = None, + actor_manager, + model_dims: utils.ModelDims, ): - from open_instruct.streaming_dataloader import StreamingDataLoader - - self.dataloader = StreamingDataLoader( + self.dataloader = data_loader_config.build( dataset=dataset, reward_fn=reward_fn, inference_results_Q=inference_results_Q, param_prompt_Q=param_prompt_Q, pending_queries_map=pending_queries_map, tokenizer=self.tokenizer, - args=self.args, generation_config=generation_config, - work_dir=work_dir, - global_batch_size=global_batch_size, - resume_training_step=resume_training_step, - num_training_steps=num_training_steps, - actor_manager=actor_manager, - model_dims=model_dims, - dp_world_size=self.args.world_size, dp_rank=self.local_rank, fs_local_rank=self.local_rank, + actor_manager=actor_manager, + model_dims=model_dims, ) def train(self, pad_token_id: int, num_mini_batches: int): @@ -3002,19 +2991,27 @@ def run_training( ) logger.info("======== ✅ Setting up dataloaders =========") + from open_instruct.streaming_data_loader import StreamingDataLoaderConfig + + data_loader_config = StreamingDataLoaderConfig( + work_dir=args.output_dir, + global_batch_size=args.num_unique_prompts_rollout, + dp_world_size=args.world_size, + resume_training_step=resume_training_step, + num_training_steps=args.num_training_steps, + args=args, + ) + ray_get_with_progress( [ - policy_group.models[i].setup_dataloader.remote( + policy_group.models[i].build_dataloader.remote( + data_loader_config=data_loader_config, dataset=train_dataset, reward_fn=reward_fn, inference_results_Q=inference_results_Q, param_prompt_Q=param_prompt_Q, pending_queries_map=pending_queries_map, generation_config=generation_configs["train"], - resume_training_step=resume_training_step, - num_training_steps=args.num_training_steps, - work_dir=args.output_dir, - global_batch_size=args.num_unique_prompts_rollout, actor_manager=actor_manager, model_dims=model_dims, ) From 86964108e219e983522a31c1f59c9e794c3955fc Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 10:28:51 -0700 Subject: [PATCH 04/96] Added streaming data loader file. --- open_instruct/streaming_data_loader.py | 559 +++++++++++++++++++++++++ 1 file changed, 559 insertions(+) create mode 100644 open_instruct/streaming_data_loader.py diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py new file mode 100644 index 000000000..1bf016d72 --- /dev/null +++ b/open_instruct/streaming_data_loader.py @@ -0,0 +1,559 @@ +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +from abc import abstractmethod +from collections.abc import Callable, Iterable +from dataclasses import asdict, dataclass +from pathlib import Path +from queue import Queue as StdQueue +from typing import Any + +import numpy as np +import torch +from datasets import Dataset +from ray.util import queue as ray_queue +from transformers import PreTrainedTokenizer + +from open_instruct import utils +from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences + +logger = logging.getLogger(__name__) + +PathOrStr = Path | str + + +@dataclass +class StreamingDataLoaderConfig: + work_dir: PathOrStr + global_batch_size: int + dp_world_size: int + resume_training_step: int + num_training_steps: int + args: Any + + def build( + self, + dataset: Dataset, + reward_fn: Callable, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + pending_queries_map: dict, + tokenizer: PreTrainedTokenizer, + generation_config: Any, + dp_rank: int, + fs_local_rank: int, + actor_manager=None, + model_dims: utils.ModelDims | None = None, + ) -> "StreamingDataLoader": + return StreamingDataLoader( + dataset=dataset, + reward_fn=reward_fn, + inference_results_Q=inference_results_Q, + param_prompt_Q=param_prompt_Q, + pending_queries_map=pending_queries_map, + tokenizer=tokenizer, + args=self.args, + generation_config=generation_config, + work_dir=self.work_dir, + global_batch_size=self.global_batch_size, + resume_training_step=self.resume_training_step, + num_training_steps=self.num_training_steps, + actor_manager=actor_manager, + model_dims=model_dims, + dp_world_size=self.dp_world_size, + dp_rank=dp_rank, + fs_local_rank=fs_local_rank, + ) + + +class DataLoaderBase: + def __init__( + self, + *, + work_dir: PathOrStr, + global_batch_size: int, + dp_world_size: int = 1, + dp_rank: int = 0, + fs_local_rank: int = 0, + ): + self.work_dir = Path(work_dir) + self._global_batch_size = global_batch_size + self.dp_world_size = dp_world_size + self.dp_rank = dp_rank + self.fs_local_rank = fs_local_rank + self.batches_processed = 0 + self.epoch: int | None = None + + @property + def global_batch_size(self) -> int: + return self._global_batch_size + + @global_batch_size.setter + def global_batch_size(self, value: int): + self._global_batch_size = value + + @property + def rank_batch_size(self) -> int: + return self.global_batch_size // self.dp_world_size + + @property + @abstractmethod + def total_batches(self) -> int | None: + pass + + @abstractmethod + def state_dict(self) -> dict[str, Any]: + pass + + @abstractmethod + def load_state_dict(self, state_dict: dict[str, Any]): + pass + + @abstractmethod + def reshuffle(self, epoch: int | None = None, **kwargs): + pass + + @abstractmethod + def _iter_batches(self) -> Iterable[dict[str, Any]]: + pass + + @abstractmethod + def get_mock_batch(self) -> dict[str, Any]: + pass + + def __iter__(self): + return self._iter_batches() + + def __next__(self): + if not hasattr(self, "_iterator"): + self._iterator = self._iter_batches() + return next(self._iterator) + + def reset(self): + if hasattr(self, "_iterator"): + del self._iterator + self.batches_processed = 0 + + +class TextDataLoaderBase(DataLoaderBase): + def __init__( + self, + *, + work_dir: PathOrStr, + global_batch_size: int, + dp_world_size: int = 1, + dp_rank: int = 0, + fs_local_rank: int = 0, + ): + super().__init__( + work_dir=work_dir, + global_batch_size=global_batch_size, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + fs_local_rank=fs_local_rank, + ) + self.tokens_processed: int = 0 + + def reset(self): + super().reset() + self.tokens_processed = 0 + + def global_num_tokens_in_batch(self, batch: dict[str, Any]) -> int | None: + del batch + return self.global_batch_size + + +class ShufflingIterator: + def __init__(self, data: np.ndarray, batch_size: int, seed: int | None = None): + self.data = data.copy() + self.batch_size = batch_size + self.index = 0 + self.epoch_number = 0 + self.rng = np.random.default_rng(seed) + self.rng.shuffle(self.data) + self.exclude_list = [] + + self._update_effective_size() + + def __iter__(self): + return self + + def __next__(self) -> list[int] | int: + if self.index >= self.effective_size: + self.index = 0 + self._update_effective_size() + self.epoch_number += 1 + self.rng.shuffle(self.data) + + end_index = self.index + self.batch_size + batch = self.data[self.index : end_index].tolist() + if self.batch_size == 1: + batch = batch[0] + self.index = end_index + + return batch + + def get_state(self) -> dict[str, Any]: + return { + "index": self.index, + "epoch_number": self.epoch_number, + "data": self.data.copy(), + "rng_state": self.rng.bit_generator.state, + "exclude_list": self.exclude_list.copy(), + } + + def set_state(self, state: dict[str, Any]) -> None: + self.index = state["index"] + self.epoch_number = state.get("epoch_number", 0) + self.data = state["data"].copy() + self.rng.bit_generator.state = state["rng_state"] + self.exclude_list = state.get("exclude_list", []) + self._update_effective_size() + + def exclude_index(self, index: int) -> None: + self.exclude_list.append(index) + + def _update_effective_size(self) -> None: + if self.exclude_list: + mask = ~np.isin(self.data, self.exclude_list) + self.data = self.data[mask] + self.exclude_list = [] + + self.effective_size = len(self.data) - (len(self.data) % self.batch_size) + + +class StreamingDataLoader(TextDataLoaderBase): + def __init__( + self, + *, + dataset: Dataset, + reward_fn: Callable, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + pending_queries_map: dict, + tokenizer: PreTrainedTokenizer, + args: Any, + generation_config: Any, + work_dir: PathOrStr, + global_batch_size: int, + resume_training_step: int = 0, + num_training_steps: int = 0, + actor_manager=None, + model_dims: utils.ModelDims = None, + dp_world_size: int = 1, + dp_rank: int = 0, + fs_local_rank: int = 0, + ): + super().__init__( + work_dir=work_dir, + global_batch_size=global_batch_size, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + fs_local_rank=fs_local_rank, + ) + + self.dataset = dataset + self.reward_fn = reward_fn + self.inference_results_Q = inference_results_Q + self.param_prompt_Q = param_prompt_Q + self.pending_queries_map = pending_queries_map + self.tokenizer = tokenizer + self.args = args + self.generation_config = generation_config + self.num_training_steps = num_training_steps + self.actor_manager = actor_manager + self.model_dims = model_dims + + self.training_step = resume_training_step + self.current_epoch = 0 + + dataset_indices = np.arange(len(dataset)) + self.iter_dataloader = ShufflingIterator(dataset_indices, 1, seed=args.seed + dp_rank) + + self.local_queue = StdQueue(maxsize=args.async_steps) + self.background_thread = None + self.shutdown_requested = False + + @property + def total_batches(self) -> int | None: + return self.num_training_steps + + def state_dict(self) -> dict[str, Any]: + return { + "training_step": self.training_step, + "current_epoch": self.current_epoch, + "iter_dataloader_state": self.iter_dataloader.get_state(), + } + + def load_state_dict(self, state_dict: dict[str, Any]): + self.training_step = state_dict["training_step"] + self.current_epoch = state_dict.get("current_epoch", 0) + self.iter_dataloader.set_state(state_dict["iter_dataloader_state"]) + + def reshuffle(self, epoch: int | None = None, **kwargs): + if epoch is not None: + self.current_epoch = epoch + + def get_mock_batch(self) -> dict[str, Any]: + dummy_qr = torch.tensor([self.tokenizer.pad_token_id, self.tokenizer.eos_token_id], dtype=torch.long) + dummy_tool_mask = torch.zeros_like(dummy_qr) + dummy_attention = torch.tensor([1, 1], dtype=torch.long) + dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long) + dummy_response_mask = torch.zeros_like(dummy_qr) + dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float) + + return { + "collated_query_responses": [dummy_qr], + "collated_tool_masks": [dummy_tool_mask], + "collated_attention_masks": [dummy_attention], + "collated_position_ids": [dummy_position_ids], + "collated_advantages": [dummy_advantage], + "collated_response_masks": [dummy_response_mask], + "collated_vllm_logprobs": [torch.zeros_like(dummy_qr, dtype=torch.float)], + } + + def _iter_batches(self) -> Iterable[dict[str, Any]]: + if self.background_thread is None: + self._start_background_thread() + + while self.training_step < self.num_training_steps: + batch_data = self.local_queue.get() + self.training_step += 1 + yield batch_data + + def _start_background_thread(self): + self.shutdown_requested = False + self.background_thread = threading.Thread( + target=self._data_preparation_loop, daemon=True, name=f"DataLoader-Worker-Rank{self.dp_rank}" + ) + self.background_thread.start() + + def _data_preparation_loop(self): + from open_instruct.grpo_fast import accumulate_inference_batches + from open_instruct.queue_types import ShutdownSentinel + + for training_step in range(self.training_step, self.num_training_steps): + if self.shutdown_requested: + logger.info(f"[DataLoader Worker {self.dp_rank}] Shutdown requested, exiting") + return + + with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: + result, batch, reward_metrics, batch_stats = accumulate_inference_batches( + self.inference_results_Q, + self.pending_queries_map, + self.args, + self.generation_config, + num_prompts=self.rank_batch_size, + model_dims=self.model_dims, + tokenizer=self.tokenizer, + reward_fn=self.reward_fn, + actor_manager=self.actor_manager, + active_sampling=self.args.active_sampling, + filter_zero_std_samples=self.args.filter_zero_std_samples, + replenish_prompts=True, + no_resampling_pass_rate=self.args.no_resampling_pass_rate, + iter_dataloader=self.iter_dataloader, + prompt_dataset=self.dataset, + param_prompt_Q=self.param_prompt_Q, + training_step=training_step, + ) + if isinstance(result, ShutdownSentinel): + logger.info(f"[DataLoader Worker {self.dp_rank}] Received shutdown sentinel, exiting") + return + + getting_response_time = timer.duration + scores = np.array(batch.scores) + + good_outputs = [ + len(result.request_info.tool_outputs[i]) > 0 + and result.request_info.tool_calleds[i] + and not result.request_info.timeouts[i] + and not result.request_info.tool_errors[i] + for i in range(len(result.request_info.tool_outputs)) + ] + scores_per_prompt = scores.reshape(-1, self.args.num_samples_per_prompt_rollout) + mean_grouped_rewards = scores_per_prompt.mean(axis=-1) + mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.args.num_samples_per_prompt_rollout, axis=0) + std_grouped_rewards = scores_per_prompt.std(axis=-1) + std_grouped_rewards = np.repeat(std_grouped_rewards, self.args.num_samples_per_prompt_rollout, axis=0) + if self.args.advantage_normalization_type == "standard": + advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) + elif self.args.advantage_normalization_type == "centered": + advantages = scores - mean_grouped_rewards + else: + raise ValueError(f"Invalid advantage normalization type: {self.args.advantage_normalization_type}") + + if self.args.mask_truncated_completions: + stop_idxes = torch.tensor( + [i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"] + ) + num_truncated = len(result.finish_reasons) - len(stop_idxes) + if num_truncated > 0: + logger.info( + f"[Truncated completions filtering] Filtered {num_truncated} responses that didn't finish with 'stop'. " + f"Retention rate: {len(stop_idxes) / len(result.finish_reasons):.2%}" + ) + scores = scores[stop_idxes] + advantages = advantages[stop_idxes] + batch = batch[stop_idxes.tolist()] + result.responses = [result.responses[i] for i in stop_idxes] + result.masks = [result.masks[i] for i in stop_idxes] + result.finish_reasons = [result.finish_reasons[i] for i in stop_idxes] + result.logprobs = [result.logprobs[i] for i in stop_idxes] + + with Timer("📦 [Data Preparation Thread] Packing sequences"): + packed_sequences = pack_sequences( + queries=batch.queries, + responses=result.responses, + masks=result.masks, + pack_length=self.args.pack_length, + pad_token_id=self.tokenizer.pad_token_id, + vllm_logprobs=result.logprobs, + ) + lookup_advantages = np.zeros(len(advantages) + 1, dtype=np.float32) + lookup_advantages[1:] = advantages + packed_advantages = [ + torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) + for packed_mask in packed_sequences.response_masks + ] + packed_sequences.advantages = packed_advantages + + collated_data = self._prepare_collated_data_for_self(packed_sequences) + + if len(result.responses) == 0: + metrics = {} + logger.warning(f"No responses in batch {training_step}.") + else: + real_num_responses = len(result.responses) + expected_num_responses = ( + self.args.num_samples_per_prompt_rollout * self.args.num_unique_prompts_rollout + ) + + unsolved_num_responses = (scores < self.args.max_possible_score).sum() + sequence_lengths = np.array([len(response) for response in result.responses]) + sequence_length_solved = ( + np.array([]) + if np.all(scores == 0) + else np.array(sequence_lengths[scores == self.args.max_possible_score]) + ) + sequence_length_unsolved = ( + np.array([]) + if np.all(scores == self.args.max_possible_score) + else np.array(sequence_lengths[scores == 0]) + ) + stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( + result.finish_reasons + ) + + batch_metrics = asdict(batch_stats) + batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()} + + metrics = { + "scores": scores.mean(), + "real_batch_size_ratio": real_num_responses / expected_num_responses, + "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, + "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, + "val/solve_rate_hist": None, + "val/total_reward_groups": real_num_responses / self.args.num_samples_per_prompt_rollout, + "val/sequence_lengths": sequence_lengths.mean(), + "val/sequence_lengths_min": sequence_lengths.min(), + "val/sequence_lengths_max": sequence_lengths.max(), + "val/sequence_lengths_unsolved": ( + 0 if len(sequence_length_unsolved) == 0 else sequence_length_unsolved.mean() + ), + "val/sequence_lengths_solved": ( + 0 if len(sequence_length_solved) == 0 else sequence_length_solved.mean() + ), + "val/sequence_lengths_unsolved_hist": sequence_length_unsolved, + "val/sequence_lengths_solved_hist": sequence_length_solved, + "val/stop_rate": stop_rate, + "val/advantages_mean": advantages.mean(), + "val/advantages_min": advantages.min(), + "val/advantages_max": advantages.max(), + "val/advantages_hist": advantages, + "val/num_calls_rate": np.array(result.request_info.num_calls).mean(), + "val/timeouts_rate": np.array(result.request_info.timeouts).mean(), + "val/tool_errors_rate": np.array( + [len(item) > 0 for item in result.request_info.tool_errors] + ).mean(), + "val/good_outputs_rate": np.array(good_outputs).mean(), + "val/tool_runtimes_rate": np.array(result.request_info.tool_runtimes).mean(), + "val/tool_calleds_rate": np.array(result.request_info.tool_calleds).mean(), + "time/getting_response": getting_response_time, + **reward_metrics, + **batch_metrics_prefixed, + } + + total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens + metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time + + self.local_queue.put(collated_data) + + def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> dict[str, list[torch.Tensor]]: + from open_instruct.grpo_fast import collate_fn + + per_device_packed_query_responses = packed_sequences.query_responses + per_device_packed_tool_masks = packed_sequences.tool_masks + per_device_packed_attention_masks = packed_sequences.attention_masks + per_device_packed_position_ids = packed_sequences.position_ids + per_device_packed_advantages = packed_sequences.advantages + per_device_packed_response_masks = packed_sequences.response_masks + per_device_packed_vllm_logprobs = packed_sequences.vllm_logprobs + + b_inds = np.random.permutation(len(per_device_packed_query_responses)) + collated_query_responses = [] + collated_tool_masks = [] + collated_attention_masks = [] + collated_position_ids = [] + collated_response_masks = [] + collated_advantages = [] + collated_vllm_logprobs = [] + for j in range(0, len(per_device_packed_query_responses), self.args.per_device_train_batch_size): + micro_range = b_inds[j : j + self.args.per_device_train_batch_size] + collated_query_responses.append( + collate_fn( + [per_device_packed_query_responses[idx] for idx in micro_range], self.tokenizer.pad_token_id, True + ) + ) + collated_tool_masks.append(collate_fn([per_device_packed_tool_masks[idx] for idx in micro_range], 0, True)) + collated_attention_masks.append( + collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0, True) + ) + collated_position_ids.append( + collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0, True) + ) + collated_response_masks.append( + collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0, True) + ) + collated_advantages.append(collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0, True)) + collated_vllm_logprobs.append( + collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0, True) + ) + + return { + "collated_query_responses": collated_query_responses, + "collated_tool_masks": collated_tool_masks, + "collated_attention_masks": collated_attention_masks, + "collated_position_ids": collated_position_ids, + "collated_advantages": collated_advantages, + "collated_response_masks": collated_response_masks, + "collated_vllm_logprobs": collated_vllm_logprobs, + } + + def shutdown(self): + self.shutdown_requested = True + if self.background_thread is not None: + self.background_thread.join(timeout=5.0) From 9cc8cabc84a1d6d4f27f300a895b3dd6e43a2964 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 11:12:53 -0700 Subject: [PATCH 05/96] Cleaned up init order --- open_instruct/grpo_fast.py | 228 ++++++++++++++++++++++++------------- 1 file changed, 147 insertions(+), 81 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3e7c8097e..1a7cfe94a 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -36,7 +36,7 @@ with contextlib.suppress(Exception): import deepspeed -from open_instruct import utils +from open_instruct import streaming_data_loader, utils # isort: on import asyncio @@ -711,6 +711,39 @@ def _update_effective_size(self) -> None: @ray.remote(num_gpus=1) class PolicyTrainerRayProcess(RayProcess): + def __init__( + self, + world_size: int, + rank: int, + local_rank: int, + master_addr: str | None, + master_port: int | None, + data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, + dataset: Dataset, + reward_fn: Callable, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + pending_queries_map: dict, + tokenizer: PreTrainedTokenizer, + generation_config, + actor_manager, + model_dims: utils.ModelDims, + ): + super().__init__(world_size, rank, local_rank, master_addr, master_port) + self.dataloader = data_loader_config.build( + dataset=dataset, + reward_fn=reward_fn, + inference_results_Q=inference_results_Q, + param_prompt_Q=param_prompt_Q, + pending_queries_map=pending_queries_map, + tokenizer=tokenizer, + generation_config=generation_config, + dp_rank=self.local_rank, + fs_local_rank=self.local_rank, + actor_manager=actor_manager, + model_dims=model_dims, + ) + def from_pretrained( self, args: Args, @@ -1013,32 +1046,6 @@ def update_ref_policy(self): else: ref_param.data.mul_(1.0 - self.args.alpha).add_(param.data, alpha=self.args.alpha) - def build_dataloader( - self, - data_loader_config, - dataset: Dataset, - reward_fn: Callable, - inference_results_Q: ray_queue.Queue, - param_prompt_Q: ray_queue.Queue, - pending_queries_map: dict, - generation_config, - actor_manager, - model_dims: utils.ModelDims, - ): - self.dataloader = data_loader_config.build( - dataset=dataset, - reward_fn=reward_fn, - inference_results_Q=inference_results_Q, - param_prompt_Q=param_prompt_Q, - pending_queries_map=pending_queries_map, - tokenizer=self.tokenizer, - generation_config=generation_config, - dp_rank=self.local_rank, - fs_local_rank=self.local_rank, - actor_manager=actor_manager, - model_dims=model_dims, - ) - def train(self, pad_token_id: int, num_mini_batches: int): batch_data = next(self.dataloader) collated_query_responses = batch_data["collated_query_responses"] @@ -1487,7 +1494,21 @@ def launch_ai2_evals_on_weka_wrapper(self, step_dir, leaderboard_name, wandb_url class ModelGroup: def __init__( - self, pg: PlacementGroup, ray_process_cls: RayProcess, num_gpus_per_node: list[int], single_gpu_mode: bool + self, + pg: PlacementGroup, + ray_process_cls: RayProcess, + num_gpus_per_node: list[int], + single_gpu_mode: bool, + data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, + dataset: Dataset, + reward_fn: Callable, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + pending_queries_map: dict, + tokenizer: PreTrainedTokenizer, + generation_config, + actor_manager, + model_dims: utils.ModelDims, ): self.pg = pg self.ray_process_cls = ray_process_cls @@ -1502,7 +1523,23 @@ def __init__( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=self.pg, placement_group_bundle_index=0 ), - ).remote(world_size, 0, 0, None, None) + ).remote( + world_size, + 0, + 0, + None, + None, + data_loader_config, + dataset, + reward_fn, + inference_results_Q, + param_prompt_Q, + pending_queries_map, + tokenizer, + generation_config, + actor_manager, + model_dims, + ) self.models.append(master_policy) results, _ = ray_get_with_progress( @@ -1535,7 +1572,23 @@ def get_bundle_index(rank, num_gpus_per_node): num_cpus=self.num_cpus_per_actor, num_gpus=self.num_gpus_per_actor, scheduling_strategy=scheduling_strategy, - ).remote(world_size, rank, 0, master_addr, master_port) + ).remote( + world_size, + rank, + 0, + master_addr, + master_port, + data_loader_config, + dataset, + reward_fn, + inference_results_Q, + param_prompt_Q, + pending_queries_map, + tokenizer, + generation_config, + actor_manager, + model_dims, + ) self.models.append(worker_policy) @@ -2294,19 +2347,17 @@ def create_model_and_optimizer( inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, evaluation_inference_results_Q: ray_queue.Queue, -) -> tuple[ModelGroup, list[vllm_utils.LLMRayActor], dict, int, int]: + data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, + train_dataset: Dataset, + reward_fn: Callable, + pending_queries_map: dict, + generation_config, +) -> tuple[ModelGroup, list[vllm_utils.LLMRayActor], dict, int, int, ActorManager, utils.ModelDims]: """Create the model, optimizer, and vLLM engines.""" # Create placement group bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.num_learners_per_node] pg = placement_group(bundles, strategy="STRICT_SPREAD") ray_get_with_progress([pg.ready()], desc="Waiting for placement group") - inits = [] - policy_group = ModelGroup(pg, PolicyTrainerRayProcess, args.num_learners_per_node, args.single_gpu_mode) - wandb_url = wandb.run.get_url() if args.with_tracking else None - inits.extend( - model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer) - for model in policy_group.models - ) # Set up tools max_len = args.max_prompt_token_length + args.response_length @@ -2367,10 +2418,9 @@ def create_model_and_optimizer( inflight_updates=args.inflight_updates, ) - results, _ = ray_get_with_progress(inits, desc="Initializing models") - resume_training_step = results[0] + 1 - episode = (resume_training_step - 1) * args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout - logger.info("======== ✅ all models and vLLM engines initialized =========") + # Get model dimensions from vLLM engine + model_dims = ray.get(vllm_engines[0].get_model_dims.remote()) + logger.info("======== ✅ vLLM engines and actor_manager initialized =========") # Get and set KV cache max concurrency from the first engine (all engines have the same config) # fp8 kv cache for now forces v0 engine and breaks this. @@ -2381,13 +2431,42 @@ def create_model_and_optimizer( # dummy value ray.get(actor_manager.set_kv_cache_max_concurrency.remote(-1)) + # Now create policy actors with all dependencies + wandb_url = wandb.run.get_url() if args.with_tracking else None + policy_group = ModelGroup( + pg, + PolicyTrainerRayProcess, + args.num_learners_per_node, + args.single_gpu_mode, + data_loader_config=data_loader_config, + dataset=train_dataset, + reward_fn=reward_fn, + inference_results_Q=inference_results_Q, + param_prompt_Q=param_prompt_Q, + pending_queries_map=pending_queries_map, + tokenizer=tokenizer, + generation_config=generation_config, + actor_manager=actor_manager, + model_dims=model_dims, + ) + + inits = [ + model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer) + for model in policy_group.models + ] + + results, _ = ray_get_with_progress(inits, desc="Initializing models") + resume_training_step = results[0] + 1 + episode = (resume_training_step - 1) * args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout + logger.info("======== ✅ all models initialized =========") + ray_get_with_progress( [m.setup_model_update_group.remote(vllm_engines=vllm_engines) for m in policy_group.models], desc="Setting up model update group", ) logger.info("======== ✅ model update group setup successfully =========") - return policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager + return policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims def create_generation_configs(args: Args): @@ -2990,35 +3069,7 @@ def run_training( [engine.ready.remote() for engine in vllm_engines], "Checking engines are ready to work", timeout=300 ) - logger.info("======== ✅ Setting up dataloaders =========") - from open_instruct.streaming_data_loader import StreamingDataLoaderConfig - - data_loader_config = StreamingDataLoaderConfig( - work_dir=args.output_dir, - global_batch_size=args.num_unique_prompts_rollout, - dp_world_size=args.world_size, - resume_training_step=resume_training_step, - num_training_steps=args.num_training_steps, - args=args, - ) - - ray_get_with_progress( - [ - policy_group.models[i].build_dataloader.remote( - data_loader_config=data_loader_config, - dataset=train_dataset, - reward_fn=reward_fn, - inference_results_Q=inference_results_Q, - param_prompt_Q=param_prompt_Q, - pending_queries_map=pending_queries_map, - generation_config=generation_configs["train"], - actor_manager=actor_manager, - model_dims=model_dims, - ) - for i in range(args.world_size) - ], - desc="Setting up dataloaders for all workers", - ) + logger.info("======== ✅ Dataloaders already initialized in actors =========") def health_check_fn(): [f.result() for f in [weight_sync_thread_future] if f.done()] @@ -3208,7 +3259,25 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): # We don't care if we ever hit the max, so we let the queue be unbounded. evaluation_inference_results_Q = ray_queue.Queue() - policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager = ( + # Create dataloader dependencies before model creation + pending_queries_map = PendingQueriesMap() + reward_fn = make_reward_fn(args) + generation_configs = create_generation_configs(args) + + # Get resume_training_step to create data_loader_config + # We need to temporarily estimate this; it will be corrected after model init + resume_training_step_estimate = 0 + + data_loader_config = streaming_data_loader.StreamingDataLoaderConfig( + work_dir=args.output_dir, + global_batch_size=args.num_unique_prompts_rollout, + dp_world_size=args.world_size, + resume_training_step=resume_training_step_estimate, + num_training_steps=args.num_training_steps, + args=args, + ) + + (policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims) = ( create_model_and_optimizer( args, tc, @@ -3219,14 +3288,14 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): inference_results_Q, param_prompt_Q, evaluation_inference_results_Q, + data_loader_config, + train_dataset, + reward_fn, + pending_queries_map, + generation_configs["train"], ) ) - # Get the model dimensions from one of the engines without loading weights - model_dims = ray.get(vllm_engines[0].get_model_dims.remote()) - - generation_configs = create_generation_configs(args) - checkpoint_state = None if args.checkpoint_state_dir and os.path.exists(args.checkpoint_state_dir): # Try to load the checkpoint state from the first rank @@ -3247,13 +3316,10 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): # Create additional queues (main queues already created above) packed_sequences_Q = Queue(maxsize=args.async_steps) - pending_queries_map = PendingQueriesMap() eval_pending_queries_map = PendingQueriesMap() generate_metrics_Q = Queue(maxsize=args.async_steps) weight_sync_metrics_Q = Queue(maxsize=args.async_steps) - reward_fn = make_reward_fn(args) - stop_event = threading.Event() executor = futures.ThreadPoolExecutor(max_workers=3, thread_name_prefix="grpo") From 4a68d0cc215b5a6f8f970287673f3f74bdaca697 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 11:24:52 -0700 Subject: [PATCH 06/96] Clean up --- open_instruct/grpo_fast.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 1a7cfe94a..f32f14980 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -730,6 +730,9 @@ def __init__( model_dims: utils.ModelDims, ): super().__init__(world_size, rank, local_rank, master_addr, master_port) + self.tokenizer = tokenizer + self.pad_token_id = tokenizer.pad_token_id + self.num_mini_batches = data_loader_config.args.num_mini_batches self.dataloader = data_loader_config.build( dataset=dataset, reward_fn=reward_fn, @@ -1046,7 +1049,7 @@ def update_ref_policy(self): else: ref_param.data.mul_(1.0 - self.args.alpha).add_(param.data, alpha=self.args.alpha) - def train(self, pad_token_id: int, num_mini_batches: int): + def step(self): batch_data = next(self.dataloader) collated_query_responses = batch_data["collated_query_responses"] collated_tool_masks = batch_data["collated_tool_masks"] @@ -1065,7 +1068,7 @@ def train(self, pad_token_id: int, num_mini_batches: int): to_device_inplace(collated_response_masks, self.device) to_device_inplace(collated_vllm_logprobs, self.device) # accumulation steps should always be at least 1 - accumulation_steps = max(math.ceil(len(collated_query_responses) / num_mini_batches - 0.5), 1) + accumulation_steps = max(math.ceil(len(collated_query_responses) / self.num_mini_batches - 0.5), 1) leftover = len(collated_query_responses) % accumulation_steps if leftover > 0: collated_query_responses = collated_query_responses[0:-leftover] @@ -1075,7 +1078,7 @@ def train(self, pad_token_id: int, num_mini_batches: int): collated_advantages = collated_advantages[0:-leftover] collated_response_masks = collated_response_masks[0:-leftover] collated_vllm_logprobs = collated_vllm_logprobs[0:-leftover] - logger.warning(f"{leftover} samples are dropped due to batch size {num_mini_batches}") + logger.warning(f"{leftover} samples are dropped due to batch size {self.num_mini_batches}") # recalculate the "real" number of mini-batches num_mini_batches = len(collated_query_responses) // accumulation_steps @@ -1094,7 +1097,7 @@ def train(self, pad_token_id: int, num_mini_batches: int): query_response, attention_mask, position_id, - pad_token_id, + self.pad_token_id, args.temperature, return_entropy=False, ) @@ -1124,7 +1127,7 @@ def train(self, pad_token_id: int, num_mini_batches: int): query_response, attention_mask, position_id, - pad_token_id, + self.pad_token_id, args.temperature, return_entropy=False, ) @@ -1176,7 +1179,7 @@ def train(self, pad_token_id: int, num_mini_batches: int): mb_query_responses, mb_attention_mask, mb_position_id, - pad_token_id, + self.pad_token_id, args.temperature, return_entropy=args.record_entropy, ) @@ -2646,12 +2649,7 @@ def one_training_step( update_ref_policy_future = [] with Timer("[Main Thread] 🗡️ Training") as train_timer: metrics_list, _ = ray_get_with_progress( - [ - policy_group.models[i].train.remote( - pad_token_id=tokenizer.pad_token_id, num_mini_batches=args.num_mini_batches - ) - for i in range(args.world_size) - ], + [policy_group.models[i].step.remote() for i in range(args.world_size)], desc=f"Running training step {training_step}", ) if ( From ee012abc780f15c196c2da4867692f5c44a20be8 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 12:55:12 -0700 Subject: [PATCH 07/96] Cleaned up PR --- open_instruct/grpo_fast.py | 506 +------------------------ open_instruct/queue_types.py | 4 + open_instruct/streaming_data_loader.py | 444 ++++++++++++++++++++-- 3 files changed, 425 insertions(+), 529 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index f32f14980..2eee942ba 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -37,6 +37,13 @@ import deepspeed from open_instruct import streaming_data_loader, utils +from open_instruct.streaming_data_loader import ( + PendingQueriesMap, + ShufflingIterator, + accumulate_inference_batches, + add_prompt_to_generator, + collate_fn, +) # isort: on import asyncio @@ -73,17 +80,13 @@ from ray.util.placement_group import PlacementGroup, placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from rich.pretty import pprint -from tqdm import tqdm from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, get_scheduler from transformers.integrations import HfDeepSpeedConfig from open_instruct import logger_utils, vllm_utils from open_instruct.actor_manager import ActorManager from open_instruct.dataset_transformation import ( - GROUND_TRUTHS_KEY, INPUT_IDS_PROMPT_KEY, - RAW_PROMPT_KEY, - VERIFIER_SOURCE_KEY, TokenizerConfig, get_cached_dataset_tulu, visualize_token, @@ -94,7 +97,6 @@ soft_format_reward_func, ) from open_instruct.model_utils import ( - Batch, ModelConfig, apply_verifiable_reward, disable_dropout_in_model, @@ -105,7 +107,7 @@ print_rich_table, push_folder_to_hub, ) -from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics +from open_instruct.queue_types import ShutdownSentinel from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences from open_instruct.utils import ( ArgumentParserPlus, @@ -114,7 +116,6 @@ _z3_params_to_fetch, calibrate_checkpoint_state_dir, clean_last_n_checkpoints_deepspeed, - combine_reward_metrics, download_latest_checkpoint_from_gs, get_beaker_whoami, get_eval_ds_config, @@ -128,7 +129,6 @@ maybe_use_ai2_hf_entity, maybe_use_ai2_wandb_entity, ray_get_with_progress, - repeat_each, sync_gs_bucket, ) @@ -137,10 +137,6 @@ INVALID_LOGPROB = 1.0 -class ShutdownSentinel: - """Sentinel value to signal thread shutdown via queue.""" - - @dataclass class Args: # Dataset @@ -549,13 +545,6 @@ def masked_mean( return (numerator / denom).mean() -def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: - padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) - if pin_memory: - padded_tensor = padded_tensor.pin_memory() - return padded_tensor - - @Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker") def prepare_collated_data_for_workers( packed_sequences: PackedSequences, @@ -645,70 +634,6 @@ def to_device_inplace(tensors_list: list[torch.Tensor], device: torch.device): tensors_list[i] = tensors_list[i].to(device, non_blocking=True) -class ShufflingIterator: - def __init__(self, data: np.ndarray, batch_size: int, seed: int | None = None): - self.data = data.copy() - self.batch_size = batch_size - self.index = 0 - self.epoch_number = 0 - self.rng = np.random.default_rng(seed) - self.rng.shuffle(self.data) - self.exclude_list = [] - - self._update_effective_size() - - def __iter__(self) -> Iterator[list[int]]: - return self - - def __next__(self) -> list[int] | int: - """Return a list of next indices or a single index if batch size is 1""" - if self.index >= self.effective_size: - self.index = 0 - self._update_effective_size() - self.epoch_number += 1 - self.rng.shuffle(self.data) - - end_index = self.index + self.batch_size - batch = self.data[self.index : end_index].tolist() - if self.batch_size == 1: - batch = batch[0] - self.index = end_index - - return batch - - def get_state(self) -> dict[str, Any]: - """Get the current state of the iterator for checkpointing.""" - return { - "index": self.index, - "epoch_number": self.epoch_number, - "data": self.data.copy(), - "rng_state": self.rng.bit_generator.state, - "exclude_list": self.exclude_list.copy(), - } - - def set_state(self, state: dict[str, Any]) -> None: - """Restore the iterator state from a checkpoint.""" - self.index = state["index"] - self.epoch_number = state.get("epoch_number", 0) - self.data = state["data"].copy() - self.rng.bit_generator.state = state["rng_state"] - self.exclude_list = state.get("exclude_list", []) - self._update_effective_size() - - def exclude_index(self, index: int) -> None: - """Exclude provided data points from future sampling.""" - self.exclude_list.append(index) - - def _update_effective_size(self) -> None: - """Ensure the effective dataset size is divisible by batch_size and filter out all the indices excluded in the last epoch""" - if self.exclude_list: - mask = ~np.isin(self.data, self.exclude_list) - self.data = self.data[mask] - self.exclude_list = [] - - self.effective_size = len(self.data) - (len(self.data) % self.batch_size) - - @ray.remote(num_gpus=1) class PolicyTrainerRayProcess(RayProcess): def __init__( @@ -1595,70 +1520,6 @@ def get_bundle_index(rank, num_gpus_per_node): self.models.append(worker_policy) -class PendingQueriesMap: - """Thread-safe map for tracking pending queries with reference counting.""" - - def __init__(self): - self._map = {} # dataset_idx -> (query, ground_truth, dataset, count) - self._lock = threading.Lock() - - def insert(self, dataset_idx, query, ground_truth, dataset, raw_query): - """Insert or increment count for a dataset index.""" - with self._lock: - if dataset_idx in self._map: - # Already exists - just increment count - existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[ - dataset_idx - ] - self._map[dataset_idx] = ( - existing_query, - existing_ground_truth, - existing_dataset, - existing_raw_query, - count + 1, - ) - else: - # New entry - count starts at 1 - self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, 1) - - def pop(self, dataset_idx): - """Retrieve data and decrement count. Removes entry when count reaches 0.""" - with self._lock: - if dataset_idx not in self._map: - raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") - - query, ground_truth, dataset, raw_query, count = self._map[dataset_idx] - - if count > 1: - # More results expected - just decrement - self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, count - 1) - else: - # Last result - remove entry - del self._map[dataset_idx] - - return query, ground_truth, dataset, raw_query - - def __len__(self): - """Return the number of entries in the map.""" - with self._lock: - return len(self._map) - - def __contains__(self, dataset_idx): - """Check if a dataset index is in the map.""" - with self._lock: - return dataset_idx in self._map - - def __getitem__(self, dataset_idx): - """Get the value for a dataset index.""" - with self._lock: - return self._map[dataset_idx] - - def keys(self): - """Return a view of the keys in the map.""" - with self._lock: - return list(self._map.keys()) - - def calculate_utilization_metrics( model_dims: utils.ModelDims, prompt_lengths: list[int], @@ -1716,299 +1577,6 @@ def calculate_utilization_metrics( return utilization_metrics -@dataclass -class BatchStatistics: - prompt_lengths: list[int] - response_lengths: list[int] - filtered_prompts: int - filtered_prompts_zero: int - filtered_prompts_solved: int - filtered_prompts_nonzero: int - percent_solved_mean: float - no_resampled_prompts: int - total_prompts: int - - -def accumulate_inference_batches( - inference_results_Q: ray_queue.Queue, - pending_queries_map: PendingQueriesMap, - args: Args, - generation_config: vllm.SamplingParams, - num_prompts: int, - model_dims: utils.ModelDims, - tokenizer: PreTrainedTokenizer, - reward_fn: Callable, - actor_manager=None, - timeout: float | None = None, - active_sampling: bool = False, - filter_zero_std_samples: bool = False, - replenish_prompts: bool = False, - no_resampling_pass_rate: float | None = None, - iter_dataloader: ShufflingIterator | None = None, - prompt_dataset: Dataset = None, - param_prompt_Q: ray_queue.Queue | None = None, - training_step: int = None, -) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: - """Accumulate multiple inference results into a single training batch. - - Args: - inference_results_Q: Queue containing individual GenerationResult objects (one per prompt) - pending_queries_map: PendingQueriesMap instance for thread-safe query tracking - args: Arguments containing vllm_num_engines and batch size info - generation_config: Generation config containing n (number of samples per prompt) - num_prompts: Number of prompts to accumulate - timeout: Optional timeout in seconds for queue get operations. If None, blocks indefinitely. - active_sampling: Whether to continue sampling until we have sampled num_prompts prompts with non-zero std - filter_zero_std_samples: Whether to filter samples with zero reward std - replenish_prompts: Add a prompt back onto the prompt_Q after receiving a finished result - no_resampling_pass_rate: Optional rate at which to note samples solved at greater than this rate - and exclude them from further sampling - iter_dataloader: Optional, used for no_resampling_pass_rate - param_prompt_Q: Queue containing prompts to send to generator, used to replenish used prompts - - Raises: - queue.Empty: If timeout is specified and no data is available within timeout. - - Returns: - Tuple of (combined_result, Batch with queries, ground_truths, datasets, prompt_lengths, response_lengths) - or (ShutdownSentinel, None, None, None) if shutdown signal received - """ - if no_resampling_pass_rate is not None: - assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed" - - if replenish_prompts: - assert param_prompt_Q is not None and iter_dataloader is not None and prompt_dataset is not None, ( - "replenish_prompts requires param_prompt_Q and iter_dataloader and prompt_dataset" - ) - - results = [] - all_queries = [] - all_ground_truths = [] - all_datasets = [] - all_raw_queries = [] - all_decoded_responses = [] - all_reward_metrics = [] - all_scores = [] - all_percent_solved = [] - total_filtered_prompts = 0 - filtered_prompt_zero = 0 - filtered_prompt_solved = 0 - filtered_prompt_nonzero = 0 - total_no_resampled = 0 - progress_bar = tqdm( - total=num_prompts, - desc=f"Accumulating Responses and Rewarding {num_prompts} prompts", - bar_format="{l_bar}{bar}{r_bar}\n", - disable=not args.verbose, - ) - num_prompts_sampled = 0 - while num_prompts_sampled < num_prompts: - result = inference_results_Q.get(timeout=timeout) - - if isinstance(result, ShutdownSentinel): - return result, None, None, None - - # Validate that each individual result has the expected number of responses - assert len(result.responses) == generation_config.n, ( - f"Mismatch: individual prompt result has {len(result.responses)} responses " - f"but expected {generation_config.n} samples per prompt. " - f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}" - ) - - query, ground_truth, dataset_name, raw_query = pending_queries_map.pop(result.dataset_index) - - # Replenish generation queue with new prompt - if replenish_prompts: - dataset_index = next(iter_dataloader) - add_prompt_to_generator( - prompt_dataset[dataset_index], - dataset_index, - iter_dataloader.epoch_number, - training_step, - pending_queries_map, - param_prompt_Q, - generation_config, - is_eval=False, - ) - - # TODO(finbarrtimbers): Move this to LLMRayActor. - for i in range(len(result.finish_reasons)): - if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: - result.responses[i].append(tokenizer.eos_token_id) - result.masks[i].append(1) - result.logprobs[i].append(float("nan")) - - decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) - - # TODO(finbarrtimbers): Make PendingQueriesMap.pop return a Batch, and add a Batch.repeat method. - k_queries = repeat_each([query], generation_config.n) - k_ground_truths = repeat_each([ground_truth], generation_config.n) - k_datasets = repeat_each([dataset_name], generation_config.n) - k_raw_queries = repeat_each([raw_query], generation_config.n) - - scores, reward_metrics = asyncio.run( - reward_fn( - result.responses, - decoded_responses, - k_ground_truths, - k_datasets, - result.finish_reasons, - result.request_info, - k_raw_queries, - ) - ) - - percent_solved = np.mean(scores).item() / args.max_possible_score - # Don't resample prompt that was solved at more than no_resample_positive_rate - if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: - iter_dataloader.exclude_index(result.dataset_index) - total_no_resampled += 1 - logging.debug( - f"[Data Preparation Thread] Prompt solved at {percent_solved}, will be excluded from resampling, total no resampled: {total_no_resampled}" - ) - - # Filter out zero std prompts - if filter_zero_std_samples and np.std(scores) == 0: - # If we're not active sampling, still count this as a sample - if not active_sampling: - num_prompts_sampled += 1 - progress_bar.update(1) - - total_filtered_prompts += 1 - if scores[0] == 0: - filtered_prompt_zero += 1 - elif scores[0] == args.max_possible_score: - filtered_prompt_solved += 1 - else: - filtered_prompt_nonzero += 1 - logging.debug( - f"[Data Preparation Thread] Filtered prompt with reward std 0, total filtered {total_filtered_prompts}" - ) - continue - else: - num_prompts_sampled += 1 - progress_bar.update(1) - - results.append(result) - all_queries.extend(k_queries) - all_ground_truths.extend(k_ground_truths) - all_datasets.extend(k_datasets) - all_raw_queries.extend(k_raw_queries) - all_decoded_responses.extend(decoded_responses) - all_scores.extend(scores) - all_reward_metrics.append(reward_metrics) - all_percent_solved.append(percent_solved) - - # Combine all results into a single GenerationResult - combined_responses = [] - combined_finish_reasons = [] - combined_masks = [] - combined_num_calls = [] - combined_timeouts = [] - combined_tool_errors = [] - combined_tool_outputs = [] - combined_tool_runtimes = [] - combined_tool_calleds = [] - combined_logprobs = [] - - earliest_start_time = float("inf") - prompt_lengths = [] - response_lengths = [] - - total_prompt_tokens = 0 - total_response_tokens = 0 - max_generation_time = 0 - - for i, result in enumerate(results): - combined_responses.extend(result.responses) - combined_finish_reasons.extend(result.finish_reasons) - combined_masks.extend(result.masks) - combined_num_calls.extend(result.request_info.num_calls) - combined_timeouts.extend(result.request_info.timeouts) - combined_tool_errors.extend(result.request_info.tool_errors) - combined_tool_outputs.extend(result.request_info.tool_outputs) - combined_tool_runtimes.extend(result.request_info.tool_runtimes) - combined_tool_calleds.extend(result.request_info.tool_calleds) - - combined_logprobs.extend(result.logprobs) - - earliest_start_time = min(earliest_start_time, result.start_time) - - prompt_lengths.append(len(all_queries[i * generation_config.n])) - - for response in result.responses: - response_lengths.append(len(response)) - - total_prompt_tokens += result.token_statistics.num_prompt_tokens - total_response_tokens += result.token_statistics.num_response_tokens - max_generation_time = max(max_generation_time, result.token_statistics.generation_time) - - # Use the maximum generation time across engines since they work in parallel - # This avoids including queue overhead and accumulation time in MFU/MBU calculations - total_generation_time = max_generation_time - - accumulated_stats = TokenStatistics( - num_prompt_tokens=total_prompt_tokens, - num_response_tokens=total_response_tokens, - generation_time=total_generation_time, - earliest_start_time=earliest_start_time, - ) - - # Create combined RequestInfo - combined_request_info = RequestInfo( - num_calls=combined_num_calls, - timeouts=combined_timeouts, - tool_errors=combined_tool_errors, - tool_outputs=combined_tool_outputs, - tool_runtimes=combined_tool_runtimes, - tool_calleds=combined_tool_calleds, - ) - - # Create combined GenerationResult - combined_result = GenerationResult( - responses=combined_responses, - finish_reasons=combined_finish_reasons, - masks=combined_masks, - request_info=combined_request_info, - dataset_index=None, - epoch_number=results[0].epoch_number, - token_statistics=accumulated_stats, - logprobs=combined_logprobs, - ) - - if actor_manager is not None: - ray.get(actor_manager.report_token_statistics.remote(accumulated_stats)) - - # Note: We don't have dataset_indices here, but they're not needed for the returned batch - batch = Batch( - queries=all_queries, - ground_truths=all_ground_truths, - datasets=all_datasets, - raw_queries=all_raw_queries, - decoded_responses=all_decoded_responses, - indices=None, # Not meaningful for combined results - scores=all_scores, - ) - - combined_reward_metrics = combine_reward_metrics(all_reward_metrics) - percent_solved_mean = np.mean(all_percent_solved) if all_percent_solved else 0.0 - - batch_stats = BatchStatistics( - prompt_lengths=prompt_lengths, - response_lengths=response_lengths, - filtered_prompts=total_filtered_prompts, - filtered_prompts_zero=filtered_prompt_zero, - filtered_prompts_solved=filtered_prompt_solved, - filtered_prompts_nonzero=filtered_prompt_nonzero, - percent_solved_mean=percent_solved_mean, - no_resampled_prompts=total_no_resampled, - total_prompts=len(results), - ) - logging.info( - f"[Data Preparation Thread] Calculating rewards took {combined_reward_metrics['time/reward']} seconds" - ) - - return combined_result, batch, combined_reward_metrics, batch_stats def data_preparation_thread( @@ -2496,35 +2064,6 @@ def create_generation_configs(args: Args): return {"train": generation_config, "eval": eval_generation_config} -def add_prompt_to_generator( - example: dict[str, Any], - example_index: int, - epoch_number: int, - training_step: int, - pending_queries_map: PendingQueriesMap, - param_prompt_Q: ray_queue.Queue, - generation_config, - is_eval: bool, -) -> None: - """Split a batch into multiple inference batches and insert individual prompts into queues and mapping.""" - query = example[INPUT_IDS_PROMPT_KEY] - ground_truth = example[GROUND_TRUTHS_KEY] - dataset_name = example[VERIFIER_SOURCE_KEY] - raw_query = example[RAW_PROMPT_KEY] - pending_queries_map.insert(example_index, query, ground_truth, dataset_name, raw_query) - - param_prompt_Q.put( - PromptRequest( - prompt=query, - generation_config=generation_config, - epoch_number=epoch_number, - training_step=training_step, - dataset_index=example_index, - is_eval=is_eval, - ) - ) - - def load_data_from_packing_thread( packed_sequences_Q: Queue, num_total_tokens: int, stop_event: threading.Event, health_check_fn: Callable[[], None] ) -> tuple[list[dict[str, list[torch.Tensor]]] | None, dict[str, Any], int, int, list[int] | None, list[int] | None]: @@ -3222,7 +2761,12 @@ def health_check_fn(): save_final_model(args, policy_group, tokenizer, training_step, wandb_url, tc.chat_template_name) -def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): +def main( + args: Args, + tc: TokenizerConfig, + model_config: ModelConfig, + streaming_config: streaming_data_loader.StreamingDataLoaderConfig, +): tokenizer = make_tokenizer(tc, model_config) args = setup_runtime_variables(args) @@ -3262,19 +2806,6 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): reward_fn = make_reward_fn(args) generation_configs = create_generation_configs(args) - # Get resume_training_step to create data_loader_config - # We need to temporarily estimate this; it will be corrected after model init - resume_training_step_estimate = 0 - - data_loader_config = streaming_data_loader.StreamingDataLoaderConfig( - work_dir=args.output_dir, - global_batch_size=args.num_unique_prompts_rollout, - dp_world_size=args.world_size, - resume_training_step=resume_training_step_estimate, - num_training_steps=args.num_training_steps, - args=args, - ) - (policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims) = ( create_model_and_optimizer( args, @@ -3286,7 +2817,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): inference_results_Q, param_prompt_Q, evaluation_inference_results_Q, - data_loader_config, + streaming_config, train_dataset, reward_fn, pending_queries_map, @@ -3386,10 +2917,11 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): if __name__ == "__main__": utils.check_oe_eval_internal() - parser = ArgumentParserPlus((Args, TokenizerConfig, ModelConfig)) - args, tokenizer_config, model_config = parser.parse_args_into_dataclasses() + parser = ArgumentParserPlus((Args, TokenizerConfig, ModelConfig, streaming_data_loader.StreamingDataLoaderConfig)) + args, tokenizer_config, model_config, streaming_config = parser.parse_args_into_dataclasses() assert isinstance(args, Args) assert isinstance(tokenizer_config, TokenizerConfig) assert isinstance(model_config, ModelConfig) + assert isinstance(streaming_config, streaming_data_loader.StreamingDataLoaderConfig) - main(args, tokenizer_config, model_config) + main(args, tokenizer_config, model_config, streaming_config) diff --git a/open_instruct/queue_types.py b/open_instruct/queue_types.py index 0cc047bca..8267f2f44 100644 --- a/open_instruct/queue_types.py +++ b/open_instruct/queue_types.py @@ -2,6 +2,10 @@ from typing import Any +class ShutdownSentinel: + """Sentinel value to signal thread shutdown via queue.""" + + @dataclass class TokenStatistics: """Container for token statistics from inference.""" diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index 1bf016d72..da7804a43 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import logging import threading from abc import abstractmethod @@ -23,26 +24,45 @@ import numpy as np import torch +import vllm from datasets import Dataset from ray.util import queue as ray_queue +from tqdm import tqdm from transformers import PreTrainedTokenizer from open_instruct import utils +from open_instruct.dataset_transformation import ( + GROUND_TRUTHS_KEY, + INPUT_IDS_PROMPT_KEY, + RAW_PROMPT_KEY, + VERIFIER_SOURCE_KEY, +) +from open_instruct.model_utils import Batch +from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, ShutdownSentinel, TokenStatistics from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences +from open_instruct.utils import combine_reward_metrics, repeat_each logger = logging.getLogger(__name__) -PathOrStr = Path | str - @dataclass class StreamingDataLoaderConfig: - work_dir: PathOrStr + work_dir: Path | str global_batch_size: int dp_world_size: int - resume_training_step: int num_training_steps: int - args: Any + seed: int + async_steps: int + num_samples_per_prompt_rollout: int + active_sampling: bool + filter_zero_std_samples: bool + no_resampling_pass_rate: float + advantage_normalization_type: str + mask_truncated_completions: bool + pack_length: int + max_possible_score: float + per_device_train_batch_size: int + verbose: bool def build( self, @@ -65,11 +85,10 @@ def build( param_prompt_Q=param_prompt_Q, pending_queries_map=pending_queries_map, tokenizer=tokenizer, - args=self.args, + config=self, generation_config=generation_config, work_dir=self.work_dir, global_batch_size=self.global_batch_size, - resume_training_step=self.resume_training_step, num_training_steps=self.num_training_steps, actor_manager=actor_manager, model_dims=model_dims, @@ -83,7 +102,7 @@ class DataLoaderBase: def __init__( self, *, - work_dir: PathOrStr, + work_dir: Path | str, global_batch_size: int, dp_world_size: int = 1, dp_rank: int = 0, @@ -152,7 +171,7 @@ class TextDataLoaderBase(DataLoaderBase): def __init__( self, *, - work_dir: PathOrStr, + work_dir: Path | str, global_batch_size: int, dp_world_size: int = 1, dp_rank: int = 0, @@ -245,11 +264,10 @@ def __init__( param_prompt_Q: ray_queue.Queue, pending_queries_map: dict, tokenizer: PreTrainedTokenizer, - args: Any, + config: StreamingDataLoaderConfig, generation_config: Any, - work_dir: PathOrStr, + work_dir: Path | str, global_batch_size: int, - resume_training_step: int = 0, num_training_steps: int = 0, actor_manager=None, model_dims: utils.ModelDims = None, @@ -271,19 +289,19 @@ def __init__( self.param_prompt_Q = param_prompt_Q self.pending_queries_map = pending_queries_map self.tokenizer = tokenizer - self.args = args + self.config = config self.generation_config = generation_config self.num_training_steps = num_training_steps self.actor_manager = actor_manager self.model_dims = model_dims - self.training_step = resume_training_step + self.training_step = 0 self.current_epoch = 0 dataset_indices = np.arange(len(dataset)) - self.iter_dataloader = ShufflingIterator(dataset_indices, 1, seed=args.seed + dp_rank) + self.iter_dataloader = ShufflingIterator(dataset_indices, 1, seed=config.seed + dp_rank) - self.local_queue = StdQueue(maxsize=args.async_steps) + self.local_queue = StdQueue(maxsize=config.async_steps) self.background_thread = None self.shutdown_requested = False @@ -342,9 +360,6 @@ def _start_background_thread(self): self.background_thread.start() def _data_preparation_loop(self): - from open_instruct.grpo_fast import accumulate_inference_batches - from open_instruct.queue_types import ShutdownSentinel - for training_step in range(self.training_step, self.num_training_steps): if self.shutdown_requested: logger.info(f"[DataLoader Worker {self.dp_rank}] Shutdown requested, exiting") @@ -354,21 +369,22 @@ def _data_preparation_loop(self): result, batch, reward_metrics, batch_stats = accumulate_inference_batches( self.inference_results_Q, self.pending_queries_map, - self.args, self.generation_config, num_prompts=self.rank_batch_size, model_dims=self.model_dims, tokenizer=self.tokenizer, reward_fn=self.reward_fn, actor_manager=self.actor_manager, - active_sampling=self.args.active_sampling, - filter_zero_std_samples=self.args.filter_zero_std_samples, + active_sampling=self.config.active_sampling, + filter_zero_std_samples=self.config.filter_zero_std_samples, replenish_prompts=True, - no_resampling_pass_rate=self.args.no_resampling_pass_rate, + no_resampling_pass_rate=self.config.no_resampling_pass_rate, iter_dataloader=self.iter_dataloader, prompt_dataset=self.dataset, param_prompt_Q=self.param_prompt_Q, training_step=training_step, + verbose=self.config.verbose, + max_possible_score=self.config.max_possible_score, ) if isinstance(result, ShutdownSentinel): logger.info(f"[DataLoader Worker {self.dp_rank}] Received shutdown sentinel, exiting") @@ -384,19 +400,19 @@ def _data_preparation_loop(self): and not result.request_info.tool_errors[i] for i in range(len(result.request_info.tool_outputs)) ] - scores_per_prompt = scores.reshape(-1, self.args.num_samples_per_prompt_rollout) + scores_per_prompt = scores.reshape(-1, self.config.num_samples_per_prompt_rollout) mean_grouped_rewards = scores_per_prompt.mean(axis=-1) - mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.args.num_samples_per_prompt_rollout, axis=0) + mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) std_grouped_rewards = scores_per_prompt.std(axis=-1) - std_grouped_rewards = np.repeat(std_grouped_rewards, self.args.num_samples_per_prompt_rollout, axis=0) - if self.args.advantage_normalization_type == "standard": + std_grouped_rewards = np.repeat(std_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) + if self.config.advantage_normalization_type == "standard": advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - elif self.args.advantage_normalization_type == "centered": + elif self.config.advantage_normalization_type == "centered": advantages = scores - mean_grouped_rewards else: - raise ValueError(f"Invalid advantage normalization type: {self.args.advantage_normalization_type}") + raise ValueError(f"Invalid advantage normalization type: {self.config.advantage_normalization_type}") - if self.args.mask_truncated_completions: + if self.config.mask_truncated_completions: stop_idxes = torch.tensor( [i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"] ) @@ -419,7 +435,7 @@ def _data_preparation_loop(self): queries=batch.queries, responses=result.responses, masks=result.masks, - pack_length=self.args.pack_length, + pack_length=self.config.pack_length, pad_token_id=self.tokenizer.pad_token_id, vllm_logprobs=result.logprobs, ) @@ -438,20 +454,18 @@ def _data_preparation_loop(self): logger.warning(f"No responses in batch {training_step}.") else: real_num_responses = len(result.responses) - expected_num_responses = ( - self.args.num_samples_per_prompt_rollout * self.args.num_unique_prompts_rollout - ) + expected_num_responses = self.config.num_samples_per_prompt_rollout * self.global_batch_size - unsolved_num_responses = (scores < self.args.max_possible_score).sum() + unsolved_num_responses = (scores < self.config.max_possible_score).sum() sequence_lengths = np.array([len(response) for response in result.responses]) sequence_length_solved = ( np.array([]) if np.all(scores == 0) - else np.array(sequence_lengths[scores == self.args.max_possible_score]) + else np.array(sequence_lengths[scores == self.config.max_possible_score]) ) sequence_length_unsolved = ( np.array([]) - if np.all(scores == self.args.max_possible_score) + if np.all(scores == self.config.max_possible_score) else np.array(sequence_lengths[scores == 0]) ) stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( @@ -467,7 +481,7 @@ def _data_preparation_loop(self): "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, "val/solve_rate_hist": None, - "val/total_reward_groups": real_num_responses / self.args.num_samples_per_prompt_rollout, + "val/total_reward_groups": real_num_responses / self.config.num_samples_per_prompt_rollout, "val/sequence_lengths": sequence_lengths.mean(), "val/sequence_lengths_min": sequence_lengths.min(), "val/sequence_lengths_max": sequence_lengths.max(), @@ -503,8 +517,6 @@ def _data_preparation_loop(self): self.local_queue.put(collated_data) def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> dict[str, list[torch.Tensor]]: - from open_instruct.grpo_fast import collate_fn - per_device_packed_query_responses = packed_sequences.query_responses per_device_packed_tool_masks = packed_sequences.tool_masks per_device_packed_attention_masks = packed_sequences.attention_masks @@ -521,8 +533,8 @@ def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> collated_response_masks = [] collated_advantages = [] collated_vllm_logprobs = [] - for j in range(0, len(per_device_packed_query_responses), self.args.per_device_train_batch_size): - micro_range = b_inds[j : j + self.args.per_device_train_batch_size] + for j in range(0, len(per_device_packed_query_responses), self.config.per_device_train_batch_size): + micro_range = b_inds[j : j + self.config.per_device_train_batch_size] collated_query_responses.append( collate_fn( [per_device_packed_query_responses[idx] for idx in micro_range], self.tokenizer.pad_token_id, True @@ -557,3 +569,351 @@ def shutdown(self): self.shutdown_requested = True if self.background_thread is not None: self.background_thread.join(timeout=5.0) + + +def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: + padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) + if pin_memory: + padded_tensor = padded_tensor.pin_memory() + return padded_tensor + + +@dataclass +class BatchStatistics: + prompt_lengths: list[int] + response_lengths: list[int] + filtered_prompts: int + filtered_prompts_zero: int + filtered_prompts_solved: int + filtered_prompts_nonzero: int + percent_solved_mean: float + no_resampled_prompts: int + total_prompts: int + + +class PendingQueriesMap: + def __init__(self): + self._map = {} + self._lock = threading.Lock() + + def insert(self, dataset_idx, query, ground_truth, dataset, raw_query): + with self._lock: + if dataset_idx in self._map: + existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[ + dataset_idx + ] + self._map[dataset_idx] = ( + existing_query, + existing_ground_truth, + existing_dataset, + existing_raw_query, + count + 1, + ) + else: + self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, 1) + + def pop(self, dataset_idx): + with self._lock: + if dataset_idx not in self._map: + raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") + + query, ground_truth, dataset, raw_query, count = self._map[dataset_idx] + + if count > 1: + self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, count - 1) + else: + del self._map[dataset_idx] + + return query, ground_truth, dataset, raw_query + + def __len__(self): + with self._lock: + return len(self._map) + + def __contains__(self, dataset_idx): + with self._lock: + return dataset_idx in self._map + + def __getitem__(self, dataset_idx): + with self._lock: + return self._map[dataset_idx] + + def keys(self): + with self._lock: + return list(self._map.keys()) + + +def add_prompt_to_generator( + example: dict[str, Any], + example_index: int, + epoch_number: int, + training_step: int, + pending_queries_map: PendingQueriesMap, + param_prompt_Q: ray_queue.Queue, + generation_config, + is_eval: bool, +) -> None: + query = example[INPUT_IDS_PROMPT_KEY] + ground_truth = example[GROUND_TRUTHS_KEY] + dataset_name = example[VERIFIER_SOURCE_KEY] + raw_query = example[RAW_PROMPT_KEY] + pending_queries_map.insert(example_index, query, ground_truth, dataset_name, raw_query) + + param_prompt_Q.put( + PromptRequest( + prompt=query, + generation_config=generation_config, + epoch_number=epoch_number, + training_step=training_step, + dataset_index=example_index, + is_eval=is_eval, + ) + ) + + +def accumulate_inference_batches( + inference_results_Q: ray_queue.Queue, + pending_queries_map: PendingQueriesMap, + generation_config: vllm.SamplingParams, + num_prompts: int, + model_dims: utils.ModelDims, + tokenizer: PreTrainedTokenizer, + reward_fn: Callable, + actor_manager=None, + timeout: float | None = None, + active_sampling: bool = False, + filter_zero_std_samples: bool = False, + replenish_prompts: bool = False, + no_resampling_pass_rate: float | None = None, + iter_dataloader: ShufflingIterator | None = None, + prompt_dataset: Dataset = None, + param_prompt_Q: ray_queue.Queue | None = None, + training_step: int = None, + verbose: bool = False, + max_possible_score: float = 1.0, +) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: + import ray + + if no_resampling_pass_rate is not None: + assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed" + + if replenish_prompts: + assert param_prompt_Q is not None and iter_dataloader is not None and prompt_dataset is not None, ( + "replenish_prompts requires param_prompt_Q and iter_dataloader and prompt_dataset" + ) + + results = [] + all_queries = [] + all_ground_truths = [] + all_datasets = [] + all_raw_queries = [] + all_decoded_responses = [] + all_reward_metrics = [] + all_scores = [] + all_percent_solved = [] + total_filtered_prompts = 0 + filtered_prompt_zero = 0 + filtered_prompt_solved = 0 + filtered_prompt_nonzero = 0 + total_no_resampled = 0 + progress_bar = tqdm( + total=num_prompts, + desc=f"Accumulating Responses and Rewarding {num_prompts} prompts", + bar_format="{l_bar}{bar}{r_bar}\n", + disable=not verbose, + ) + num_prompts_sampled = 0 + while num_prompts_sampled < num_prompts: + result = inference_results_Q.get(timeout=timeout) + + if isinstance(result, ShutdownSentinel): + return result, None, None, None + + assert len(result.responses) == generation_config.n, ( + f"Mismatch: individual prompt result has {len(result.responses)} responses " + f"but expected {generation_config.n} samples per prompt. " + f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}" + ) + + query, ground_truth, dataset_name, raw_query = pending_queries_map.pop(result.dataset_index) + + if replenish_prompts: + dataset_index = next(iter_dataloader) + add_prompt_to_generator( + prompt_dataset[dataset_index], + dataset_index, + iter_dataloader.epoch_number, + training_step, + pending_queries_map, + param_prompt_Q, + generation_config, + is_eval=False, + ) + + for i in range(len(result.finish_reasons)): + if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: + result.responses[i].append(tokenizer.eos_token_id) + result.masks[i].append(1) + result.logprobs[i].append(float("nan")) + + decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) + + k_queries = repeat_each([query], generation_config.n) + k_ground_truths = repeat_each([ground_truth], generation_config.n) + k_datasets = repeat_each([dataset_name], generation_config.n) + k_raw_queries = repeat_each([raw_query], generation_config.n) + + scores, reward_metrics = asyncio.run( + reward_fn( + result.responses, + decoded_responses, + k_ground_truths, + k_datasets, + result.finish_reasons, + result.request_info, + k_raw_queries, + ) + ) + + percent_solved = np.mean(scores).item() / max_possible_score + if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: + iter_dataloader.exclude_index(result.dataset_index) + total_no_resampled += 1 + logging.debug( + f"[Data Preparation Thread] Prompt solved at {percent_solved}, will be excluded from resampling, total no resampled: {total_no_resampled}" + ) + + if filter_zero_std_samples and np.std(scores) == 0: + if not active_sampling: + num_prompts_sampled += 1 + progress_bar.update(1) + + total_filtered_prompts += 1 + if scores[0] == 0: + filtered_prompt_zero += 1 + elif scores[0] == max_possible_score: + filtered_prompt_solved += 1 + else: + filtered_prompt_nonzero += 1 + logging.debug( + f"[Data Preparation Thread] Filtered prompt with reward std 0, total filtered {total_filtered_prompts}" + ) + continue + else: + num_prompts_sampled += 1 + progress_bar.update(1) + + results.append(result) + all_queries.extend(k_queries) + all_ground_truths.extend(k_ground_truths) + all_datasets.extend(k_datasets) + all_raw_queries.extend(k_raw_queries) + all_decoded_responses.extend(decoded_responses) + all_scores.extend(scores) + all_reward_metrics.append(reward_metrics) + all_percent_solved.append(percent_solved) + + combined_responses = [] + combined_finish_reasons = [] + combined_masks = [] + combined_num_calls = [] + combined_timeouts = [] + combined_tool_errors = [] + combined_tool_outputs = [] + combined_tool_runtimes = [] + combined_tool_calleds = [] + combined_logprobs = [] + + earliest_start_time = float("inf") + prompt_lengths = [] + response_lengths = [] + + total_prompt_tokens = 0 + total_response_tokens = 0 + max_generation_time = 0 + + for i, result in enumerate(results): + combined_responses.extend(result.responses) + combined_finish_reasons.extend(result.finish_reasons) + combined_masks.extend(result.masks) + combined_num_calls.extend(result.request_info.num_calls) + combined_timeouts.extend(result.request_info.timeouts) + combined_tool_errors.extend(result.request_info.tool_errors) + combined_tool_outputs.extend(result.request_info.tool_outputs) + combined_tool_runtimes.extend(result.request_info.tool_runtimes) + combined_tool_calleds.extend(result.request_info.tool_calleds) + + combined_logprobs.extend(result.logprobs) + + earliest_start_time = min(earliest_start_time, result.start_time) + + prompt_lengths.append(len(all_queries[i * generation_config.n])) + + for response in result.responses: + response_lengths.append(len(response)) + + total_prompt_tokens += result.token_statistics.num_prompt_tokens + total_response_tokens += result.token_statistics.num_response_tokens + max_generation_time = max(max_generation_time, result.token_statistics.generation_time) + + total_generation_time = max_generation_time + + accumulated_stats = TokenStatistics( + num_prompt_tokens=total_prompt_tokens, + num_response_tokens=total_response_tokens, + generation_time=total_generation_time, + earliest_start_time=earliest_start_time, + ) + + combined_request_info = RequestInfo( + num_calls=combined_num_calls, + timeouts=combined_timeouts, + tool_errors=combined_tool_errors, + tool_outputs=combined_tool_outputs, + tool_runtimes=combined_tool_runtimes, + tool_calleds=combined_tool_calleds, + ) + + combined_result = GenerationResult( + responses=combined_responses, + finish_reasons=combined_finish_reasons, + masks=combined_masks, + request_info=combined_request_info, + dataset_index=None, + epoch_number=results[0].epoch_number, + token_statistics=accumulated_stats, + logprobs=combined_logprobs, + ) + + if actor_manager is not None: + ray.get(actor_manager.report_token_statistics.remote(accumulated_stats)) + + batch = Batch( + queries=all_queries, + ground_truths=all_ground_truths, + datasets=all_datasets, + raw_queries=all_raw_queries, + decoded_responses=all_decoded_responses, + indices=None, + scores=all_scores, + ) + + combined_reward_metrics = combine_reward_metrics(all_reward_metrics) + percent_solved_mean = np.mean(all_percent_solved) if all_percent_solved else 0.0 + + batch_stats = BatchStatistics( + prompt_lengths=prompt_lengths, + response_lengths=response_lengths, + filtered_prompts=total_filtered_prompts, + filtered_prompts_zero=filtered_prompt_zero, + filtered_prompts_solved=filtered_prompt_solved, + filtered_prompts_nonzero=filtered_prompt_nonzero, + percent_solved_mean=percent_solved_mean, + no_resampled_prompts=total_no_resampled, + total_prompts=len(results), + ) + logging.info( + f"[Data Preparation Thread] Calculating rewards took {combined_reward_metrics['time/reward']} seconds" + ) + + return combined_result, batch, combined_reward_metrics, batch_stats From 09411639d99f0963940ae4f62364f0003c44f721 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 13:21:50 -0700 Subject: [PATCH 08/96] Cleaned up code. deleted unused code. --- open_instruct/grpo_fast.py | 229 +------------------------------------ 1 file changed, 1 insertion(+), 228 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 2eee942ba..dd0bc33b4 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -47,7 +47,6 @@ # isort: on import asyncio -import json import logging import math import random @@ -108,7 +107,7 @@ push_folder_to_hub, ) from open_instruct.queue_types import ShutdownSentinel -from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences +from open_instruct.rl_utils import PackedSequences, Timer from open_instruct.utils import ( ArgumentParserPlus, BeakerRuntimeConfig, @@ -1577,232 +1576,6 @@ def calculate_utilization_metrics( return utilization_metrics - - -def data_preparation_thread( - reward_fn: Callable, - inference_results_Q: ray_queue.Queue, # Ray queue - param_prompt_Q: ray_queue.Queue, - packed_sequences_Q: Queue, - pending_queries_map: dict, - args: Args, - tokenizer: PreTrainedTokenizer, - num_training_steps: int, - generation_config, - resume_training_step: int, - iter_dataloader: ShufflingIterator, - train_dataset: Dataset, - actor_manager=None, - model_dims: utils.ModelDims = None, -): - for training_step in range(resume_training_step, num_training_steps + 1): - # Streaming accumulation: collect results as they arrive - with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: - result, batch, reward_metrics, batch_stats = accumulate_inference_batches( - inference_results_Q, - pending_queries_map, - args, - generation_config, - num_prompts=args.num_unique_prompts_rollout, - model_dims=model_dims, - tokenizer=tokenizer, - reward_fn=reward_fn, - actor_manager=actor_manager, - active_sampling=args.active_sampling, - filter_zero_std_samples=args.filter_zero_std_samples, - replenish_prompts=True, - no_resampling_pass_rate=args.no_resampling_pass_rate, - iter_dataloader=iter_dataloader, - prompt_dataset=train_dataset, - param_prompt_Q=param_prompt_Q, - training_step=training_step, - ) - if isinstance(result, ShutdownSentinel): - logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting") - return - - getting_response_time = timer.duration - scores = np.array(batch.scores) - - good_outputs = [ - len(result.request_info.tool_outputs[i]) > 0 - and result.request_info.tool_calleds[i] - and not result.request_info.timeouts[i] - and not result.request_info.tool_errors[i] - for i in range(len(result.request_info.tool_outputs)) - ] - scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout) - mean_grouped_rewards = scores_per_prompt.mean(axis=-1) - mean_grouped_rewards = np.repeat(mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0) - std_grouped_rewards = scores_per_prompt.std(axis=-1) - std_grouped_rewards = np.repeat(std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0) - if args.advantage_normalization_type == "standard": - advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - elif args.advantage_normalization_type == "centered": - advantages = scores - mean_grouped_rewards - else: - raise ValueError(f"Invalid advantage normalization type: {args.advantage_normalization_type}") - - if args.mask_truncated_completions: - stop_idxes = torch.tensor( - [i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"] - ) - num_truncated = len(result.finish_reasons) - len(stop_idxes) - if num_truncated > 0: - logger.info( - f"[Truncated completions filtering] Filtered {num_truncated} responses that didn't finish with 'stop'. " - f"Retention rate: {len(stop_idxes) / len(result.finish_reasons):.2%}" - ) - scores = scores[stop_idxes] - advantages = advantages[stop_idxes] - batch = batch[stop_idxes.tolist()] - result.responses = [result.responses[i] for i in stop_idxes] - result.masks = [result.masks[i] for i in stop_idxes] - result.finish_reasons = [result.finish_reasons[i] for i in stop_idxes] - result.logprobs = [result.logprobs[i] for i in stop_idxes] - - with Timer("📦 [Data Preparation Thread] Packing sequences"): - packed_sequences = pack_sequences( - queries=batch.queries, - responses=result.responses, - masks=result.masks, - pack_length=args.pack_length, - pad_token_id=tokenizer.pad_token_id, - vllm_logprobs=result.logprobs, - ) - num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses) - # Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value - # and each value is the corresponding advantage score: index 0 is set to 0 since response masks start from 1 (1-indexed) - lookup_advantages = np.zeros(len(advantages) + 1, dtype=np.float32) - lookup_advantages[1:] = advantages - packed_advantages = [ - torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) - for packed_mask in packed_sequences.response_masks - ] - packed_sequences.advantages = packed_advantages - - # if we have less batches than world size, we need to pad out so each world is fine - # ideally, you should avoid this since its wasting computation. - if args.allow_world_padding: - with Timer("🤺 [Data Preparation Thread] Padding sequences for world size"): - shortfall = args.world_size - len(packed_sequences.query_responses) - if shortfall > 0: - logger.warning( - f"Padding {shortfall} sequences for world size. In future, you should adjust your compute this." - ) - # construct "dummy" sequences for padding out the world size - dummy_qr = torch.tensor([tokenizer.pad_token_id, tokenizer.eos_token_id], dtype=torch.long) - dummy_tool_mask = torch.zeros_like(dummy_qr) - dummy_attention = torch.tensor([1, 1], dtype=torch.long) - dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long) - dummy_response_mask = torch.zeros_like(dummy_qr) - dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float) - # pad out the world size - for _ in range(shortfall): - packed_sequences.query_responses.append(dummy_qr) - packed_sequences.tool_masks.append(dummy_tool_mask) - packed_sequences.attention_masks.append(dummy_attention) - packed_sequences.position_ids.append(dummy_position_ids) - packed_sequences.response_masks.append(dummy_response_mask) - packed_sequences.advantages.append(dummy_advantage) - - collated_data = prepare_collated_data_for_workers( - packed_sequences, args.world_size, args.per_device_train_batch_size, tokenizer.pad_token_id - ) - B = len(packed_sequences.query_responses) // args.world_size - - # Create a result package with metrics and data - if len(result.responses) == 0: - # Handle empty responses case - # in this case, we won't log metrics, so it should be fine. - metrics = {} - logger.warning(f"No responses in batch {training_step}.") - else: - real_num_responses = len(result.responses) - expected_num_responses = args.num_samples_per_prompt_rollout * args.num_unique_prompts_rollout - - unsolved_num_responses = (scores < args.max_possible_score).sum() - sequence_lengths = np.array([len(response) for response in result.responses]) - sequence_length_solved = ( - np.array([]) if np.all(scores == 0) else np.array(sequence_lengths[scores == args.max_possible_score]) - ) - sequence_length_unsolved = ( - np.array([]) if np.all(scores == args.max_possible_score) else np.array(sequence_lengths[scores == 0]) - ) - stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( - result.finish_reasons - ) - - batch_metrics = asdict(batch_stats) - batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()} - - metrics = { - "scores": scores.mean(), - "real_batch_size_ratio": real_num_responses / expected_num_responses, - "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, - "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, - "val/solve_rate_hist": None, - "val/total_reward_groups": real_num_responses / args.num_samples_per_prompt_rollout, - "val/sequence_lengths": sequence_lengths.mean(), - "val/sequence_lengths_min": sequence_lengths.min(), - "val/sequence_lengths_max": sequence_lengths.max(), - "val/sequence_lengths_unsolved": ( - 0 if len(sequence_length_unsolved) == 0 else sequence_length_unsolved.mean() - ), - "val/sequence_lengths_solved": ( - 0 if len(sequence_length_solved) == 0 else sequence_length_solved.mean() - ), - "val/sequence_lengths_unsolved_hist": sequence_length_unsolved, - "val/sequence_lengths_solved_hist": sequence_length_solved, - "val/stop_rate": stop_rate, - "val/advantages_mean": advantages.mean(), - "val/advantages_min": advantages.min(), - "val/advantages_max": advantages.max(), - "val/advantages_hist": advantages, - "val/num_calls_rate": np.array(result.request_info.num_calls).mean(), - "val/timeouts_rate": np.array(result.request_info.timeouts).mean(), - "val/tool_errors_rate": np.array([len(item) > 0 for item in result.request_info.tool_errors]).mean(), - "val/good_outputs_rate": np.array(good_outputs).mean(), - "val/tool_runtimes_rate": np.array(result.request_info.tool_runtimes).mean(), - "val/tool_calleds_rate": np.array(result.request_info.tool_calleds).mean(), - "time/getting_response": getting_response_time, - **reward_metrics, - **batch_metrics_prefixed, - } - - total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens - metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time - - if args.save_traces: - traces = { - "scores": scores.tolist(), - "finish_reasons": result.finish_reasons, - "responses": result.responses, - "training_step": training_step, - **asdict(batch), # Unpack all batch fields - **reward_metrics, - } - os.makedirs(args.output_dir, exist_ok=True) - with open(f"{args.output_dir}/traces_{args.run_name}.jsonl", "a") as f: - json.dump(traces, f) - f.write("\n") - - # Put the packed sequences and metrics into the output queue - packed_sequences_Q.put( - { - "packed_sequences": packed_sequences, # for debugging purposes - "collated_data": collated_data, - "metrics": metrics, - "responses_count": len(result.responses), - "num_new_tokens": num_new_tokens, - "B": B, - "prompt_lengths": batch_stats.prompt_lengths, - "response_lengths": batch_stats.response_lengths, - "num_filtered_prompts": batch_stats.filtered_prompts, - } - ) - - def setup_runtime_variables(args: Args) -> Args: """Set up runtime variables for the experiment.""" args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" From 0c5a98d4c8c247ecd708f6f34133df0ca7304a6b Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 13:57:17 -0700 Subject: [PATCH 09/96] deleted code --- open_instruct/grpo_fast.py | 39 -------------------------------------- 1 file changed, 39 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index dd0bc33b4..a0a0467c5 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1837,45 +1837,6 @@ def create_generation_configs(args: Args): return {"train": generation_config, "eval": eval_generation_config} -def load_data_from_packing_thread( - packed_sequences_Q: Queue, num_total_tokens: int, stop_event: threading.Event, health_check_fn: Callable[[], None] -) -> tuple[list[dict[str, list[torch.Tensor]]] | None, dict[str, Any], int, int, list[int] | None, list[int] | None]: - """Get the packed sequences with advantages from the packing thread.""" - with Timer("[Main Thread] 📦 Getting packed sequences from thread") as timer: - while True: - if stop_event.is_set(): - logger.warning("[Main Thread] Stop event detected while waiting for packed sequences") - return None, {}, num_total_tokens, 0, None, None, 0 - try: - packed_data = packed_sequences_Q.get(timeout=30.0) - break - except Empty: - health_check_fn() - logger.warning("[Main Thread] Timeout waiting for packed sequences. Retrying...") - data_thread_metrics = packed_data["metrics"] - B = packed_data["B"] - collated_data = packed_data["collated_data"] - num_step_tokens = packed_data["num_new_tokens"] - num_total_tokens += num_step_tokens - prompt_lengths = packed_data["prompt_lengths"] - response_lengths = packed_data["response_lengths"] - num_filtered_prompts = packed_data["num_filtered_prompts"] - - data_thread_metrics["time/trainer_idling"] = timer.duration - if B == 0: - logger.warning("[Main Thread] 🤡 After packing, there is not enough data to train") - return None, data_thread_metrics, num_total_tokens, 0, None, None, 0 - return ( - collated_data, - data_thread_metrics, - num_total_tokens, - num_step_tokens, - prompt_lengths, - response_lengths, - num_filtered_prompts, - ) - - def weight_sync_thread( args: Args, stop_event: threading.Event, From 3271f4f1e53dcdf25ccae27790d86458752e9a44 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 14:42:34 -0700 Subject: [PATCH 10/96] Updated code --- open_instruct/grpo_fast.py | 62 +++++---------------- open_instruct/streaming_data_loader.py | 74 ++++++++++++++++++-------- 2 files changed, 65 insertions(+), 71 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index a0a0467c5..14c1e849a 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -39,7 +39,6 @@ from open_instruct import streaming_data_loader, utils from open_instruct.streaming_data_loader import ( PendingQueriesMap, - ShufflingIterator, accumulate_inference_batches, add_prompt_to_generator, collate_fn, @@ -56,7 +55,7 @@ import time from argparse import Namespace from collections import defaultdict -from collections.abc import Callable, Iterator +from collections.abc import Callable from dataclasses import asdict, dataclass, field from datetime import timedelta from queue import Empty, Full, Queue @@ -244,8 +243,6 @@ class Args: """Enable immediate stopping of request processing when should_stop is set, allowing for quick pausing and resumption""" kl_estimator: Literal["kl1", "kl2", "kl3", "kl4"] = "kl3" """the KL estimator to use""" - pack_length: int = 512 - """the length of the pack (you should prob set to the max length of the model)""" masked_mean_axis: int | None = None """the axis to compute the mean of the masked values""" masked_mean_denominator: float | None = None @@ -257,19 +254,6 @@ class Args: """ ref_policy_update_freq: int | None = None """How many training steps to take before updating the reference policy.""" - advantage_normalization_type: Literal["standard", "centered"] = "standard" - """The type of advantage normalization to use. Standard normalization is the default: it subtracts the mean and - divides by the standard deviation. Centered normalization is the same but subtracts the mean only (e.g., used in - DR.GRPO https://arxiv.org/pdf/2503.20783).""" - mask_truncated_completions: bool = False - """Whether to mask out truncated completions. Also called overlong filtering, from DAPO (https://arxiv.org/abs/2503.14476).""" - - active_sampling: bool = False - """Whether to continue sampling responses until you get a full batch.""" - filter_zero_std_samples: bool = True - """Whether to filter out prompts with zero reward std (all samples have the same score).""" - no_resampling_pass_rate: float | None = None - """If the response to a prompt is solved at a rate higher than this, do not resample this prompt again""" record_entropy: bool = False """whether to record the entropy of the policy during training. Uses extra memory.""" @@ -642,6 +626,7 @@ def __init__( local_rank: int, master_addr: str | None, master_port: int | None, + args: Args, data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, dataset: Dataset, reward_fn: Callable, @@ -656,7 +641,7 @@ def __init__( super().__init__(world_size, rank, local_rank, master_addr, master_port) self.tokenizer = tokenizer self.pad_token_id = tokenizer.pad_token_id - self.num_mini_batches = data_loader_config.args.num_mini_batches + self.num_mini_batches = args.num_mini_batches self.dataloader = data_loader_config.build( dataset=dataset, reward_fn=reward_fn, @@ -667,6 +652,12 @@ def __init__( generation_config=generation_config, dp_rank=self.local_rank, fs_local_rank=self.local_rank, + num_training_steps=args.num_training_steps, + seed=args.seed, + async_steps=args.async_steps, + num_samples_per_prompt_rollout=args.num_samples_per_prompt_rollout, + per_device_train_batch_size=args.per_device_train_batch_size, + verbose=args.verbose, actor_manager=actor_manager, model_dims=model_dims, ) @@ -1426,6 +1417,7 @@ def __init__( ray_process_cls: RayProcess, num_gpus_per_node: list[int], single_gpu_mode: bool, + args: Args, data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, dataset: Dataset, reward_fn: Callable, @@ -1456,6 +1448,7 @@ def __init__( 0, None, None, + args, data_loader_config, dataset, reward_fn, @@ -1505,6 +1498,7 @@ def get_bundle_index(rank, num_gpus_per_node): 0, master_addr, master_port, + args, data_loader_config, dataset, reward_fn, @@ -1782,6 +1776,7 @@ def create_model_and_optimizer( PolicyTrainerRayProcess, args.num_learners_per_node, args.single_gpu_mode, + args=args, data_loader_config=data_loader_config, dataset=train_dataset, reward_fn=reward_fn, @@ -1916,7 +1911,6 @@ def one_training_step( prompt_lengths: list[int], response_lengths: list[int], actor_manager: ActorManager | None = None, - iter_dataloader: Iterator | None = None, ) -> None: """Train the model for one step.""" update_ref_policy_future = [] @@ -2299,7 +2293,6 @@ def run_training( policy_group, vllm_engines, generation_configs, - iter_dataloader, reward_fn, resume_training_step, episode, @@ -2350,19 +2343,6 @@ def health_check_fn(): enable=False, ) - # Send initial data to ensure we have a N-step offset. - for _ in range(args.async_steps * args.num_unique_prompts_rollout): - dataset_index = next(iter_dataloader) - add_prompt_to_generator( - train_dataset[dataset_index], - dataset_index, - iter_dataloader.epoch_number, - resume_training_step, - pending_queries_map, - param_prompt_Q, - generation_configs["train"], - is_eval=False, - ) if checkpoint_state and "num_total_tokens" in checkpoint_state: num_total_tokens = checkpoint_state["num_total_tokens"] logger.info(f"Restored num_total_tokens: {num_total_tokens}") @@ -2399,7 +2379,7 @@ def health_check_fn(): add_prompt_to_generator( eval_example, eval_index, - iter_dataloader.epoch_number, + 0, training_step, eval_pending_queries_map, param_prompt_Q, @@ -2440,7 +2420,6 @@ def health_check_fn(): prompt_lengths, response_lengths, actor_manager, - iter_dataloader, ) logger.debug(f"[Main Thread] Triggered weight sync for step {training_step}") @@ -2454,17 +2433,12 @@ def health_check_fn(): and args.checkpoint_state_dir is not None ): with Timer("[Main Thread] 🗡️ Saving checkpoint state"): - # Save comprehensive client state including ShufflingIterator state client_state = { "training_step": training_step, "episode": episode, "num_total_tokens": num_total_tokens, } - # Save ShufflingIterator state - if iter_dataloader is not None: - client_state["shuffling_iterator_state"] = iter_dataloader.get_state() - ray_get_with_progress( [ policy_group.models[i].save_checkpoint_state.remote(args.checkpoint_state_dir, client_state) @@ -2570,13 +2544,6 @@ def main( episode = checkpoint_state["episode"] logger.info(f"Restored episode count: {episode}") - train_dataset_idxs = np.arange(len(train_dataset)) - iter_dataloader = ShufflingIterator(train_dataset_idxs, 1, seed=args.seed) - - if checkpoint_state and "shuffling_iterator_state" in checkpoint_state: - iter_dataloader.set_state(checkpoint_state["shuffling_iterator_state"]) - logger.info("Restored ShufflingIterator state from checkpoint") - # Create additional queues (main queues already created above) packed_sequences_Q = Queue(maxsize=args.async_steps) eval_pending_queries_map = PendingQueriesMap() @@ -2595,7 +2562,6 @@ def main( policy_group, vllm_engines, generation_configs, - iter_dataloader, reward_fn, resume_training_step, episode, diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index da7804a43..333db2e7e 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -50,19 +50,13 @@ class StreamingDataLoaderConfig: work_dir: Path | str global_batch_size: int dp_world_size: int - num_training_steps: int - seed: int - async_steps: int - num_samples_per_prompt_rollout: int - active_sampling: bool - filter_zero_std_samples: bool - no_resampling_pass_rate: float - advantage_normalization_type: str - mask_truncated_completions: bool - pack_length: int max_possible_score: float - per_device_train_batch_size: int - verbose: bool + active_sampling: bool = False + filter_zero_std_samples: bool = True + no_resampling_pass_rate: float | None = None + advantage_normalization_type: str = "standard" + mask_truncated_completions: bool = False + pack_length: int = 512 def build( self, @@ -75,6 +69,12 @@ def build( generation_config: Any, dp_rank: int, fs_local_rank: int, + num_training_steps: int, + seed: int, + async_steps: int, + num_samples_per_prompt_rollout: int, + per_device_train_batch_size: int, + verbose: bool, actor_manager=None, model_dims: utils.ModelDims | None = None, ) -> "StreamingDataLoader": @@ -89,7 +89,12 @@ def build( generation_config=generation_config, work_dir=self.work_dir, global_batch_size=self.global_batch_size, - num_training_steps=self.num_training_steps, + num_training_steps=num_training_steps, + seed=seed, + async_steps=async_steps, + num_samples_per_prompt_rollout=num_samples_per_prompt_rollout, + per_device_train_batch_size=per_device_train_batch_size, + verbose=verbose, actor_manager=actor_manager, model_dims=model_dims, dp_world_size=self.dp_world_size, @@ -269,6 +274,11 @@ def __init__( work_dir: Path | str, global_batch_size: int, num_training_steps: int = 0, + seed: int, + async_steps: int, + num_samples_per_prompt_rollout: int, + per_device_train_batch_size: int, + verbose: bool, actor_manager=None, model_dims: utils.ModelDims = None, dp_world_size: int = 1, @@ -295,13 +305,18 @@ def __init__( self.actor_manager = actor_manager self.model_dims = model_dims + self.async_steps = async_steps + self.num_samples_per_prompt_rollout = num_samples_per_prompt_rollout + self.per_device_train_batch_size = per_device_train_batch_size + self.verbose = verbose + self.training_step = 0 self.current_epoch = 0 dataset_indices = np.arange(len(dataset)) - self.iter_dataloader = ShufflingIterator(dataset_indices, 1, seed=config.seed + dp_rank) + self.iter_dataloader = ShufflingIterator(dataset_indices, 1, seed=seed + dp_rank) - self.local_queue = StdQueue(maxsize=config.async_steps) + self.local_queue = StdQueue(maxsize=async_steps) self.background_thread = None self.shutdown_requested = False @@ -360,6 +375,19 @@ def _start_background_thread(self): self.background_thread.start() def _data_preparation_loop(self): + for _ in range(self.async_steps * self.global_batch_size // self.dp_world_size): + dataset_index = next(self.iter_dataloader) + add_prompt_to_generator( + self.dataset[dataset_index], + dataset_index, + self.iter_dataloader.epoch_number, + self.training_step, + self.pending_queries_map, + self.param_prompt_Q, + self.generation_config, + is_eval=False, + ) + for training_step in range(self.training_step, self.num_training_steps): if self.shutdown_requested: logger.info(f"[DataLoader Worker {self.dp_rank}] Shutdown requested, exiting") @@ -383,7 +411,7 @@ def _data_preparation_loop(self): prompt_dataset=self.dataset, param_prompt_Q=self.param_prompt_Q, training_step=training_step, - verbose=self.config.verbose, + verbose=self.verbose, max_possible_score=self.config.max_possible_score, ) if isinstance(result, ShutdownSentinel): @@ -400,11 +428,11 @@ def _data_preparation_loop(self): and not result.request_info.tool_errors[i] for i in range(len(result.request_info.tool_outputs)) ] - scores_per_prompt = scores.reshape(-1, self.config.num_samples_per_prompt_rollout) + scores_per_prompt = scores.reshape(-1, self.num_samples_per_prompt_rollout) mean_grouped_rewards = scores_per_prompt.mean(axis=-1) - mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) + mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.num_samples_per_prompt_rollout, axis=0) std_grouped_rewards = scores_per_prompt.std(axis=-1) - std_grouped_rewards = np.repeat(std_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) + std_grouped_rewards = np.repeat(std_grouped_rewards, self.num_samples_per_prompt_rollout, axis=0) if self.config.advantage_normalization_type == "standard": advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) elif self.config.advantage_normalization_type == "centered": @@ -454,7 +482,7 @@ def _data_preparation_loop(self): logger.warning(f"No responses in batch {training_step}.") else: real_num_responses = len(result.responses) - expected_num_responses = self.config.num_samples_per_prompt_rollout * self.global_batch_size + expected_num_responses = self.num_samples_per_prompt_rollout * self.global_batch_size unsolved_num_responses = (scores < self.config.max_possible_score).sum() sequence_lengths = np.array([len(response) for response in result.responses]) @@ -481,7 +509,7 @@ def _data_preparation_loop(self): "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, "val/solve_rate_hist": None, - "val/total_reward_groups": real_num_responses / self.config.num_samples_per_prompt_rollout, + "val/total_reward_groups": real_num_responses / self.num_samples_per_prompt_rollout, "val/sequence_lengths": sequence_lengths.mean(), "val/sequence_lengths_min": sequence_lengths.min(), "val/sequence_lengths_max": sequence_lengths.max(), @@ -533,8 +561,8 @@ def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> collated_response_masks = [] collated_advantages = [] collated_vllm_logprobs = [] - for j in range(0, len(per_device_packed_query_responses), self.config.per_device_train_batch_size): - micro_range = b_inds[j : j + self.config.per_device_train_batch_size] + for j in range(0, len(per_device_packed_query_responses), self.per_device_train_batch_size): + micro_range = b_inds[j : j + self.per_device_train_batch_size] collated_query_responses.append( collate_fn( [per_device_packed_query_responses[idx] for idx in micro_range], self.tokenizer.pad_token_id, True From 504d13fe0f499acd55a8b0444b047816692be41b Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 15:09:31 -0700 Subject: [PATCH 11/96] Fix StreamingDataLoaderConfig to pass runtime values as parameters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Moved work_dir, global_batch_size, dp_world_size, and max_possible_score from StreamingDataLoaderConfig fields to build() method parameters. These values are computed at runtime from Args and should not be CLI arguments. - work_dir comes from args.output_dir - global_batch_size comes from args.num_unique_prompts_rollout - dp_world_size comes from the actual world_size (number of PolicyTrainerRayProcess instances) - max_possible_score is computed in Args.__post_init__ 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 4 ++++ open_instruct/streaming_data_loader.py | 17 ++++++++++------- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 14c1e849a..626bb68c0 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -658,6 +658,10 @@ def __init__( num_samples_per_prompt_rollout=args.num_samples_per_prompt_rollout, per_device_train_batch_size=args.per_device_train_batch_size, verbose=args.verbose, + work_dir=args.output_dir, + global_batch_size=args.num_unique_prompts_rollout, + dp_world_size=world_size, + max_possible_score=args.max_possible_score, actor_manager=actor_manager, model_dims=model_dims, ) diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index 333db2e7e..c9bb0e8e1 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -47,10 +47,6 @@ @dataclass class StreamingDataLoaderConfig: - work_dir: Path | str - global_batch_size: int - dp_world_size: int - max_possible_score: float active_sampling: bool = False filter_zero_std_samples: bool = True no_resampling_pass_rate: float | None = None @@ -75,6 +71,10 @@ def build( num_samples_per_prompt_rollout: int, per_device_train_batch_size: int, verbose: bool, + work_dir: Path | str, + global_batch_size: int, + dp_world_size: int, + max_possible_score: float, actor_manager=None, model_dims: utils.ModelDims | None = None, ) -> "StreamingDataLoader": @@ -87,17 +87,18 @@ def build( tokenizer=tokenizer, config=self, generation_config=generation_config, - work_dir=self.work_dir, - global_batch_size=self.global_batch_size, + work_dir=work_dir, + global_batch_size=global_batch_size, num_training_steps=num_training_steps, seed=seed, async_steps=async_steps, num_samples_per_prompt_rollout=num_samples_per_prompt_rollout, per_device_train_batch_size=per_device_train_batch_size, verbose=verbose, + max_possible_score=max_possible_score, actor_manager=actor_manager, model_dims=model_dims, - dp_world_size=self.dp_world_size, + dp_world_size=dp_world_size, dp_rank=dp_rank, fs_local_rank=fs_local_rank, ) @@ -279,6 +280,7 @@ def __init__( num_samples_per_prompt_rollout: int, per_device_train_batch_size: int, verbose: bool, + max_possible_score: float, actor_manager=None, model_dims: utils.ModelDims = None, dp_world_size: int = 1, @@ -300,6 +302,7 @@ def __init__( self.pending_queries_map = pending_queries_map self.tokenizer = tokenizer self.config = config + self.config.max_possible_score = max_possible_score self.generation_config = generation_config self.num_training_steps = num_training_steps self.actor_manager = actor_manager From 2fa3a111dcf6b38a3e65a20d607fc884ff1f7764 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 20:19:00 -0700 Subject: [PATCH 12/96] Move pack_length validation to StreamingDataLoaderConfig MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Moved max_prompt_token_length and response_length to StreamingDataLoaderConfig and added __post_init__ to validate pack_length assertion there instead of in Args. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 3 --- open_instruct/streaming_data_loader.py | 7 +++++++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 626bb68c0..3bde7b660 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -458,9 +458,6 @@ def __post_init__(self): if self.inference_batch_size is None: total_prompts = self.num_samples_per_prompt_rollout * self.num_unique_prompts_rollout self.inference_batch_size = max(1, math.ceil(total_prompts / self.vllm_num_engines)) - assert self.pack_length >= self.max_prompt_token_length + self.response_length, ( - "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!" - ) if self.checkpoint_state_freq > 0 and self.checkpoint_state_dir is None: raise ValueError("`checkpoint_state_dir` must be provided if `checkpoint_state_freq` is greater than 0!") if self.checkpoint_state_dir is not None and self.checkpoint_state_freq == -1: diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index c9bb0e8e1..49762440a 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -47,6 +47,8 @@ @dataclass class StreamingDataLoaderConfig: + max_prompt_token_length: int = 256 + response_length: int = 256 active_sampling: bool = False filter_zero_std_samples: bool = True no_resampling_pass_rate: float | None = None @@ -54,6 +56,11 @@ class StreamingDataLoaderConfig: mask_truncated_completions: bool = False pack_length: int = 512 + def __post_init__(self): + assert self.pack_length >= self.max_prompt_token_length + self.response_length, ( + "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!" + ) + def build( self, dataset: Dataset, From 7e049b375e33fdb5221ed39a0d519cd16410e452 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 20:34:32 -0700 Subject: [PATCH 13/96] Move max_prompt_token_length and response_length to StreamingDataLoaderConfig MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removed these fields from Args to avoid argparse conflicts and updated all references to use streaming_config instead. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3bde7b660..6e1b713cf 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -160,8 +160,6 @@ class Args: """Whether to skip the cache.""" shuffle_eval_dataset: bool = False """Whether to shuffle the evaluation dataset.""" - max_prompt_token_length: int = 256 - """The maximum prompt token length to use for the dataset""" system_prompt_override_file: str | None = None """Path to a text file containing a system prompt to override the dataset's system prompts""" @@ -210,8 +208,6 @@ class Args: """Timeout for inference/training backends in minutes. Default is 2 hours (120 min).""" # Generation - response_length: int = 256 - """the length of the response""" temperature: float = 0.7 """the sampling temperature""" num_unique_prompts_rollout: int = 16 @@ -1625,7 +1621,9 @@ def setup_experiment_tracking(args: Args, tc: TokenizerConfig, model_config: Mod return beaker_config, wandb_url -def setup_datasets(args: Args, tc: TokenizerConfig, tokenizer: PreTrainedTokenizer): +def setup_datasets( + args: Args, tc: TokenizerConfig, tokenizer: PreTrainedTokenizer, streaming_config: streaming_data_loader.StreamingDataLoaderConfig +): """Set up training and evaluation datasets.""" system_prompt_override = None if args.system_prompt_override_file is not None: @@ -1636,7 +1634,7 @@ def setup_datasets(args: Args, tc: TokenizerConfig, tokenizer: PreTrainedTokeniz transform_fn_args = [ {"system_prompt_override": system_prompt_override}, - {"max_prompt_token_length": args.max_prompt_token_length}, + {"max_prompt_token_length": streaming_config.max_prompt_token_length}, ] train_dataset = get_cached_dataset_tulu( dataset_mixer_list=args.dataset_mixer_list, @@ -1699,7 +1697,7 @@ def create_model_and_optimizer( ray_get_with_progress([pg.ready()], desc="Waiting for placement group") # Set up tools - max_len = args.max_prompt_token_length + args.response_length + max_len = data_loader_config.max_prompt_token_length + data_loader_config.response_length tool_objects = {} if args.tools: for tool in args.tools: @@ -1809,12 +1807,12 @@ def create_model_and_optimizer( return policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims -def create_generation_configs(args: Args): +def create_generation_configs(args: Args, streaming_config: streaming_data_loader.StreamingDataLoaderConfig): """Create generation configs for training and evaluation.""" generation_config = vllm.SamplingParams( temperature=args.temperature, top_p=args.vllm_top_p, # prevent rare out-of-vocab tokens with qwen - max_tokens=args.response_length, + max_tokens=streaming_config.response_length, include_stop_str_in_output=True, skip_special_tokens=False, n=args.num_samples_per_prompt_rollout, @@ -2486,7 +2484,7 @@ def main( beaker_config, wandb_url = setup_experiment_tracking(args, tc, model_config) - train_dataset, eval_dataset = setup_datasets(args, tc, tokenizer) + train_dataset, eval_dataset = setup_datasets(args, tc, tokenizer, streaming_config) if len(train_dataset) < (needed := max(args.async_steps, 1) * args.num_unique_prompts_rollout): raise ValueError( @@ -2513,7 +2511,7 @@ def main( # Create dataloader dependencies before model creation pending_queries_map = PendingQueriesMap() reward_fn = make_reward_fn(args) - generation_configs = create_generation_configs(args) + generation_configs = create_generation_configs(args, streaming_config) (policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims) = ( create_model_and_optimizer( From c2b96fabe3ac3f35c812691120ec565571c599dd Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 17 Nov 2025 21:36:47 -0700 Subject: [PATCH 14/96] Move validation to StreamingDataLoaderConfig.__post_init__ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added async_steps and num_samples_per_prompt_rollout fields to StreamingDataLoaderConfig and moved the validation logic there. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 16 ---------------- open_instruct/streaming_data_loader.py | 18 ++++++++++++++++++ 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 6e1b713cf..7e5c1adea 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -495,22 +495,6 @@ def __post_init__(self): if self.apply_r1_style_format_reward and self.additive_format_reward: self.max_possible_score += self.r1_style_format_reward - if self.active_sampling: - assert self.async_steps > 1, ( - "With active_sampling, you should set async_steps > 1 to account for filtering of the first batch. " - "Otherwise, your generator only generates only one batch worth of prompts and a single filtered " - "prompt will cause the trainer to stall waiting for more data . " - ) - assert self.filter_zero_std_samples, ( - "filter_zero_std_samples must be True when active_sampling is True. " - "Active sampling requires filtering to work correctly." - ) - if self.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples: - raise ValueError( - "`filter_zero_std_samples` cannot be True when `num_samples_per_prompt_rollout` is 1, " - "as the reward standard deviation will always be 0, causing all samples to be filtered." - ) - def masked_mean( values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index 49762440a..f6920f887 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -49,6 +49,8 @@ class StreamingDataLoaderConfig: max_prompt_token_length: int = 256 response_length: int = 256 + async_steps: int = 1 + num_samples_per_prompt_rollout: int = 4 active_sampling: bool = False filter_zero_std_samples: bool = True no_resampling_pass_rate: float | None = None @@ -61,6 +63,22 @@ def __post_init__(self): "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!" ) + if self.active_sampling: + assert self.async_steps > 1, ( + "With active_sampling, you should set async_steps > 1 to account for filtering of the first batch. " + "Otherwise, your generator only generates only one batch worth of prompts and a single filtered " + "prompt will cause the trainer to stall waiting for more data . " + ) + assert self.filter_zero_std_samples, ( + "filter_zero_std_samples must be True when active_sampling is True. " + "Active sampling requires filtering to work correctly." + ) + if self.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples: + raise ValueError( + "`filter_zero_std_samples` cannot be True when `num_samples_per_prompt_rollout` is 1, " + "as the reward standard deviation will always be 0, causing all samples to be filtered." + ) + def build( self, dataset: Dataset, From 56823d91cecf6ad416fba5282d4f31ffc8d9050b Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 11:23:40 -0700 Subject: [PATCH 15/96] Move async_steps and num_samples_per_prompt_rollout to StreamingDataLoaderConfig MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactored to access these values directly from config instead of passing them as parameters. Updated function signatures to pass streaming_config where needed. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 36 ++++++++++++-------------- open_instruct/streaming_data_loader.py | 26 +++++++------------ 2 files changed, 26 insertions(+), 36 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 7e5c1adea..1babead29 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -212,8 +212,6 @@ class Args: """the sampling temperature""" num_unique_prompts_rollout: int = 16 """The number of unique prompts during rollout""" - num_samples_per_prompt_rollout: int = 4 - """the number of samples to generate per prompt during rollout, useful for easy-star""" stop_strings: list[str] | None = None """List of strings that stop the generation when they are generated. The returned output will not contain the stop strings.""" @@ -221,8 +219,6 @@ class Args: """Whether to use fp8 kv cache. This is useful for larger models or olmo.""" # Algorithm - async_steps: int = 1 - """Number of steps ahead to generate responses. Set to 0 to make the code synchronous. Values greater than 0 learn from a policy up to async_steps old like Cleanba (https://arxiv.org/abs/2310.00036)""" num_epochs: int = 1 """the number of epochs to train""" num_mini_batches: int = 1 @@ -631,8 +627,6 @@ def __init__( fs_local_rank=self.local_rank, num_training_steps=args.num_training_steps, seed=args.seed, - async_steps=args.async_steps, - num_samples_per_prompt_rollout=args.num_samples_per_prompt_rollout, per_device_train_batch_size=args.per_device_train_batch_size, verbose=args.verbose, work_dir=args.output_dir, @@ -1551,7 +1545,7 @@ def calculate_utilization_metrics( return utilization_metrics -def setup_runtime_variables(args: Args) -> Args: +def setup_runtime_variables(args: Args, streaming_config: streaming_data_loader.StreamingDataLoaderConfig) -> Args: """Set up runtime variables for the experiment.""" args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" args.output_dir = os.path.join(args.output_dir, args.run_name) @@ -1560,7 +1554,7 @@ def setup_runtime_variables(args: Args) -> Args: args.dataset_local_cache_dir = "/weka/oe-adapt-default/allennlp/deletable_open_instruct_dataset_cache" args.world_size = sum(args.num_learners_per_node) args.num_training_steps = args.total_episodes // ( - args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout + args.num_unique_prompts_rollout * streaming_config.num_samples_per_prompt_rollout ) args.try_launch_beaker_eval_jobs_on_weka = args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job() if args.push_to_hub: @@ -1779,7 +1773,7 @@ def create_model_and_optimizer( results, _ = ray_get_with_progress(inits, desc="Initializing models") resume_training_step = results[0] + 1 - episode = (resume_training_step - 1) * args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout + episode = (resume_training_step - 1) * args.num_unique_prompts_rollout * data_loader_config.num_samples_per_prompt_rollout logger.info("======== ✅ all models initialized =========") ray_get_with_progress( @@ -1799,7 +1793,7 @@ def create_generation_configs(args: Args, streaming_config: streaming_data_loade max_tokens=streaming_config.response_length, include_stop_str_in_output=True, skip_special_tokens=False, - n=args.num_samples_per_prompt_rollout, + n=streaming_config.num_samples_per_prompt_rollout, stop=args.stop_strings, seed=args.seed, logprobs=1, # Enable logprobs to compare with local calculations @@ -1878,6 +1872,7 @@ def weight_sync_thread( def one_training_step( args: Args, + streaming_config: streaming_data_loader.StreamingDataLoaderConfig, policy_group: ModelGroup, tokenizer: PreTrainedTokenizer, data_thread_metrics: dict[str, Any], @@ -1949,7 +1944,7 @@ def one_training_step( prompt_lengths=prompt_lengths, response_lengths=response_lengths, total_generation_time=total_generation_time, - samples_per_prompt=args.num_samples_per_prompt_rollout, + samples_per_prompt=streaming_config.num_samples_per_prompt_rollout, num_engines=args.vllm_num_engines, num_gpus_per_engine=args.vllm_tensor_parallel_size, training_time=train_timer.duration, @@ -1962,7 +1957,7 @@ def one_training_step( "training_step": training_step, "val/num_total_tokens": num_total_tokens, "val/num_step_tokens": num_step_tokens, - "epoch": episode / args.num_samples_per_prompt_rollout / len(train_dataset), + "epoch": episode / streaming_config.num_samples_per_prompt_rollout / len(train_dataset), "learner_tokens_per_second_overall": num_total_tokens / total_training_time, "learner_tokens_per_second_step": num_step_tokens / step_time, "time/total": step_time, @@ -2270,6 +2265,7 @@ def cleanup_training_resources( def run_training( args, + streaming_config, tokenizer, train_dataset, eval_dataset, @@ -2370,7 +2366,7 @@ def health_check_fn(): is_eval=True, ) - episode += args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout + episode += args.num_unique_prompts_rollout * streaming_config.num_samples_per_prompt_rollout data_thread_metrics = {} for metrics_Q in [generate_metrics_Q, weight_sync_metrics_Q]: @@ -2387,6 +2383,7 @@ def health_check_fn(): one_training_step( args, + streaming_config, policy_group, tokenizer, data_thread_metrics, @@ -2459,7 +2456,7 @@ def main( streaming_config: streaming_data_loader.StreamingDataLoaderConfig, ): tokenizer = make_tokenizer(tc, model_config) - args = setup_runtime_variables(args) + args = setup_runtime_variables(args, streaming_config) if args.verbose: logging.getLogger().setLevel(logging.DEBUG) @@ -2470,7 +2467,7 @@ def main( train_dataset, eval_dataset = setup_datasets(args, tc, tokenizer, streaming_config) - if len(train_dataset) < (needed := max(args.async_steps, 1) * args.num_unique_prompts_rollout): + if len(train_dataset) < (needed := max(streaming_config.async_steps, 1) * args.num_unique_prompts_rollout): raise ValueError( f"Train dataset is too small! Is {len(train_dataset)} prompts, but {needed} are needed to have enough prompts for bsz and prefill. Try reducing async_steps or num_unique_prompts_rollout, or increasing the dataset size." ) @@ -2486,7 +2483,7 @@ def main( # Create Ray queues. # Since we now send/receive individual prompts, queue size should accommodate # all prompts from async_steps + 1 training steps - queue_size = (args.async_steps + 1) * args.num_unique_prompts_rollout + queue_size = (streaming_config.async_steps + 1) * args.num_unique_prompts_rollout inference_results_Q = ray_queue.Queue(maxsize=queue_size) param_prompt_Q = ray_queue.Queue(maxsize=queue_size) # We don't care if we ever hit the max, so we let the queue be unbounded. @@ -2528,10 +2525,10 @@ def main( logger.info(f"Restored episode count: {episode}") # Create additional queues (main queues already created above) - packed_sequences_Q = Queue(maxsize=args.async_steps) + packed_sequences_Q = Queue(maxsize=streaming_config.async_steps) eval_pending_queries_map = PendingQueriesMap() - generate_metrics_Q = Queue(maxsize=args.async_steps) - weight_sync_metrics_Q = Queue(maxsize=args.async_steps) + generate_metrics_Q = Queue(maxsize=streaming_config.async_steps) + weight_sync_metrics_Q = Queue(maxsize=streaming_config.async_steps) stop_event = threading.Event() executor = futures.ThreadPoolExecutor(max_workers=3, thread_name_prefix="grpo") @@ -2539,6 +2536,7 @@ def main( try: episode = run_training( args, + streaming_config, tokenizer, train_dataset, eval_dataset, diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index f6920f887..25c54f882 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -64,7 +64,7 @@ def __post_init__(self): ) if self.active_sampling: - assert self.async_steps > 1, ( + assert self.config.async_steps > 1, ( "With active_sampling, you should set async_steps > 1 to account for filtering of the first batch. " "Otherwise, your generator only generates only one batch worth of prompts and a single filtered " "prompt will cause the trainer to stall waiting for more data . " @@ -73,7 +73,7 @@ def __post_init__(self): "filter_zero_std_samples must be True when active_sampling is True. " "Active sampling requires filtering to work correctly." ) - if self.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples: + if self.config.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples: raise ValueError( "`filter_zero_std_samples` cannot be True when `num_samples_per_prompt_rollout` is 1, " "as the reward standard deviation will always be 0, causing all samples to be filtered." @@ -92,8 +92,6 @@ def build( fs_local_rank: int, num_training_steps: int, seed: int, - async_steps: int, - num_samples_per_prompt_rollout: int, per_device_train_batch_size: int, verbose: bool, work_dir: Path | str, @@ -116,8 +114,6 @@ def build( global_batch_size=global_batch_size, num_training_steps=num_training_steps, seed=seed, - async_steps=async_steps, - num_samples_per_prompt_rollout=num_samples_per_prompt_rollout, per_device_train_batch_size=per_device_train_batch_size, verbose=verbose, max_possible_score=max_possible_score, @@ -301,8 +297,6 @@ def __init__( global_batch_size: int, num_training_steps: int = 0, seed: int, - async_steps: int, - num_samples_per_prompt_rollout: int, per_device_train_batch_size: int, verbose: bool, max_possible_score: float, @@ -333,8 +327,6 @@ def __init__( self.actor_manager = actor_manager self.model_dims = model_dims - self.async_steps = async_steps - self.num_samples_per_prompt_rollout = num_samples_per_prompt_rollout self.per_device_train_batch_size = per_device_train_batch_size self.verbose = verbose @@ -344,7 +336,7 @@ def __init__( dataset_indices = np.arange(len(dataset)) self.iter_dataloader = ShufflingIterator(dataset_indices, 1, seed=seed + dp_rank) - self.local_queue = StdQueue(maxsize=async_steps) + self.local_queue = StdQueue(maxsize=config.async_steps) self.background_thread = None self.shutdown_requested = False @@ -403,7 +395,7 @@ def _start_background_thread(self): self.background_thread.start() def _data_preparation_loop(self): - for _ in range(self.async_steps * self.global_batch_size // self.dp_world_size): + for _ in range(self.config.async_steps * self.global_batch_size // self.dp_world_size): dataset_index = next(self.iter_dataloader) add_prompt_to_generator( self.dataset[dataset_index], @@ -456,11 +448,11 @@ def _data_preparation_loop(self): and not result.request_info.tool_errors[i] for i in range(len(result.request_info.tool_outputs)) ] - scores_per_prompt = scores.reshape(-1, self.num_samples_per_prompt_rollout) + scores_per_prompt = scores.reshape(-1, self.config.num_samples_per_prompt_rollout) mean_grouped_rewards = scores_per_prompt.mean(axis=-1) - mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.num_samples_per_prompt_rollout, axis=0) + mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) std_grouped_rewards = scores_per_prompt.std(axis=-1) - std_grouped_rewards = np.repeat(std_grouped_rewards, self.num_samples_per_prompt_rollout, axis=0) + std_grouped_rewards = np.repeat(std_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) if self.config.advantage_normalization_type == "standard": advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) elif self.config.advantage_normalization_type == "centered": @@ -510,7 +502,7 @@ def _data_preparation_loop(self): logger.warning(f"No responses in batch {training_step}.") else: real_num_responses = len(result.responses) - expected_num_responses = self.num_samples_per_prompt_rollout * self.global_batch_size + expected_num_responses = self.config.num_samples_per_prompt_rollout * self.global_batch_size unsolved_num_responses = (scores < self.config.max_possible_score).sum() sequence_lengths = np.array([len(response) for response in result.responses]) @@ -537,7 +529,7 @@ def _data_preparation_loop(self): "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, "val/solve_rate_hist": None, - "val/total_reward_groups": real_num_responses / self.num_samples_per_prompt_rollout, + "val/total_reward_groups": real_num_responses / self.config.num_samples_per_prompt_rollout, "val/sequence_lengths": sequence_lengths.mean(), "val/sequence_lengths_min": sequence_lengths.min(), "val/sequence_lengths_max": sequence_lengths.max(), From 6e562d83338537d2931ceb762480af4f439a0a55 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 11:25:27 -0700 Subject: [PATCH 16/96] Fix validation references in __post_init__ MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Move num_samples_per_prompt_rollout validation to StreamingDataLoaderConfig and fix references to use self instead of self.config. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 3 --- open_instruct/streaming_data_loader.py | 7 +++++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 1babead29..7aebdb5a8 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -431,9 +431,6 @@ def __post_init__(self): assert self.masked_mean_denominator > 0, ( f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!" ) - assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" - if self.num_samples_per_prompt_rollout == 1: - logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") assert self.apply_verifiable_reward or self.apply_r1_style_format_reward or self.non_stop_penalty, ( "At least one reward must be applied!" ) diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index 25c54f882..82b23d1bb 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -62,9 +62,12 @@ def __post_init__(self): assert self.pack_length >= self.max_prompt_token_length + self.response_length, ( "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!" ) + assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" + if self.num_samples_per_prompt_rollout == 1: + logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") if self.active_sampling: - assert self.config.async_steps > 1, ( + assert self.async_steps > 1, ( "With active_sampling, you should set async_steps > 1 to account for filtering of the first batch. " "Otherwise, your generator only generates only one batch worth of prompts and a single filtered " "prompt will cause the trainer to stall waiting for more data . " @@ -73,7 +76,7 @@ def __post_init__(self): "filter_zero_std_samples must be True when active_sampling is True. " "Active sampling requires filtering to work correctly." ) - if self.config.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples: + if self.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples: raise ValueError( "`filter_zero_std_samples` cannot be True when `num_samples_per_prompt_rollout` is 1, " "as the reward standard deviation will always be 0, causing all samples to be filtered." From 77ed3dda9f6151aed0a4ab0920cdf51f42239326 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 11:27:11 -0700 Subject: [PATCH 17/96] Move inference_batch_size computation to setup_runtime_variables MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Moved the inference_batch_size default calculation from Args.__post_init__ to setup_runtime_variables where we have access to streaming_config. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 7aebdb5a8..bb3565685 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -444,9 +444,6 @@ def __post_init__(self): # Initialize stop_strings if None if self.stop_strings is None: self.stop_strings = [] - if self.inference_batch_size is None: - total_prompts = self.num_samples_per_prompt_rollout * self.num_unique_prompts_rollout - self.inference_batch_size = max(1, math.ceil(total_prompts / self.vllm_num_engines)) if self.checkpoint_state_freq > 0 and self.checkpoint_state_dir is None: raise ValueError("`checkpoint_state_dir` must be provided if `checkpoint_state_freq` is greater than 0!") if self.checkpoint_state_dir is not None and self.checkpoint_state_freq == -1: @@ -1553,6 +1550,9 @@ def setup_runtime_variables(args: Args, streaming_config: streaming_data_loader. args.num_training_steps = args.total_episodes // ( args.num_unique_prompts_rollout * streaming_config.num_samples_per_prompt_rollout ) + if args.inference_batch_size is None: + total_prompts = streaming_config.num_samples_per_prompt_rollout * args.num_unique_prompts_rollout + args.inference_batch_size = max(1, math.ceil(total_prompts / args.vllm_num_engines)) args.try_launch_beaker_eval_jobs_on_weka = args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job() if args.push_to_hub: if args.hf_repo_id is None: # auto-generate one From 3544474729aec10897e0aff73624ca7d7b358f57 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 11:28:13 -0700 Subject: [PATCH 18/96] Format code with ruff MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index bb3565685..16c6b5266 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1597,7 +1597,10 @@ def setup_experiment_tracking(args: Args, tc: TokenizerConfig, model_config: Mod def setup_datasets( - args: Args, tc: TokenizerConfig, tokenizer: PreTrainedTokenizer, streaming_config: streaming_data_loader.StreamingDataLoaderConfig + args: Args, + tc: TokenizerConfig, + tokenizer: PreTrainedTokenizer, + streaming_config: streaming_data_loader.StreamingDataLoaderConfig, ): """Set up training and evaluation datasets.""" system_prompt_override = None @@ -1770,7 +1773,11 @@ def create_model_and_optimizer( results, _ = ray_get_with_progress(inits, desc="Initializing models") resume_training_step = results[0] + 1 - episode = (resume_training_step - 1) * args.num_unique_prompts_rollout * data_loader_config.num_samples_per_prompt_rollout + episode = ( + (resume_training_step - 1) + * args.num_unique_prompts_rollout + * data_loader_config.num_samples_per_prompt_rollout + ) logger.info("======== ✅ all models initialized =========") ray_get_with_progress( From 5273b09e92f9060847270c558847d14ecfc0f0c8 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 12:10:41 -0700 Subject: [PATCH 19/96] fixed code. --- open_instruct/grpo_fast.py | 12 ------------ open_instruct/streaming_data_loader.py | 5 +---- open_instruct/test_grpo_fast.py | 3 --- 3 files changed, 1 insertion(+), 19 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 16c6b5266..72be9818d 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -599,7 +599,6 @@ def __init__( reward_fn: Callable, inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, - pending_queries_map: dict, tokenizer: PreTrainedTokenizer, generation_config, actor_manager, @@ -614,7 +613,6 @@ def __init__( reward_fn=reward_fn, inference_results_Q=inference_results_Q, param_prompt_Q=param_prompt_Q, - pending_queries_map=pending_queries_map, tokenizer=tokenizer, generation_config=generation_config, dp_rank=self.local_rank, @@ -1392,7 +1390,6 @@ def __init__( reward_fn: Callable, inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, - pending_queries_map: dict, tokenizer: PreTrainedTokenizer, generation_config, actor_manager, @@ -1423,7 +1420,6 @@ def __init__( reward_fn, inference_results_Q, param_prompt_Q, - pending_queries_map, tokenizer, generation_config, actor_manager, @@ -1473,7 +1469,6 @@ def get_bundle_index(rank, num_gpus_per_node): reward_fn, inference_results_Q, param_prompt_Q, - pending_queries_map, tokenizer, generation_config, actor_manager, @@ -1665,7 +1660,6 @@ def create_model_and_optimizer( data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, train_dataset: Dataset, reward_fn: Callable, - pending_queries_map: dict, generation_config, ) -> tuple[ModelGroup, list[vllm_utils.LLMRayActor], dict, int, int, ActorManager, utils.ModelDims]: """Create the model, optimizer, and vLLM engines.""" @@ -1759,7 +1753,6 @@ def create_model_and_optimizer( reward_fn=reward_fn, inference_results_Q=inference_results_Q, param_prompt_Q=param_prompt_Q, - pending_queries_map=pending_queries_map, tokenizer=tokenizer, generation_config=generation_config, actor_manager=actor_manager, @@ -2287,7 +2280,6 @@ def run_training( param_prompt_Q, evaluation_inference_results_Q, packed_sequences_Q, - pending_queries_map, eval_pending_queries_map, generate_metrics_Q, weight_sync_metrics_Q, @@ -2493,8 +2485,6 @@ def main( # We don't care if we ever hit the max, so we let the queue be unbounded. evaluation_inference_results_Q = ray_queue.Queue() - # Create dataloader dependencies before model creation - pending_queries_map = PendingQueriesMap() reward_fn = make_reward_fn(args) generation_configs = create_generation_configs(args, streaming_config) @@ -2512,7 +2502,6 @@ def main( streaming_config, train_dataset, reward_fn, - pending_queries_map, generation_configs["train"], ) ) @@ -2558,7 +2547,6 @@ def main( param_prompt_Q, evaluation_inference_results_Q, packed_sequences_Q, - pending_queries_map, eval_pending_queries_map, generate_metrics_Q, weight_sync_metrics_Q, diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index 82b23d1bb..753c297c8 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -88,7 +88,6 @@ def build( reward_fn: Callable, inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, - pending_queries_map: dict, tokenizer: PreTrainedTokenizer, generation_config: Any, dp_rank: int, @@ -109,7 +108,6 @@ def build( reward_fn=reward_fn, inference_results_Q=inference_results_Q, param_prompt_Q=param_prompt_Q, - pending_queries_map=pending_queries_map, tokenizer=tokenizer, config=self, generation_config=generation_config, @@ -292,7 +290,6 @@ def __init__( reward_fn: Callable, inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, - pending_queries_map: dict, tokenizer: PreTrainedTokenizer, config: StreamingDataLoaderConfig, generation_config: Any, @@ -321,7 +318,7 @@ def __init__( self.reward_fn = reward_fn self.inference_results_Q = inference_results_Q self.param_prompt_Q = param_prompt_Q - self.pending_queries_map = pending_queries_map + self.pending_queries_map = PendingQueriesMap() self.tokenizer = tokenizer self.config = config self.config.max_possible_score = max_possible_score diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index d39e9208d..149ce51ad 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -710,8 +710,6 @@ def test_accumulate_waits_for_all_engines(self): mock_result = self.create_mock_result(i, 1) inference_results_Q.put(mock_result) - mock_args = self.create_mock_args(num_engines) - completed = threading.Event() def run_accumulate(): @@ -724,7 +722,6 @@ def run_accumulate(): grpo_fast.accumulate_inference_batches( inference_results_Q, pending_queries_map, - mock_args, generation_config=mock_generation_config, num_prompts=num_prompts, model_dims=mock_model_dims, From f71f3d2e268de3b7b75894f5223c34ee02feb9f0 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 12:27:19 -0700 Subject: [PATCH 20/96] fixed bug in metrics --- open_instruct/grpo_fast.py | 5 ++++- open_instruct/streaming_data_loader.py | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 72be9818d..a8b640fc7 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -933,6 +933,7 @@ def update_ref_policy(self): def step(self): batch_data = next(self.dataloader) + batch_metrics = batch_data["metrics"] collated_query_responses = batch_data["collated_query_responses"] collated_tool_masks = batch_data["collated_tool_masks"] collated_attention_masks = batch_data["collated_attention_masks"] @@ -1245,6 +1246,8 @@ def step(self): if args.record_entropy: self.local_metrics["policy/entropy_avg"] = entropy_stats.mean() self.local_metrics["lr"] = self.scheduler.get_last_lr()[0] + for key, value in batch_metrics.items(): + self.local_metrics[key] = value return self.local_metrics.get_metrics_list() def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: dict[str, Any]) -> None: @@ -1934,7 +1937,7 @@ def one_training_step( step_time = time.perf_counter() - start_time total_training_time = time.perf_counter() - training_start_time - total_generation_time = data_thread_metrics["time/getting_response"] + total_generation_time = average_metrics["time/getting_response"] utilization_metrics = calculate_utilization_metrics( model_dims=model_dims, diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index 753c297c8..342e3531c 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -562,6 +562,7 @@ def _data_preparation_loop(self): total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time + collated_data["metrics"] = metrics self.local_queue.put(collated_data) def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> dict[str, list[torch.Tensor]]: From fe2aec07567d5387dfc9ec28791aba988ad329c0 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 12:36:44 -0700 Subject: [PATCH 21/96] Fixed error --- open_instruct/grpo_fast.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index a8b640fc7..edfb251fa 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1247,7 +1247,8 @@ def step(self): self.local_metrics["policy/entropy_avg"] = entropy_stats.mean() self.local_metrics["lr"] = self.scheduler.get_last_lr()[0] for key, value in batch_metrics.items(): - self.local_metrics[key] = value + if value is not None: + self.local_metrics[key] = value return self.local_metrics.get_metrics_list() def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: dict[str, Any]) -> None: From cb9ae424eda1deebb5c5458eb78e6ac7fe414635 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 12:53:26 -0700 Subject: [PATCH 22/96] added percent_solved_hist --- open_instruct/streaming_data_loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index 342e3531c..01498dbf3 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -528,7 +528,7 @@ def _data_preparation_loop(self): "real_batch_size_ratio": real_num_responses / expected_num_responses, "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, - "val/solve_rate_hist": None, + "val/solve_rate_hist": batch_stats.percent_solved_hist, "val/total_reward_groups": real_num_responses / self.config.num_samples_per_prompt_rollout, "val/sequence_lengths": sequence_lengths.mean(), "val/sequence_lengths_min": sequence_lengths.min(), @@ -636,6 +636,7 @@ class BatchStatistics: filtered_prompts_solved: int filtered_prompts_nonzero: int percent_solved_mean: float + percent_solved_hist: np.ndarray no_resampled_prompts: int total_prompts: int @@ -958,6 +959,7 @@ def accumulate_inference_batches( filtered_prompts_solved=filtered_prompt_solved, filtered_prompts_nonzero=filtered_prompt_nonzero, percent_solved_mean=percent_solved_mean, + percent_solved_hist=np.array(all_percent_solved), no_resampled_prompts=total_no_resampled, total_prompts=len(results), ) From 4059fe53912977568689dc8e9e4de552a6294690 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 13:15:19 -0700 Subject: [PATCH 23/96] Fixed metrics --- open_instruct/grpo_fast.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index edfb251fa..bdf7b301e 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1246,10 +1246,15 @@ def step(self): if args.record_entropy: self.local_metrics["policy/entropy_avg"] = entropy_stats.mean() self.local_metrics["lr"] = self.scheduler.get_last_lr()[0] + array_metrics = {} for key, value in batch_metrics.items(): - if value is not None: + if value is None: + continue + if isinstance(value, np.ndarray): + array_metrics[key] = value + else: self.local_metrics[key] = value - return self.local_metrics.get_metrics_list() + return self.local_metrics.get_metrics_list(), array_metrics def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: dict[str, Any]) -> None: args = self.args @@ -1894,10 +1899,12 @@ def one_training_step( """Train the model for one step.""" update_ref_policy_future = [] with Timer("[Main Thread] 🗡️ Training") as train_timer: - metrics_list, _ = ray_get_with_progress( + results, _ = ray_get_with_progress( [policy_group.models[i].step.remote() for i in range(args.world_size)], desc=f"Running training step {training_step}", ) + metrics_list = [r[0] for r in results] + array_metrics_list = [r[1] for r in results] if ( args.ref_policy_update_freq is not None and training_step % args.ref_policy_update_freq == 0 @@ -1935,6 +1942,8 @@ def one_training_step( ray.get(actor_manager.report_training_step_time.remote(train_timer.duration)) average_metrics = {k: sum(m[k] for m in metrics_list) / len(metrics_list) for k in metrics_list[0]} + for key, value in array_metrics_list[0].items(): + average_metrics[key] = value step_time = time.perf_counter() - start_time total_training_time = time.perf_counter() - training_start_time From 346995c6c27ca8a0e4dfa6da9d6313484c80dddd Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 13:37:18 -0700 Subject: [PATCH 24/96] fixed metrics --- open_instruct/grpo_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index bdf7b301e..07b5b2f71 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -806,7 +806,7 @@ def load(self, path: str, map_location=None): else: self.ref_policy.load_state_dict(state_dict) logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}") - self.local_metrics = utils.MetricsTracker(device=self.device) + self.local_metrics = utils.MetricsTracker(max_metrics=64, device=self.device) return optimization_steps_done def forward( From 144a86fba89b24bd878a98307ba59f2a284abe92 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 15:29:34 -0700 Subject: [PATCH 25/96] updated code --- open_instruct/grpo_fast.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 07b5b2f71..3c45e8690 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1250,10 +1250,10 @@ def step(self): for key, value in batch_metrics.items(): if value is None: continue - if isinstance(value, np.ndarray): - array_metrics[key] = value - else: + if isinstance(value, (int, float, np.floating, np.integer)): self.local_metrics[key] = value + else: + array_metrics[key] = value return self.local_metrics.get_metrics_list(), array_metrics def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: dict[str, Any]) -> None: From 19645a27ee644388c1e47dc6b091ca90215e74e3 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 15:34:29 -0700 Subject: [PATCH 26/96] updated the code to remove extra args param. --- open_instruct/grpo_fast.py | 1 - 1 file changed, 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3c45e8690..ba15e4e67 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -2013,7 +2013,6 @@ def maybe_evaluate( eval_result, eval_batch, eval_reward_metrics, _ = accumulate_inference_batches( evaluation_inference_results_Q, eval_pending_queries_map, - args, eval_generation_config, num_prompts=num_eval_prompts, model_dims=model_dims, From e97335c1cbeaa247f72f034f5e4c5af72f5c9e41 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 16:25:53 -0700 Subject: [PATCH 27/96] fixed tests --- open_instruct/grpo_fast.py | 30 ++----- open_instruct/streaming_data_loader.py | 26 +++--- open_instruct/test_grpo_fast.py | 120 ++++++++++++------------- 3 files changed, 76 insertions(+), 100 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index ba15e4e67..6170718d1 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -37,12 +37,7 @@ import deepspeed from open_instruct import streaming_data_loader, utils -from open_instruct.streaming_data_loader import ( - PendingQueriesMap, - accumulate_inference_batches, - add_prompt_to_generator, - collate_fn, -) +from open_instruct.streaming_data_loader import accumulate_inference_batches, add_prompt_to_generator, collate_fn # isort: on import asyncio @@ -1885,15 +1880,12 @@ def one_training_step( episode: int, training_step: int, num_total_tokens: int, - num_step_tokens: int, start_time: float, train_dataset: datasets.Dataset, training_start_time: float, wandb_url: str, chat_template_name: str, model_dims: utils.ModelDims, - prompt_lengths: list[int], - response_lengths: list[int], actor_manager: ActorManager | None = None, ) -> None: """Train the model for one step.""" @@ -1948,6 +1940,9 @@ def one_training_step( total_training_time = time.perf_counter() - training_start_time total_generation_time = average_metrics["time/getting_response"] + prompt_lengths = array_metrics_list[0]["batch/prompt_lengths"] + response_lengths = array_metrics_list[0]["batch/response_lengths"] + num_step_tokens = sum(prompt_lengths) + sum(response_lengths) utilization_metrics = calculate_utilization_metrics( model_dims=model_dims, @@ -1996,7 +1991,7 @@ def maybe_evaluate( tokenizer, reward_fn, episode, - eval_pending_queries_map: PendingQueriesMap, + eval_dataset, eval_generation_config, generate_metrics_Q: Queue, num_eval_prompts: int, @@ -2012,12 +2007,12 @@ def maybe_evaluate( # Accumulate evaluation results from all vLLM engines eval_result, eval_batch, eval_reward_metrics, _ = accumulate_inference_batches( evaluation_inference_results_Q, - eval_pending_queries_map, eval_generation_config, num_prompts=num_eval_prompts, model_dims=model_dims, tokenizer=tokenizer, reward_fn=reward_fn, + dataset=eval_dataset, actor_manager=actor_manager, timeout=timeout, active_sampling=False, @@ -2292,7 +2287,6 @@ def run_training( param_prompt_Q, evaluation_inference_results_Q, packed_sequences_Q, - eval_pending_queries_map, generate_metrics_Q, weight_sync_metrics_Q, actor_manager: ActorManager, @@ -2368,7 +2362,6 @@ def health_check_fn(): eval_index, 0, training_step, - eval_pending_queries_map, param_prompt_Q, generation_configs["eval"], is_eval=True, @@ -2385,10 +2378,6 @@ def health_check_fn(): data_thread_metrics["time/health_check"] = health_check_time - num_step_tokens = 0 - prompt_lengths = [] - response_lengths = [] - one_training_step( args, streaming_config, @@ -2398,15 +2387,12 @@ def health_check_fn(): episode, training_step, num_total_tokens, - num_step_tokens, start_time, train_dataset, training_start_time, wandb_url, tc.chat_template_name, model_dims, - prompt_lengths, - response_lengths, actor_manager, ) @@ -2443,7 +2429,7 @@ def health_check_fn(): tokenizer, reward_fn, episode, - eval_pending_queries_map, + eval_dataset, generation_configs["eval"], generate_metrics_Q, len(eval_dataset) if eval_dataset else 0, @@ -2531,7 +2517,6 @@ def main( # Create additional queues (main queues already created above) packed_sequences_Q = Queue(maxsize=streaming_config.async_steps) - eval_pending_queries_map = PendingQueriesMap() generate_metrics_Q = Queue(maxsize=streaming_config.async_steps) weight_sync_metrics_Q = Queue(maxsize=streaming_config.async_steps) @@ -2559,7 +2544,6 @@ def main( param_prompt_Q, evaluation_inference_results_Q, packed_sequences_Q, - eval_pending_queries_map, generate_metrics_Q, weight_sync_metrics_Q, actor_manager, diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index 01498dbf3..03f14919a 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -318,7 +318,6 @@ def __init__( self.reward_fn = reward_fn self.inference_results_Q = inference_results_Q self.param_prompt_Q = param_prompt_Q - self.pending_queries_map = PendingQueriesMap() self.tokenizer = tokenizer self.config = config self.config.max_possible_score = max_possible_score @@ -402,7 +401,6 @@ def _data_preparation_loop(self): dataset_index, self.iter_dataloader.epoch_number, self.training_step, - self.pending_queries_map, self.param_prompt_Q, self.generation_config, is_eval=False, @@ -416,19 +414,18 @@ def _data_preparation_loop(self): with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: result, batch, reward_metrics, batch_stats = accumulate_inference_batches( self.inference_results_Q, - self.pending_queries_map, self.generation_config, num_prompts=self.rank_batch_size, model_dims=self.model_dims, tokenizer=self.tokenizer, reward_fn=self.reward_fn, + dataset=self.dataset, actor_manager=self.actor_manager, active_sampling=self.config.active_sampling, filter_zero_std_samples=self.config.filter_zero_std_samples, replenish_prompts=True, no_resampling_pass_rate=self.config.no_resampling_pass_rate, iter_dataloader=self.iter_dataloader, - prompt_dataset=self.dataset, param_prompt_Q=self.param_prompt_Q, training_step=training_step, verbose=self.verbose, @@ -698,16 +695,11 @@ def add_prompt_to_generator( example_index: int, epoch_number: int, training_step: int, - pending_queries_map: PendingQueriesMap, param_prompt_Q: ray_queue.Queue, generation_config, is_eval: bool, ) -> None: query = example[INPUT_IDS_PROMPT_KEY] - ground_truth = example[GROUND_TRUTHS_KEY] - dataset_name = example[VERIFIER_SOURCE_KEY] - raw_query = example[RAW_PROMPT_KEY] - pending_queries_map.insert(example_index, query, ground_truth, dataset_name, raw_query) param_prompt_Q.put( PromptRequest( @@ -723,12 +715,12 @@ def add_prompt_to_generator( def accumulate_inference_batches( inference_results_Q: ray_queue.Queue, - pending_queries_map: PendingQueriesMap, generation_config: vllm.SamplingParams, num_prompts: int, model_dims: utils.ModelDims, tokenizer: PreTrainedTokenizer, reward_fn: Callable, + dataset: Dataset, actor_manager=None, timeout: float | None = None, active_sampling: bool = False, @@ -736,7 +728,6 @@ def accumulate_inference_batches( replenish_prompts: bool = False, no_resampling_pass_rate: float | None = None, iter_dataloader: ShufflingIterator | None = None, - prompt_dataset: Dataset = None, param_prompt_Q: ray_queue.Queue | None = None, training_step: int = None, verbose: bool = False, @@ -748,8 +739,8 @@ def accumulate_inference_batches( assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed" if replenish_prompts: - assert param_prompt_Q is not None and iter_dataloader is not None and prompt_dataset is not None, ( - "replenish_prompts requires param_prompt_Q and iter_dataloader and prompt_dataset" + assert param_prompt_Q is not None and iter_dataloader is not None and dataset is not None, ( + "replenish_prompts requires param_prompt_Q and iter_dataloader and dataset" ) results = [] @@ -785,16 +776,19 @@ def accumulate_inference_batches( f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}" ) - query, ground_truth, dataset_name, raw_query = pending_queries_map.pop(result.dataset_index) + example = dataset[result.dataset_index] + query = example[INPUT_IDS_PROMPT_KEY] + ground_truth = example[GROUND_TRUTHS_KEY] + dataset_name = example[VERIFIER_SOURCE_KEY] + raw_query = example[RAW_PROMPT_KEY] if replenish_prompts: dataset_index = next(iter_dataloader) add_prompt_to_generator( - prompt_dataset[dataset_index], + dataset[dataset_index], dataset_index, iter_dataloader.epoch_number, training_step, - pending_queries_map, param_prompt_Q, generation_config, is_eval=False, diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index 149ce51ad..6b05fc2c3 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -23,6 +23,7 @@ VERIFIER_SOURCE_KEY, ) from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics +from open_instruct.streaming_data_loader import PendingQueriesMap from open_instruct.vllm_utils import create_vllm_engines @@ -234,7 +235,6 @@ def setup_and_add_prompts_to_generator( queue_size = max(len(queries), num_engines * 2) param_prompt_Q = ray_queue.Queue(maxsize=queue_size) inference_results_Q = ray_queue.Queue(maxsize=queue_size) - pending_queries_map = grpo_fast.PendingQueriesMap() # Track queues for cleanup self._ray_queues.extend([param_prompt_Q, inference_results_Q]) @@ -247,25 +247,29 @@ def setup_and_add_prompts_to_generator( # Calculate inference_batch_size based on number of queries and engines mock_args.inference_batch_size = max(1, len(queries) // num_engines) - for index in range(len(queries)): + # Create a mock dataset that can be indexed by dataset_index + max_index = max(indices) + 1 + mock_dataset = [{} for _ in range(max_index)] + for i, index in enumerate(indices): + mock_dataset[index] = { + INPUT_IDS_PROMPT_KEY: queries[i], + GROUND_TRUTHS_KEY: ground_truths[i], + VERIFIER_SOURCE_KEY: datasets[i], + RAW_PROMPT_KEY: raw_queries[i], + } + + for i in range(len(queries)): example = { - INPUT_IDS_PROMPT_KEY: queries[index], - GROUND_TRUTHS_KEY: ground_truths[index], - VERIFIER_SOURCE_KEY: datasets[index], - RAW_PROMPT_KEY: raw_queries[index], + INPUT_IDS_PROMPT_KEY: queries[i], + GROUND_TRUTHS_KEY: ground_truths[i], + VERIFIER_SOURCE_KEY: datasets[i], + RAW_PROMPT_KEY: raw_queries[i], } grpo_fast.add_prompt_to_generator( - example, - indices[index], - 0, - training_step, - pending_queries_map, - param_prompt_Q, - mock_generation_config, - False, + example, indices[i], 0, training_step, param_prompt_Q, mock_generation_config, False ) - return param_prompt_Q, inference_results_Q, pending_queries_map + return param_prompt_Q, inference_results_Q, mock_dataset class TestGrpoFastVLLM(TestGrpoFastBase): @@ -350,13 +354,10 @@ def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int, ) # Setup and split batch - param_prompt_Q, inference_results_Q, pending_queries_map = self.setup_and_add_prompts_to_generator( + param_prompt_Q, inference_results_Q, mock_dataset = self.setup_and_add_prompts_to_generator( queries_next, ground_truths_next, datasets_next, raw_queries_next, dataset_indices, vllm_num_engines ) - # Verify that we have individual prompts in the map (not batches) - self.assertEqual(len(pending_queries_map), num_unique_prompts_rollout) - # Verify that we have the expected number of items in the queue (one per prompt) self.assertEqual(param_prompt_Q.qsize(), num_unique_prompts_rollout) @@ -383,8 +384,12 @@ def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int, result = inference_results_Q.get() dataset_index = result.dataset_index - # Get query from pending_queries_map - q, gt, d, raw_q = pending_queries_map.pop(dataset_index) + # Get query from mock_dataset + example = mock_dataset[dataset_index] + q = example[INPUT_IDS_PROMPT_KEY] + gt = example[GROUND_TRUTHS_KEY] + d = example[VERIFIER_SOURCE_KEY] + raw_q = example[RAW_PROMPT_KEY] combined_responses.extend(result.responses) combined_queries.append(q) @@ -418,9 +423,6 @@ def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int, self.assertEqual(len(combined_result.finish_reasons), len(queries_next)) self.assertEqual(len(combined_result.masks), len(queries_next)) - # Verify that the pending_queries_map is empty after accumulation - self.assertEqual(len(pending_queries_map), 0) - # Verify that the inference_results_Q is empty after accumulation self.assertEqual(inference_results_Q.qsize(), 0) @@ -435,7 +437,7 @@ def test_dataset_index_preservation_through_pipeline(self): ) # Setup and split batch - param_prompt_Q, inference_results_Q, pending_queries_map = self.setup_and_add_prompts_to_generator( + param_prompt_Q, inference_results_Q, mock_dataset = self.setup_and_add_prompts_to_generator( queries_next, ground_truths_next, datasets_next, raw_queries_next, dataset_indices, vllm_num_engines ) @@ -457,7 +459,11 @@ def test_dataset_index_preservation_through_pipeline(self): result = inference_results_Q.get() dataset_index = result.dataset_index - q, gt, d, raw_q = pending_queries_map.pop(dataset_index) + example = mock_dataset[dataset_index] + q = example[INPUT_IDS_PROMPT_KEY] + gt = example[GROUND_TRUTHS_KEY] + d = example[VERIFIER_SOURCE_KEY] + raw_q = example[RAW_PROMPT_KEY] combined_queries.append(q) combined_raw_queries.append(raw_q) combined_ground_truths.append(gt) @@ -467,7 +473,6 @@ def test_dataset_index_preservation_through_pipeline(self): self.assertEqual(combined_queries, queries_next) self.assertEqual(combined_ground_truths, ground_truths_next) self.assertEqual(combined_datasets, datasets_next) - self.assertEqual(len(pending_queries_map), 0) @parameterized.expand([(1, 16), (2, 8), (4, 4)]) def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_per_prompt: int): @@ -480,18 +485,10 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe ) # Setup and split batch - param_prompt_Q, inference_results_Q, pending_queries_map = self.setup_and_add_prompts_to_generator( + param_prompt_Q, inference_results_Q, mock_dataset = self.setup_and_add_prompts_to_generator( queries_next, ground_truths_next, datasets_next, raw_queries_next, dataset_indices, vllm_num_engines ) - # For multiple samples, we need to add additional references to the pending_queries_map - # The first reference is already added by setup_and_add_prompts_to_generator - for _ in range(num_samples_per_prompt - 1): - for idx, query, ground_truth, dataset, raw_query in zip( - dataset_indices, queries_next, ground_truths_next, datasets_next, raw_queries_next - ): - pending_queries_map.insert(idx, query, ground_truth, dataset, raw_query) - # Simulate vLLM processing with multiple samples batch_idx = 0 while not param_prompt_Q.empty(): @@ -511,11 +508,12 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe result = inference_results_Q.get() dataset_index = result.dataset_index - # Pop the query data for this specific result - pop multiple times for multiple samples - q, gt, d, raw_q = pending_queries_map.pop(dataset_index) - # Pop additional times to handle multiple samples per prompt - for _ in range(num_samples_per_prompt - 1): - pending_queries_map.pop(dataset_index) + # Get query data from mock_dataset + example = mock_dataset[dataset_index] + q = example[INPUT_IDS_PROMPT_KEY] + gt = example[GROUND_TRUTHS_KEY] + d = example[VERIFIER_SOURCE_KEY] + raw_q = example[RAW_PROMPT_KEY] combined_responses.extend(result.responses) combined_queries.append(q) @@ -542,7 +540,6 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe self.assertEqual(combined_queries, queries_next) self.assertEqual(combined_ground_truths, ground_truths_next) self.assertEqual(combined_datasets, datasets_next) - self.assertEqual(len(pending_queries_map), 0) # Verify correct number of responses expected_responses = num_unique_prompts_rollout * num_samples_per_prompt @@ -600,7 +597,7 @@ def test_out_of_order_processing(self): tokenizer, reward_fn = self.create_mock_tokenizer_and_reward_fn() # Setup and split batch - param_prompt_Q, inference_results_Q, pending_queries_map = self.setup_and_add_prompts_to_generator( + param_prompt_Q, inference_results_Q, mock_dataset = self.setup_and_add_prompts_to_generator( queries, ground_truths, datasets, raw_queries, indices, num_engines ) @@ -615,7 +612,6 @@ def test_out_of_order_processing(self): inference_results_Q.put(mock_result) # Accumulate results - mock_args = self.create_mock_args(num_engines, num_samples_per_prompt) # Create a mock generation config with n mock_generation_config = Mock() mock_generation_config.n = num_samples_per_prompt @@ -623,23 +619,21 @@ def test_out_of_order_processing(self): mock_model_dims = self.create_mock_model_dims() combined_result, batch, reward_metrics, batch_stats = grpo_fast.accumulate_inference_batches( inference_results_Q, - pending_queries_map, - mock_args, - generation_config=mock_generation_config, + mock_generation_config, num_prompts=num_prompts, model_dims=mock_model_dims, tokenizer=tokenizer, reward_fn=reward_fn, + dataset=mock_dataset, ) # Verify results work correctly even with out-of-order processing self.assertEqual(len(batch.queries), num_prompts * num_samples_per_prompt) self.assertEqual(len(combined_result.responses), num_prompts * num_samples_per_prompt) - self.assertEqual(len(pending_queries_map), 0) def test_thread_safety_pending_queries_map(self): """Test concurrent access to pending_queries_map.""" - pending_queries_map = grpo_fast.PendingQueriesMap() + pending_queries_map = PendingQueriesMap() errors = [] num_threads = 4 entries_per_thread = 50 @@ -697,11 +691,17 @@ def test_accumulate_waits_for_all_engines(self): # Track queue for cleanup self._ray_queues.append(inference_results_Q) - pending_queries_map = grpo_fast.PendingQueriesMap() - - # Add entries to map + # Create mock dataset for lookup + mock_dataset = [] for i in range(num_prompts): - pending_queries_map.insert(i, f"q_{i}", f"t_{i}", f"d_{i}", f"q_{i}") + mock_dataset.append( + { + INPUT_IDS_PROMPT_KEY: f"q_{i}", + GROUND_TRUTHS_KEY: f"t_{i}", + VERIFIER_SOURCE_KEY: f"d_{i}", + RAW_PROMPT_KEY: f"q_{i}", + } + ) # Add results from only 3 engines (missing one) # With individual prompts, we add individual results @@ -721,12 +721,12 @@ def run_accumulate(): mock_model_dims = self.create_mock_model_dims() grpo_fast.accumulate_inference_batches( inference_results_Q, - pending_queries_map, - generation_config=mock_generation_config, + mock_generation_config, num_prompts=num_prompts, model_dims=mock_model_dims, tokenizer=tokenizer, reward_fn=reward_fn, + dataset=mock_dataset, ) completed.set() except Exception: @@ -741,8 +741,6 @@ def run_accumulate(): # Queue should be empty after consuming 12 results self.assertEqual(inference_results_Q.qsize(), 0) - # 12 entries should be removed from the map (4 still pending) - self.assertEqual(len(pending_queries_map), 4) class TestStreamingAccumulation(TestGrpoFastBase): @@ -756,7 +754,7 @@ def test_more_engines_than_queries(self): queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_queries) param_prompt_Q = ray_queue.Queue(maxsize=num_queries) - pending_queries_map = grpo_fast.PendingQueriesMap() + pending_queries_map = PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(param_prompt_Q) @@ -811,7 +809,7 @@ def test_uneven_distribution_no_empty_batches(self): queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_queries) param_prompt_Q = ray_queue.Queue(maxsize=num_queries) - pending_queries_map = grpo_fast.PendingQueriesMap() + pending_queries_map = PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(param_prompt_Q) @@ -864,7 +862,7 @@ def test_streaming_accumulation_basic(self): # Create queues and maps inference_results_Q = ray_queue.Queue(maxsize=num_prompts) - pending_queries_map = grpo_fast.PendingQueriesMap() + pending_queries_map = PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(inference_results_Q) @@ -916,7 +914,7 @@ def test_streaming_with_multiple_samples(self): # Create queues and maps inference_results_Q = ray_queue.Queue(maxsize=num_prompts) - pending_queries_map = grpo_fast.PendingQueriesMap() + pending_queries_map = PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(inference_results_Q) From 0c464c88552d1b6fa6fa2396522f99d7f3b0b3ce Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 18 Nov 2025 16:53:14 -0700 Subject: [PATCH 28/96] Fixed tests. --- open_instruct/test_grpo_fast.py | 99 ++++++++++++++------------------- 1 file changed, 43 insertions(+), 56 deletions(-) diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index 6b05fc2c3..b010783c3 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -23,7 +23,7 @@ VERIFIER_SOURCE_KEY, ) from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics -from open_instruct.streaming_data_loader import PendingQueriesMap +from open_instruct.streaming_data_loader import PendingQueriesMap, ShufflingIterator from open_instruct.vllm_utils import create_vllm_engines @@ -749,12 +749,10 @@ class TestStreamingAccumulation(TestGrpoFastBase): def test_more_engines_than_queries(self): """Test that add_prompt_to_generator handles gracefully when engines > queries.""" # More engines than queries - should handle gracefully with single-prompt batches - num_engines = 8 num_queries = 4 queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_queries) param_prompt_Q = ray_queue.Queue(maxsize=num_queries) - pending_queries_map = PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(param_prompt_Q) @@ -762,10 +760,6 @@ def test_more_engines_than_queries(self): mock_generation_config = MagicMock() mock_generation_config.n = 1 - # Create mock args with inference_batch_size - mock_args = MagicMock() - mock_args.inference_batch_size = max(1, num_queries // num_engines) - for index in range(len(queries)): example = { INPUT_IDS_PROMPT_KEY: queries[index], @@ -774,14 +768,7 @@ def test_more_engines_than_queries(self): RAW_PROMPT_KEY: raw_queries[index], } grpo_fast.add_prompt_to_generator( - example, - indices[index], - epoch_number=0, - training_step=1, - pending_queries_map=pending_queries_map, - param_prompt_Q=param_prompt_Q, - generation_config=mock_generation_config, - is_eval=False, + example, indices[index], 0, 1, param_prompt_Q, mock_generation_config, False ) # Should have 4 batches (one for each query) @@ -799,17 +786,13 @@ def test_more_engines_than_queries(self): # Should have exactly num_queries PromptRequests self.assertEqual(prompt_count, num_queries, f"Should have {num_queries} PromptRequests") - # All queries should be in the pending map - self.assertEqual(len(pending_queries_map), num_queries) def test_uneven_distribution_no_empty_batches(self): """Test that uneven query distribution doesn't create empty batches.""" - num_engines = 3 - num_queries = 7 # 7/3 = ceil(2.33) = 3, so distribution should be [3, 3, 1] + num_queries = 7 queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_queries) param_prompt_Q = ray_queue.Queue(maxsize=num_queries) - pending_queries_map = PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(param_prompt_Q) @@ -817,10 +800,6 @@ def test_uneven_distribution_no_empty_batches(self): mock_generation_config = MagicMock() mock_generation_config.n = 1 - # Create mock args with inference_batch_size - mock_args = MagicMock() - mock_args.inference_batch_size = max(1, num_queries // num_engines + (1 if num_queries % num_engines else 0)) - for index in range(len(queries)): example = { INPUT_IDS_PROMPT_KEY: queries[index], @@ -829,14 +808,7 @@ def test_uneven_distribution_no_empty_batches(self): RAW_PROMPT_KEY: raw_queries[index], } grpo_fast.add_prompt_to_generator( - example, - indices[index], - epoch_number=0, - training_step=1, - pending_queries_map=pending_queries_map, - param_prompt_Q=param_prompt_Q, - generation_config=mock_generation_config, - is_eval=False, + example, indices[index], 0, 1, param_prompt_Q, mock_generation_config, False ) # With single-prompt architecture, verify we have the right number of individual requests @@ -860,16 +832,23 @@ def test_streaming_accumulation_basic(self): # Create test data queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_prompts) - # Create queues and maps + # Create queues and mock dataset inference_results_Q = ray_queue.Queue(maxsize=num_prompts) - pending_queries_map = PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(inference_results_Q) - # Insert data into pending_queries_map + # Create mock dataset for lookup + mock_dataset = [] for i in range(num_prompts): - pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i], raw_queries[i]) + mock_dataset.append( + { + INPUT_IDS_PROMPT_KEY: queries[i], + GROUND_TRUTHS_KEY: ground_truths[i], + VERIFIER_SOURCE_KEY: datasets[i], + RAW_PROMPT_KEY: raw_queries[i], + } + ) # Create mock results - one per prompt for i in range(num_prompts): @@ -886,14 +865,17 @@ def test_streaming_accumulation_basic(self): results_list.append(result) - # Get query for this prompt + # Get query for this prompt from dataset dataset_index = result.dataset_index - q, gt, d, raw_q = pending_queries_map.pop(dataset_index) + example = mock_dataset[dataset_index] + q = example[INPUT_IDS_PROMPT_KEY] + gt = example[GROUND_TRUTHS_KEY] + d = example[VERIFIER_SOURCE_KEY] + raw_q = example[RAW_PROMPT_KEY] queries_list.append((q, gt, d, raw_q)) # Verify all results processed self.assertEqual(len(results_list), expected_results) - self.assertEqual(len(pending_queries_map), 0) # Combine in order combined_queries = [] @@ -912,17 +894,23 @@ def test_streaming_with_multiple_samples(self): # Create test data queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_prompts) - # Create queues and maps + # Create queues and mock dataset inference_results_Q = ray_queue.Queue(maxsize=num_prompts) - pending_queries_map = PendingQueriesMap() # Track queue for cleanup self._ray_queues.append(inference_results_Q) - # Insert data with reference counting for multiple samples + # Create mock dataset for lookup + mock_dataset = [] for i in range(num_prompts): - for _ in range(num_samples): - pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i], raw_queries[i]) + mock_dataset.append( + { + INPUT_IDS_PROMPT_KEY: queries[i], + GROUND_TRUTHS_KEY: ground_truths[i], + VERIFIER_SOURCE_KEY: datasets[i], + RAW_PROMPT_KEY: raw_queries[i], + } + ) # Create results - one per prompt with multiple samples for i in range(num_prompts): @@ -939,14 +927,13 @@ def test_streaming_with_multiple_samples(self): self.assertEqual(len(result.responses), expected_responses) total_responses += len(result.responses) - # Pop multiple times to match the number of samples (reference counting) + # Get query from dataset (can be looked up multiple times) idx = result.dataset_index - for _ in range(num_samples): - pending_queries_map.pop(idx) + example = mock_dataset[idx] + self.assertIsNotNone(example[INPUT_IDS_PROMPT_KEY]) # Verify total responses self.assertEqual(total_responses, num_prompts * num_samples) - self.assertEqual(len(pending_queries_map), 0) class TestShufflingIterator(unittest.TestCase): @@ -957,7 +944,7 @@ def test_basic_iteration(self): data = np.arange(100) batch_size = 10 - iterator = grpo_fast.ShufflingIterator(data, batch_size, seed=42) + iterator = ShufflingIterator(data, batch_size, seed=42) # Get first batch batch1 = next(iterator) @@ -978,7 +965,7 @@ def test_state_preservation_and_restoration(self): seed = 42 # Create original iterator - iter1 = grpo_fast.ShufflingIterator(data, batch_size, seed=seed) + iter1 = ShufflingIterator(data, batch_size, seed=seed) # Get a few batches _ = next(iter1) @@ -999,7 +986,7 @@ def test_state_preservation_and_restoration(self): batch5_original = next(iter1) # Create new iterator with different seed and restore state - iter2 = grpo_fast.ShufflingIterator(data, batch_size, seed=999) + iter2 = ShufflingIterator(data, batch_size, seed=999) iter2.set_state(state) # Get batches from restored iterator @@ -1017,7 +1004,7 @@ def test_epoch_boundary_state(self): batch_size = 5 # Create iterator and complete one epoch - iterator = grpo_fast.ShufflingIterator(data, batch_size, seed=123) + iterator = ShufflingIterator(data, batch_size, seed=123) for _ in range(4): # 20 / 5 = 4 batches per epoch next(iterator) @@ -1027,7 +1014,7 @@ def test_epoch_boundary_state(self): self.assertEqual(state["index"], 20) # Create new iterator and restore state - iter2 = grpo_fast.ShufflingIterator(data, batch_size, seed=456) + iter2 = ShufflingIterator(data, batch_size, seed=456) iter2.set_state(state) # Next batches should match @@ -1042,8 +1029,8 @@ def test_rng_state_preservation(self): batch_size = 50 # Create two iterators with same seed - iter1 = grpo_fast.ShufflingIterator(data, batch_size, seed=42) - _ = grpo_fast.ShufflingIterator(data, batch_size, seed=42) + iter1 = ShufflingIterator(data, batch_size, seed=42) + _ = ShufflingIterator(data, batch_size, seed=42) # Advance first iterator for _ in range(5): @@ -1051,7 +1038,7 @@ def test_rng_state_preservation(self): # Save state and create new iterator with different seed state = iter1.get_state() - iter3 = grpo_fast.ShufflingIterator(data, batch_size, seed=999) + iter3 = ShufflingIterator(data, batch_size, seed=999) # Restore state - this should override the different seed iter3.set_state(state) From f6e420a811026783f4bb2bb5dba4d7015798a200 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 19 Nov 2025 08:05:50 -0700 Subject: [PATCH 29/96] Fixed sharding --- open_instruct/grpo_fast.py | 1 + 1 file changed, 1 insertion(+) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 6170718d1..e85194798 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -603,6 +603,7 @@ def __init__( self.tokenizer = tokenizer self.pad_token_id = tokenizer.pad_token_id self.num_mini_batches = args.num_mini_batches + dataset = dataset.shard(num_shards=world_size, index=rank) self.dataloader = data_loader_config.build( dataset=dataset, reward_fn=reward_fn, From dec1b4bcb9a47bb57afd04a743b51dfc2394a131 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 19 Nov 2025 10:26:42 -0700 Subject: [PATCH 30/96] Added barrier --- open_instruct/grpo_fast.py | 1 + 1 file changed, 1 insertion(+) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index e85194798..eb38b8c3e 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -929,6 +929,7 @@ def update_ref_policy(self): def step(self): batch_data = next(self.dataloader) + torch.distributed.barrier() batch_metrics = batch_data["metrics"] collated_query_responses = batch_data["collated_query_responses"] collated_tool_masks = batch_data["collated_tool_masks"] From 1bdd302f78a170d44ffb7242a04542f4eec09fef Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 19 Nov 2025 12:21:35 -0700 Subject: [PATCH 31/96] using same accumulation steps --- open_instruct/grpo_fast.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index eb38b8c3e..5470845c3 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -949,6 +949,10 @@ def step(self): to_device_inplace(collated_vllm_logprobs, self.device) # accumulation steps should always be at least 1 accumulation_steps = max(math.ceil(len(collated_query_responses) / self.num_mini_batches - 0.5), 1) + # Sync accumulation_steps across ranks so all learners call allreduce on the same iterations + accumulation_steps_tensor = torch.tensor([accumulation_steps], device=self.device, dtype=torch.int32) + torch.distributed.all_reduce(accumulation_steps_tensor, op=torch.distributed.ReduceOp.MIN) + accumulation_steps = int(accumulation_steps_tensor.item()) leftover = len(collated_query_responses) % accumulation_steps if leftover > 0: collated_query_responses = collated_query_responses[0:-leftover] From 72a15e700524718ee000cabcff7a86779397d21f Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 20 Nov 2025 12:51:49 -0700 Subject: [PATCH 32/96] Updated metrics averaging. --- open_instruct/grpo_fast.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 5470845c3..a1d48e671 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1939,7 +1939,13 @@ def one_training_step( ray.get(actor_manager.report_training_step_time.remote(train_timer.duration)) - average_metrics = {k: sum(m[k] for m in metrics_list) / len(metrics_list) for k in metrics_list[0]} + all_keys = set() + for m in metrics_list: + all_keys.update(m.keys()) + average_metrics = {} + for k in all_keys: + values = [m[k] for m in metrics_list if k in m] + average_metrics[k] = sum(values) / len(values) for key, value in array_metrics_list[0].items(): average_metrics[key] = value step_time = time.perf_counter() - start_time From cfcfb5e01e7f9c3ae073d3e41a0a22b160322ba8 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 21 Nov 2025 10:05:01 -0700 Subject: [PATCH 33/96] updated code with debug logs --- open_instruct/grpo_fast.py | 26 ++++++++++++++++++++++++++ open_instruct/vllm_utils.py | 5 +++++ 2 files changed, 31 insertions(+) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index a1d48e671..16e18bf12 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -664,7 +664,11 @@ def load(self, path: str, map_location=None): np.random.seed(worker_seed) random.seed(worker_seed) + logger.info( + f"[DEBUG] Rank {self.rank}: Initializing DeepSpeed distributed (timeout={args.backend_timeout} minutes)..." + ) deepspeed.init_distributed(timeout=timedelta(minutes=args.backend_timeout)) + logger.info(f"[DEBUG] Rank {self.rank}: DeepSpeed distributed initialized successfully") ds_config = get_train_ds_config( offload=args.deepspeed_offload_param, @@ -839,8 +843,11 @@ def forward( return logprob, entropy def setup_model_update_group(self, vllm_engines): + logger = logger_utils.setup_logger(__name__) + logger.info(f"[DEBUG] Rank {self.rank}: Entered setup_model_update_group") self.vllm_engines = vllm_engines if self.rank == 0: + logger.info(f"[DEBUG] Rank 0: Initializing process group for {len(vllm_engines)} vLLM engines") master_address = ray._private.services.get_node_ip_address() with socket.socket() as sock: sock.bind(("", 0)) @@ -851,6 +858,10 @@ def setup_model_update_group(self, vllm_engines): ) world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 backend = self.args.vllm_sync_backend + logger.info( + f"[DEBUG] Rank 0: master_address={master_address}, master_port={master_port}, " + f"world_size={world_size}, backend={backend}" + ) refs = [ engine.init_process_group.remote( master_address, @@ -871,8 +882,15 @@ def setup_model_update_group(self, vllm_engines): group_name="openrlhf", timeout=timedelta(minutes=self.args.backend_timeout), ) + logger.info( + f"[DEBUG] Rank 0: Waiting for {len(refs)} vLLM engines to initialize process groups (timeout=600s)..." + ) ray_get_with_progress(refs, desc="Initializing vLLM process groups", timeout=600) + logger.info("[DEBUG] Rank 0: All vLLM engines initialized, approaching barrier") + else: + logger.info(f"[DEBUG] Rank {self.rank}: Approaching barrier") torch.distributed.barrier() + logger.info(f"[DEBUG] Rank {self.rank}: Passed barrier successfully") def broadcast_to_vllm(self): # avoid OOM @@ -1736,21 +1754,26 @@ def create_model_and_optimizer( use_fp8_kv_cache=args.use_fp8_kv_cache, inflight_updates=args.inflight_updates, ) + logger.info(f"[DEBUG] Created {len(vllm_engines)} vLLM engines") # Get model dimensions from vLLM engine + logger.info("[DEBUG] Fetching model dimensions from first vLLM engine...") model_dims = ray.get(vllm_engines[0].get_model_dims.remote()) logger.info("======== ✅ vLLM engines and actor_manager initialized =========") # Get and set KV cache max concurrency from the first engine (all engines have the same config) # fp8 kv cache for now forces v0 engine and breaks this. + logger.info("[DEBUG] Setting up KV cache configuration...") if vllm_engines and not args.use_fp8_kv_cache: kv_cache_max_concurrency = ray.get(vllm_engines[0].get_kv_cache_info.remote()) ray.get(actor_manager.set_kv_cache_max_concurrency.remote(kv_cache_max_concurrency)) else: # dummy value ray.get(actor_manager.set_kv_cache_max_concurrency.remote(-1)) + logger.info("[DEBUG] KV cache configuration complete") # Now create policy actors with all dependencies + logger.info("[DEBUG] Creating ModelGroup with policy actors...") wandb_url = wandb.run.get_url() if args.with_tracking else None policy_group = ModelGroup( pg, @@ -1768,7 +1791,9 @@ def create_model_and_optimizer( actor_manager=actor_manager, model_dims=model_dims, ) + logger.info(f"[DEBUG] ModelGroup created with {len(policy_group.models)} policy actors") + logger.info("[DEBUG] Starting model initialization across all ranks...") inits = [ model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer) for model in policy_group.models @@ -1783,6 +1808,7 @@ def create_model_and_optimizer( ) logger.info("======== ✅ all models initialized =========") + logger.info("[DEBUG] Setting up model update group across all ranks...") ray_get_with_progress( [m.setup_model_update_group.remote(vllm_engines=vllm_engines) for m in policy_group.models], desc="Setting up model update group", diff --git a/open_instruct/vllm_utils.py b/open_instruct/vllm_utils.py index 390f295c5..6affd135b 100644 --- a/open_instruct/vllm_utils.py +++ b/open_instruct/vllm_utils.py @@ -831,7 +831,9 @@ def create_vllm_engines( # ensure we use bundles on the same node where possible if tp>1. bundle_indices_list = get_bundle_indices_list(pg) + logger.info(f"[DEBUG] Creating {num_engines} vLLM engines with tensor_parallel_size={tensor_parallel_size}") for i in range(num_engines): + logger.info(f"[DEBUG] Creating vLLM engine {i + 1}/{num_engines}") bundle_indices = None bundle_indices = bundle_indices_list[i * tensor_parallel_size : (i + 1) * tensor_parallel_size] @@ -880,9 +882,12 @@ def create_vllm_engines( calculate_kv_scales=use_fp8_kv_cache, ) ) + logger.info(f"[DEBUG] vLLM engine {i + 1}/{num_engines} actor created") + logger.info(f"[DEBUG] All {num_engines} vLLM engine actors created, waiting for ready() (timeout=1200s)...") ray_get_with_progress( [engine.ready.remote() for engine in vllm_engines], "Initializing vLLM engines", timeout=1200 ) + logger.info(f"[DEBUG] All {num_engines} vLLM engines ready!") return vllm_engines From cc12a31f9748f71b73fb8a4b5b0c39f3a7074652 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 21 Nov 2025 14:50:24 -0700 Subject: [PATCH 34/96] now trying to run again --- open_instruct/grpo_fast.py | 1 - 1 file changed, 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 16e18bf12..53231479c 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -947,7 +947,6 @@ def update_ref_policy(self): def step(self): batch_data = next(self.dataloader) - torch.distributed.barrier() batch_metrics = batch_data["metrics"] collated_query_responses = batch_data["collated_query_responses"] collated_tool_masks = batch_data["collated_tool_masks"] From f9771217906b325a988a669c77138c7c93f81e2a Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 21 Nov 2025 16:49:45 -0700 Subject: [PATCH 35/96] uses global dataset index --- open_instruct/grpo_fast.py | 1 + open_instruct/streaming_data_loader.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 53231479c..dcb140b39 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1650,6 +1650,7 @@ def setup_datasets( system_prompt_override=system_prompt_override, ) train_dataset = train_dataset.shuffle(seed=args.seed) + train_dataset = train_dataset.map(lambda example, idx: {**example, "index": idx}, with_indices=True) eval_dataset = None if len(args.dataset_mixer_eval_list) > 0: diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index 03f14919a..124c48f0b 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -395,9 +395,11 @@ def _start_background_thread(self): def _data_preparation_loop(self): for _ in range(self.config.async_steps * self.global_batch_size // self.dp_world_size): - dataset_index = next(self.iter_dataloader) + local_index = next(self.iter_dataloader) + example = self.dataset[local_index] + dataset_index = example["index"] add_prompt_to_generator( - self.dataset[dataset_index], + example, dataset_index, self.iter_dataloader.epoch_number, self.training_step, @@ -783,9 +785,11 @@ def accumulate_inference_batches( raw_query = example[RAW_PROMPT_KEY] if replenish_prompts: - dataset_index = next(iter_dataloader) + local_index = next(iter_dataloader) + example = dataset[local_index] + dataset_index = example["index"] add_prompt_to_generator( - dataset[dataset_index], + example, dataset_index, iter_dataloader.epoch_number, training_step, From a09b73a4da157071274ac2ff7c459378730a6ca5 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 24 Nov 2025 12:38:01 -0700 Subject: [PATCH 36/96] Remove duplicate add_prompt_to_generator and load_data_from_packing_thread functions - Deleted duplicate functions from grpo_fast.py that were causing NameError - Use canonical versions from streaming_data_loader.py instead - Fixed test to use correct function signature without pending_queries_map - Linter passes --- open_instruct/grpo_fast.py | 70 --------------------------------- open_instruct/test_grpo_fast.py | 2 +- 2 files changed, 1 insertion(+), 71 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 6ecdd392c..8a54f093a 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1926,76 +1926,6 @@ def create_generation_configs(args: Args, streaming_config: streaming_data_loade return {"train": generation_config, "eval": eval_generation_config} -def add_prompt_to_generator( - example: dict[str, Any], - example_index: int, - epoch_number: int, - training_step: int, - pending_queries_map: PendingQueriesMap, - param_prompt_Q: ray_queue.Queue, - generation_config, - is_eval: bool, -) -> None: - """Split a batch into multiple inference batches and insert individual prompts into queues and mapping.""" - query = example[INPUT_IDS_PROMPT_KEY] - ground_truth = example[GROUND_TRUTHS_KEY] - dataset_name = example[VERIFIER_SOURCE_KEY] - raw_query = example[RAW_PROMPT_KEY] - pending_queries_map.insert(example_index, query, ground_truth, dataset_name, raw_query) - - param_prompt_Q.put( - PromptRequest( - prompt=query, - generation_config=generation_config, - epoch_number=epoch_number, - training_step=training_step, - dataset_index=example_index, - is_eval=is_eval, - ) - ) - - -def load_data_from_packing_thread( - packed_sequences_Q: Queue, num_total_tokens: int, stop_event: threading.Event, health_check_fn: Callable[[], None] -) -> tuple[list[dict[str, list[torch.Tensor]]] | None, dict[str, Any], int, int, list[int] | None, list[int] | None]: - """Get the packed sequences with advantages from the packing thread.""" - with Timer("[Main Thread] 📦 Getting packed sequences from thread") as timer: - while True: - if stop_event.is_set(): - logger.warning("[Main Thread] Stop event detected while waiting for packed sequences") - return None, {}, num_total_tokens, 0, None, None, 0 - try: - # When running at 32k generation length, it typically takes 900s to generate data, - # so you might see this fire a bunch of times. That's normal! - packed_data = packed_sequences_Q.get(timeout=300) - break - except Empty: - health_check_fn() - logger.warning("[Main Thread] Timeout waiting for packed sequences. Retrying...") - data_thread_metrics = packed_data["metrics"] - B = packed_data["B"] - collated_data = packed_data["collated_data"] - num_step_tokens = packed_data["num_new_tokens"] - num_total_tokens += num_step_tokens - prompt_lengths = packed_data["prompt_lengths"] - response_lengths = packed_data["response_lengths"] - num_filtered_prompts = packed_data["num_filtered_prompts"] - - data_thread_metrics["time/trainer_idling"] = timer.duration - if B == 0: - logger.warning("[Main Thread] 🤡 After packing, there is not enough data to train") - return None, data_thread_metrics, num_total_tokens, 0, None, None, 0 - return ( - collated_data, - data_thread_metrics, - num_total_tokens, - num_step_tokens, - prompt_lengths, - response_lengths, - num_filtered_prompts, - ) - - def weight_sync_thread( args: Args, stop_event: threading.Event, diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index 95dd39f5b..b010783c3 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -266,7 +266,7 @@ def setup_and_add_prompts_to_generator( RAW_PROMPT_KEY: raw_queries[i], } grpo_fast.add_prompt_to_generator( - example, indices[i], 0, training_step, pending_queries_map, param_prompt_Q, mock_generation_config, False + example, indices[i], 0, training_step, param_prompt_Q, mock_generation_config, False ) return param_prompt_Q, inference_results_Q, mock_dataset From eee872ae3e96d79d18c714c23ea3df38cf186eca Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 24 Nov 2025 14:10:06 -0700 Subject: [PATCH 37/96] Fix duplicate async_steps argument in argparse - Removed async_steps from Args class (duplicate in StreamingDataLoaderConfig) - Moved validation logic for async_steps to StreamingDataLoaderConfig.__post_init__ - Removed validation for active_sampling and filter_zero_std_samples from Args (already in StreamingDataLoaderConfig) - Fixes: argparse.ArgumentError: conflicting option strings: --async_steps --- open_instruct/grpo_fast.py | 20 -------------------- open_instruct/streaming_data_loader.py | 2 ++ 2 files changed, 2 insertions(+), 20 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 8a54f093a..093c375f6 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -212,8 +212,6 @@ class Args: """List of strings that stop the generation when they are generated. The returned output will not contain the stop strings.""" # Algorithm - async_steps: int = 1 - """Number of steps ahead to generate responses. Fully synchronous training is not supported, so async_steps must be greater than 0. The trainer learns from a policy up to async_steps old like Cleanba (https://arxiv.org/abs/2310.00036)""" num_epochs: int = 1 """the number of epochs to train""" num_mini_batches: int = 1 @@ -503,24 +501,6 @@ def __post_init__(self): if self.apply_r1_style_format_reward and self.additive_format_reward: self.max_possible_score += self.r1_style_format_reward - if self.active_sampling: - assert self.async_steps > 1, ( - "With active_sampling, you should set async_steps > 1 to account for filtering of the first batch. " - "Otherwise, your generator only generates only one batch worth of prompts and a single filtered " - "prompt will cause the trainer to stall waiting for more data . " - ) - assert self.filter_zero_std_samples, ( - "filter_zero_std_samples must be True when active_sampling is True. " - "Active sampling requires filtering to work correctly." - ) - if self.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples: - raise ValueError( - "`filter_zero_std_samples` cannot be True when `num_samples_per_prompt_rollout` is 1, " - "as the reward standard deviation will always be 0, causing all samples to be filtered." - ) - if self.async_steps < 1: - raise ValueError("`async_steps` must be greater than 0. Fully synchronous training is not supported.") - def masked_mean( values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index 124c48f0b..eb1ba09d5 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -81,6 +81,8 @@ def __post_init__(self): "`filter_zero_std_samples` cannot be True when `num_samples_per_prompt_rollout` is 1, " "as the reward standard deviation will always be 0, causing all samples to be filtered." ) + if self.async_steps < 1: + raise ValueError("`async_steps` must be greater than 0. Fully synchronous training is not supported.") def build( self, From e18f3af51e757f55bf3580418aab202380aad5e5 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 24 Nov 2025 14:17:48 -0700 Subject: [PATCH 38/96] Remove remaining duplicate fields from Args class - Removed duplicate fields that exist in StreamingDataLoaderConfig: - advantage_normalization_type - mask_truncated_completions - active_sampling - filter_zero_std_samples - no_resampling_pass_rate - These fields are now only in StreamingDataLoaderConfig to avoid argparse conflicts - Fixes: argparse.ArgumentError for all duplicate field names --- open_instruct/grpo_fast.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 093c375f6..d4919b458 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -241,19 +241,6 @@ class Args: """How many training steps to take before updating the reference policy.""" load_ref_policy: bool = True """Whether to load and use a reference policy for KL penalty calculation.""" - advantage_normalization_type: Literal["standard", "centered"] = "standard" - """The type of advantage normalization to use. Standard normalization is the default: it subtracts the mean and - divides by the standard deviation. Centered normalization is the same but subtracts the mean only (e.g., used in - DR.GRPO https://arxiv.org/pdf/2503.20783).""" - mask_truncated_completions: bool = False - """Whether to mask out truncated completions. Also called overlong filtering, from DAPO (https://arxiv.org/abs/2503.14476).""" - - active_sampling: bool = False - """Whether to continue sampling responses until you get a full batch.""" - filter_zero_std_samples: bool = True - """Whether to filter out prompts with zero reward std (all samples have the same score).""" - no_resampling_pass_rate: float | None = None - """If the response to a prompt is solved at a rate higher than this, do not resample this prompt again""" record_entropy: bool = False """whether to record the entropy of the policy during training. Uses extra memory.""" From 911bf44a2bd16b208dbba904b5ab65198d284960 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 24 Nov 2025 14:19:33 -0700 Subject: [PATCH 39/96] Remove pack_length validation from Args - pack_length, max_prompt_token_length, and response_length are in StreamingDataLoaderConfig - Validation already exists in StreamingDataLoaderConfig.__post_init__ - Fixes: AttributeError: 'Args' object has no attribute 'pack_length' --- open_instruct/grpo_fast.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index d4919b458..37d7d7eac 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -439,9 +439,6 @@ def __post_init__(self): # Initialize stop_strings if None if self.stop_strings is None: self.stop_strings = [] - assert self.pack_length >= self.max_prompt_token_length + self.response_length, ( - "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!" - ) if self.checkpoint_state_freq > 0 and self.checkpoint_state_dir is None: raise ValueError("`checkpoint_state_dir` must be provided if `checkpoint_state_freq` is greater than 0!") if self.checkpoint_state_dir is not None and self.checkpoint_state_freq == -1: From fc1c8635ec53e1cefcdb929c84d0222e5c9a7d24 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 24 Nov 2025 14:39:09 -0700 Subject: [PATCH 40/96] Add missing inference_batch_size field to Args class MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This field was present in origin/main but was missing after the merge, causing an AttributeError at runtime. Added the field with its default value and documentation. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 37d7d7eac..ffdaf1bbc 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -304,6 +304,8 @@ class Args: on the first node and 4 learner processes on the second node; each process will have 1 GPU)""" vllm_num_engines: int = 1 """number of vLLM Engines, set to 0 to disable vLLM""" + inference_batch_size: int | None = None + """inference batch size per vLLM engine. If None, calculated as ceil(num_unique_prompts_rollout / vllm_num_engines) * num_samples_per_prompt_rollout""" vllm_tensor_parallel_size: int = 1 """tensor parallel size of vLLM Engine for multi-GPU inference""" vllm_enforce_eager: bool = False From 96b32f74a1422cfbb441c1078a894224857ca0b9 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 24 Nov 2025 14:57:18 -0700 Subject: [PATCH 41/96] Remove reference to deprecated use_fp8_kv_cache field MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The use_fp8_kv_cache field was intentionally removed from the Args class in earlier commits, but a reference remained at line 1804. Since this field was always False, the condition `not args.use_fp8_kv_cache` always evaluated to True, so we can safely simplify the check to just `if vllm_engines:`. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index ffdaf1bbc..1a3b6ad38 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1801,7 +1801,7 @@ def create_model_and_optimizer( # Get and set KV cache max concurrency from the first engine (all engines have the same config) # fp8 kv cache for now forces v0 engine and breaks this. logger.info("[DEBUG] Setting up KV cache configuration...") - if vllm_engines and not args.use_fp8_kv_cache: + if vllm_engines: kv_cache_max_concurrency = ray.get(vllm_engines[0].get_kv_cache_info.remote()) ray.get(actor_manager.set_kv_cache_max_concurrency.remote(kv_cache_max_concurrency)) expected_batch_size = ( From e02704c437a33a52fe4a4632eab0931b1ff1b692 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Mon, 24 Nov 2025 15:14:43 -0700 Subject: [PATCH 42/96] Fix num_samples_per_prompt_rollout field access MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changed references from args.num_samples_per_prompt_rollout to data_loader_config.num_samples_per_prompt_rollout at lines 1808 and 1812. This field was moved to StreamingDataLoaderConfig but the newly added KV cache concurrency check code was still trying to access it via args. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 1a3b6ad38..ff1802ebc 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1805,11 +1805,15 @@ def create_model_and_optimizer( kv_cache_max_concurrency = ray.get(vllm_engines[0].get_kv_cache_info.remote()) ray.get(actor_manager.set_kv_cache_max_concurrency.remote(kv_cache_max_concurrency)) expected_batch_size = ( - args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout // args.vllm_num_engines + args.num_unique_prompts_rollout + * data_loader_config.num_samples_per_prompt_rollout + // args.vllm_num_engines ) if kv_cache_max_concurrency < expected_batch_size: nodes_needed = ( - args.num_unique_prompts_rollout * args.num_samples_per_prompt_rollout // kv_cache_max_concurrency + args.num_unique_prompts_rollout + * data_loader_config.num_samples_per_prompt_rollout + // kv_cache_max_concurrency ) logger.warning( f"kv_cache_max_concurrency ({kv_cache_max_concurrency}) is lower than " From 6271708ebff050d56f549889a57e21e6f80dfe3b Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 25 Nov 2025 12:09:29 -0700 Subject: [PATCH 43/96] dynamically calculates require metrics --- open_instruct/grpo_fast.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index ff1802ebc..ec5c3e7a5 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -80,6 +80,7 @@ from open_instruct.actor_manager import ActorManager from open_instruct.dataset_transformation import ( INPUT_IDS_PROMPT_KEY, + VERIFIER_SOURCE_KEY, TokenizerConfig, get_cached_dataset_tulu, visualize_token, @@ -639,6 +640,7 @@ def from_pretrained( beaker_config: BeakerRuntimeConfig, wandb_url: str, tokenizer: PreTrainedTokenizer, + num_verifiers: int, ) -> int: # ------------------------------------------------------------ # Monkey patch to load checkpoints with `weights_only=False` @@ -798,7 +800,11 @@ def load(self, path: str, map_location=None): if hasattr(self, "ref_policy_checkpoint_path") else None, ) - self.local_metrics = utils.MetricsTracker(device=self.device) + # 49 base metrics: 16 from step() (KL, loss, ratio, etc.), 22 from streaming_data_loader + # (scores, sequence_lengths, etc.), 7 from BatchStatistics, 4 from reward_metrics. + # Each verifier adds 2 metrics: objective/{key}_reward and objective/{key}_correct_rate. + max_metrics = 49 + 2 * num_verifiers + self.local_metrics = utils.MetricsTracker(max_metrics=max_metrics, device=self.device) return optimization_steps_done def forward( @@ -1848,8 +1854,9 @@ def create_model_and_optimizer( logger.info(f"[DEBUG] ModelGroup created with {len(policy_group.models)} policy actors") logger.info("[DEBUG] Starting model initialization across all ranks...") + num_verifiers = len(set(train_dataset[VERIFIER_SOURCE_KEY])) inits = [ - model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer) + model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer, num_verifiers) for model in policy_group.models ] From 1b1ffbce88fa2a0bd3097d5cfabf8334f6d525df Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 25 Nov 2025 12:44:45 -0700 Subject: [PATCH 44/96] updated to get num verifiers --- open_instruct/grpo_fast.py | 19 ++++++++++++++++++- open_instruct/test_grpo_fast.py | 16 ++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index ec5c3e7a5..dfc526520 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -53,6 +53,7 @@ from collections.abc import Callable from dataclasses import asdict, dataclass, field from datetime import timedelta +from itertools import chain from queue import Empty, Full, Queue from typing import Any, Literal @@ -498,6 +499,22 @@ def masked_mean( return (numerator / denom).mean() +def get_num_verifiers(dataset: Dataset) -> int: + """Count unique verifiers in the dataset. + + Each example can have one verifier (string) or multiple verifiers (list of strings). + This function flattens all verifier sources and returns the count of unique ones. + + Args: + dataset: A HuggingFace Dataset containing a VERIFIER_SOURCE_KEY column. + + Returns: + The number of unique verifiers in the dataset. + """ + verifier_sources = dataset[VERIFIER_SOURCE_KEY] + return len(set(chain.from_iterable(v if isinstance(v, list) else [v] for v in verifier_sources))) + + @Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker") def prepare_collated_data_for_workers( packed_sequences: PackedSequences, @@ -1854,7 +1871,7 @@ def create_model_and_optimizer( logger.info(f"[DEBUG] ModelGroup created with {len(policy_group.models)} policy actors") logger.info("[DEBUG] Starting model initialization across all ranks...") - num_verifiers = len(set(train_dataset[VERIFIER_SOURCE_KEY])) + num_verifiers = get_num_verifiers(train_dataset) inits = [ model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer, num_verifiers) for model in policy_group.models diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index b010783c3..7c09c8081 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -10,6 +10,7 @@ import numpy as np import ray import torch +from datasets import Dataset from parameterized import parameterized from ray.util import queue as ray_queue from transformers import AutoTokenizer @@ -22,11 +23,26 @@ RAW_PROMPT_KEY, VERIFIER_SOURCE_KEY, ) +from open_instruct.grpo_fast import get_num_verifiers from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics from open_instruct.streaming_data_loader import PendingQueriesMap, ShufflingIterator from open_instruct.vllm_utils import create_vllm_engines +class TestGetNumVerifiers(unittest.TestCase): + def test_single_verifier_per_example(self): + dataset = Dataset.from_dict({VERIFIER_SOURCE_KEY: ["gsm8k", "math", "gsm8k"]}) + self.assertEqual(get_num_verifiers(dataset), 2) + + def test_multiple_verifiers_per_example(self): + dataset = Dataset.from_dict({VERIFIER_SOURCE_KEY: [["gsm8k", "math"], ["code"]]}) + self.assertEqual(get_num_verifiers(dataset), 3) + + def test_empty_dataset(self): + dataset = Dataset.from_dict({VERIFIER_SOURCE_KEY: []}) + self.assertEqual(get_num_verifiers(dataset), 0) + + class TestGrpoFastBase(unittest.TestCase): """Base class with common test utilities.""" From b4f751f4a0063ffca32859aaddc38c84627ebbfa Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 25 Nov 2025 14:27:06 -0700 Subject: [PATCH 45/96] Fixed code --- open_instruct/grpo_fast.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 65df6c80c..c6b8f19b9 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -37,10 +37,16 @@ import deepspeed from open_instruct import streaming_data_loader, utils -from open_instruct.streaming_data_loader import accumulate_inference_batches, add_prompt_to_generator, collate_fn +from open_instruct.streaming_data_loader import ( + PendingQueriesMap, + ShufflingIterator, + add_prompt_to_generator, + collate_fn, +) # isort: on import asyncio +import json import logging import math import random @@ -74,6 +80,7 @@ from ray.util.placement_group import PlacementGroup, placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from rich.pretty import pprint +from tqdm import tqdm from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, get_scheduler from transformers.integrations import HfDeepSpeedConfig @@ -92,6 +99,7 @@ soft_format_reward_func, ) from open_instruct.model_utils import ( + Batch, ModelConfig, apply_verifiable_reward, disable_dropout_in_model, @@ -103,8 +111,8 @@ print_rich_table, push_folder_to_hub, ) -from open_instruct.queue_types import ShutdownSentinel -from open_instruct.rl_utils import PackedSequences, Timer +from open_instruct.queue_types import GenerationResult, RequestInfo, ShutdownSentinel, TokenStatistics +from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences from open_instruct.utils import ( ArgumentParserPlus, BeakerRuntimeConfig, @@ -112,6 +120,7 @@ _z3_params_to_fetch, calibrate_checkpoint_state_dir, clean_last_n_checkpoints_deepspeed, + combine_reward_metrics, download_latest_checkpoint_from_gs, get_beaker_whoami, get_eval_ds_config, @@ -125,6 +134,7 @@ maybe_use_ai2_hf_entity, maybe_use_ai2_wandb_entity, ray_get_with_progress, + repeat_each, sync_gs_bucket, ) From b32acc1d521296844a5b109ab21e303b41ecebc1 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 25 Nov 2025 15:04:51 -0700 Subject: [PATCH 46/96] tests pass --- open_instruct/grpo_fast.py | 318 +------------------------ open_instruct/streaming_data_loader.py | 8 + open_instruct/test_grpo_fast.py | 24 +- 3 files changed, 27 insertions(+), 323 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index c6b8f19b9..2831e29df 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -38,8 +38,8 @@ from open_instruct import streaming_data_loader, utils from open_instruct.streaming_data_loader import ( - PendingQueriesMap, ShufflingIterator, + accumulate_inference_batches, add_prompt_to_generator, collate_fn, ) @@ -80,7 +80,6 @@ from ray.util.placement_group import PlacementGroup, placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from rich.pretty import pprint -from tqdm import tqdm from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, get_scheduler from transformers.integrations import HfDeepSpeedConfig @@ -99,7 +98,6 @@ soft_format_reward_func, ) from open_instruct.model_utils import ( - Batch, ModelConfig, apply_verifiable_reward, disable_dropout_in_model, @@ -111,7 +109,7 @@ print_rich_table, push_folder_to_hub, ) -from open_instruct.queue_types import GenerationResult, RequestInfo, ShutdownSentinel, TokenStatistics +from open_instruct.queue_types import ShutdownSentinel from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences from open_instruct.utils import ( ArgumentParserPlus, @@ -120,7 +118,6 @@ _z3_params_to_fetch, calibrate_checkpoint_state_dir, clean_last_n_checkpoints_deepspeed, - combine_reward_metrics, download_latest_checkpoint_from_gs, get_beaker_whoami, get_eval_ds_config, @@ -134,7 +131,6 @@ maybe_use_ai2_hf_entity, maybe_use_ai2_wandb_entity, ray_get_with_progress, - repeat_each, sync_gs_bucket, ) @@ -1633,315 +1629,11 @@ def calculate_utilization_metrics( return utilization_metrics -@dataclass -class BatchStatistics: - prompt_lengths: list[int] - response_lengths: list[int] - filtered_prompts: int - filtered_prompts_zero: int - filtered_prompts_solved: int - filtered_prompts_nonzero: int - percent_solved_mean: float - no_resampled_prompts: int - total_prompts: int - - -def accumulate_inference_batches( - inference_results_Q: ray_queue.Queue, - pending_queries_map: PendingQueriesMap, - args: Args, - generation_config: vllm.SamplingParams, - num_prompts: int, - model_dims: utils.ModelDims, - tokenizer: PreTrainedTokenizer, - reward_fn: Callable, - actor_manager=None, - timeout: float | None = None, - active_sampling: bool = False, - filter_zero_std_samples: bool = False, - replenish_prompts: bool = False, - no_resampling_pass_rate: float | None = None, - iter_dataloader: ShufflingIterator | None = None, - prompt_dataset: Dataset = None, - param_prompt_Q: ray_queue.Queue | None = None, - training_step: int = None, -) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: - """Accumulate multiple inference results into a single training batch. - - Args: - inference_results_Q: Queue containing individual GenerationResult objects (one per prompt) - pending_queries_map: PendingQueriesMap instance for thread-safe query tracking - args: Arguments containing vllm_num_engines and batch size info - generation_config: Generation config containing n (number of samples per prompt) - num_prompts: Number of prompts to accumulate - timeout: Optional timeout in seconds for queue get operations. If None, blocks indefinitely. - active_sampling: Whether to continue sampling until we have sampled num_prompts prompts with non-zero std - filter_zero_std_samples: Whether to filter samples with zero reward std - replenish_prompts: Add a prompt back onto the prompt_Q after receiving a finished result - no_resampling_pass_rate: Optional rate at which to note samples solved at greater than this rate - and exclude them from further sampling - iter_dataloader: Optional, used for no_resampling_pass_rate - param_prompt_Q: Queue containing prompts to send to generator, used to replenish used prompts - - Raises: - queue.Empty: If timeout is specified and no data is available within timeout. - - Returns: - Tuple of (combined_result, Batch with queries, ground_truths, datasets, prompt_lengths, response_lengths) - or (ShutdownSentinel, None, None, None) if shutdown signal received - """ - if no_resampling_pass_rate is not None: - assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed" - - if replenish_prompts: - assert param_prompt_Q is not None and iter_dataloader is not None and prompt_dataset is not None, ( - "replenish_prompts requires param_prompt_Q and iter_dataloader and prompt_dataset" - ) - - results = [] - all_queries = [] - all_ground_truths = [] - all_datasets = [] - all_raw_queries = [] - all_decoded_responses = [] - all_reward_metrics = [] - all_scores = [] - all_percent_solved = [] - total_filtered_prompts = 0 - filtered_prompt_zero = 0 - filtered_prompt_solved = 0 - filtered_prompt_nonzero = 0 - total_no_resampled = 0 - progress_bar = tqdm( - total=num_prompts, - desc=f"Accumulating Responses and Rewarding {num_prompts} prompts", - bar_format="{l_bar}{bar}{r_bar}\n", - disable=not args.verbose, - ) - num_prompts_sampled = 0 - while num_prompts_sampled < num_prompts: - result = inference_results_Q.get(timeout=timeout) - - if isinstance(result, ShutdownSentinel): - return result, None, None, None - - # Validate that each individual result has the expected number of responses - assert len(result.responses) == generation_config.n, ( - f"Mismatch: individual prompt result has {len(result.responses)} responses " - f"but expected {generation_config.n} samples per prompt. " - f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}" - ) - - query, ground_truth, dataset_name, raw_query = pending_queries_map.pop(result.dataset_index) - - # Replenish generation queue with new prompt - if replenish_prompts: - dataset_index = next(iter_dataloader) - add_prompt_to_generator( - prompt_dataset[dataset_index], - dataset_index, - iter_dataloader.epoch_number, - training_step, - pending_queries_map, - param_prompt_Q, - generation_config, - is_eval=False, - ) - - # TODO(finbarrtimbers): Move this to LLMRayActor. - for i in range(len(result.finish_reasons)): - if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: - result.responses[i].append(tokenizer.eos_token_id) - result.masks[i].append(1) - result.logprobs[i].append(float("nan")) - - decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) - - # TODO(finbarrtimbers): Make PendingQueriesMap.pop return a Batch, and add a Batch.repeat method. - k_queries = repeat_each([query], generation_config.n) - k_ground_truths = repeat_each([ground_truth], generation_config.n) - k_datasets = repeat_each([dataset_name], generation_config.n) - k_raw_queries = repeat_each([raw_query], generation_config.n) - - scores, reward_metrics = asyncio.run( - reward_fn( - result.responses, - decoded_responses, - k_ground_truths, - k_datasets, - result.finish_reasons, - result.request_info, - k_raw_queries, - ) - ) - - percent_solved = np.mean(scores).item() / args.max_possible_score - # Don't resample prompt that was solved at more than no_resample_positive_rate - if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: - iter_dataloader.exclude_index(result.dataset_index) - total_no_resampled += 1 - logging.debug( - f"[Data Preparation Thread] Prompt solved at {percent_solved}, will be excluded from resampling, total no resampled: {total_no_resampled}" - ) - - # Filter out zero std prompts - if filter_zero_std_samples and np.std(scores) == 0: - # If we're not active sampling, still count this as a sample - if not active_sampling: - num_prompts_sampled += 1 - progress_bar.update(1) - - total_filtered_prompts += 1 - if scores[0] == 0: - filtered_prompt_zero += 1 - elif scores[0] == args.max_possible_score: - filtered_prompt_solved += 1 - else: - filtered_prompt_nonzero += 1 - logging.debug( - f"[Data Preparation Thread] Filtered prompt with reward std 0, total filtered {total_filtered_prompts}" - ) - continue - else: - num_prompts_sampled += 1 - progress_bar.update(1) - - results.append(result) - all_queries.extend(k_queries) - all_ground_truths.extend(k_ground_truths) - all_datasets.extend(k_datasets) - all_raw_queries.extend(k_raw_queries) - all_decoded_responses.extend(decoded_responses) - all_scores.extend(scores) - all_reward_metrics.append(reward_metrics) - all_percent_solved.append(percent_solved) - - if len(results) == 0: - logger.warning( - "[Data Preparation Thread] All prompts were filtered during accumulation. " - f"Filtered: {total_filtered_prompts} (zero std: {filtered_prompt_zero}, " - f"solved: {filtered_prompt_solved}, nonzero: {filtered_prompt_nonzero})" - ) - return None, None, None, None - - # Combine all results into a single GenerationResult - combined_responses = [] - combined_finish_reasons = [] - combined_masks = [] - combined_num_calls = [] - combined_timeouts = [] - combined_tool_errors = [] - combined_tool_outputs = [] - combined_tool_runtimes = [] - combined_tool_calleds = [] - combined_logprobs = [] - - earliest_start_time = float("inf") - prompt_lengths = [] - response_lengths = [] - - total_prompt_tokens = 0 - total_response_tokens = 0 - max_generation_time = 0 - - for i, result in enumerate(results): - combined_responses.extend(result.responses) - combined_finish_reasons.extend(result.finish_reasons) - combined_masks.extend(result.masks) - combined_num_calls.extend(result.request_info.num_calls) - combined_timeouts.extend(result.request_info.timeouts) - combined_tool_errors.extend(result.request_info.tool_errors) - combined_tool_outputs.extend(result.request_info.tool_outputs) - combined_tool_runtimes.extend(result.request_info.tool_runtimes) - combined_tool_calleds.extend(result.request_info.tool_calleds) - - combined_logprobs.extend(result.logprobs) - - earliest_start_time = min(earliest_start_time, result.start_time) - - prompt_lengths.append(len(all_queries[i * generation_config.n])) - - for response in result.responses: - response_lengths.append(len(response)) - - total_prompt_tokens += result.token_statistics.num_prompt_tokens - total_response_tokens += result.token_statistics.num_response_tokens - max_generation_time = max(max_generation_time, result.token_statistics.generation_time) - - # Use the maximum generation time across engines since they work in parallel - # This avoids including queue overhead and accumulation time in MFU/MBU calculations - total_generation_time = max_generation_time - - accumulated_stats = TokenStatistics( - num_prompt_tokens=total_prompt_tokens, - num_response_tokens=total_response_tokens, - generation_time=total_generation_time, - earliest_start_time=earliest_start_time, - ) - - # Create combined RequestInfo - combined_request_info = RequestInfo( - num_calls=combined_num_calls, - timeouts=combined_timeouts, - tool_errors=combined_tool_errors, - tool_outputs=combined_tool_outputs, - tool_runtimes=combined_tool_runtimes, - tool_calleds=combined_tool_calleds, - ) - - # Create combined GenerationResult - combined_result = GenerationResult( - responses=combined_responses, - finish_reasons=combined_finish_reasons, - masks=combined_masks, - request_info=combined_request_info, - dataset_index=None, - epoch_number=results[0].epoch_number, - token_statistics=accumulated_stats, - logprobs=combined_logprobs, - ) - - if actor_manager is not None: - ray.get(actor_manager.report_token_statistics.remote(accumulated_stats)) - - # Note: We don't have dataset_indices here, but they're not needed for the returned batch - batch = Batch( - queries=all_queries, - ground_truths=all_ground_truths, - datasets=all_datasets, - raw_queries=all_raw_queries, - decoded_responses=all_decoded_responses, - indices=None, # Not meaningful for combined results - scores=all_scores, - ) - - combined_reward_metrics = combine_reward_metrics(all_reward_metrics) - percent_solved_mean = np.mean(all_percent_solved) if all_percent_solved else 0.0 - - batch_stats = BatchStatistics( - prompt_lengths=prompt_lengths, - response_lengths=response_lengths, - filtered_prompts=total_filtered_prompts, - filtered_prompts_zero=filtered_prompt_zero, - filtered_prompts_solved=filtered_prompt_solved, - filtered_prompts_nonzero=filtered_prompt_nonzero, - percent_solved_mean=percent_solved_mean, - no_resampled_prompts=total_no_resampled, - total_prompts=len(results), - ) - logging.info( - f"[Data Preparation Thread] Calculating rewards took {combined_reward_metrics['time/reward']} seconds" - ) - - return combined_result, batch, combined_reward_metrics, batch_stats - - def data_preparation_thread( reward_fn: Callable, inference_results_Q: ray_queue.Queue, # Ray queue param_prompt_Q: ray_queue.Queue, packed_sequences_Q: Queue, - pending_queries_map: dict, args: Args, tokenizer: PreTrainedTokenizer, num_training_steps: int, @@ -1957,22 +1649,22 @@ def data_preparation_thread( with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: result, batch, reward_metrics, batch_stats = accumulate_inference_batches( inference_results_Q, - pending_queries_map, - args, generation_config, num_prompts=args.num_unique_prompts_rollout, model_dims=model_dims, tokenizer=tokenizer, reward_fn=reward_fn, + dataset=train_dataset, actor_manager=actor_manager, active_sampling=args.active_sampling, filter_zero_std_samples=args.filter_zero_std_samples, replenish_prompts=True, no_resampling_pass_rate=args.no_resampling_pass_rate, iter_dataloader=iter_dataloader, - prompt_dataset=train_dataset, param_prompt_Q=param_prompt_Q, training_step=training_step, + verbose=args.verbose, + max_possible_score=args.max_possible_score, ) if isinstance(result, ShutdownSentinel): logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting") diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index eb1ba09d5..d4f66ab32 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -863,6 +863,14 @@ def accumulate_inference_batches( all_reward_metrics.append(reward_metrics) all_percent_solved.append(percent_solved) + if len(results) == 0: + logging.warning( + "[Data Preparation Thread] All prompts were filtered during accumulation. " + f"Filtered: {total_filtered_prompts} (zero std: {filtered_prompt_zero}, " + f"solved: {filtered_prompt_solved}, nonzero: {filtered_prompt_nonzero})" + ) + return None, None, None, None + combined_responses = [] combined_finish_reasons = [] combined_masks = [] diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index 4be2e7b77..284d3dd16 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -1074,22 +1074,25 @@ def test_all_prompts_filtered_returns_none(self): num_prompts = 8 num_samples_per_prompt = 4 - queries, ground_truths, datasets, raw_queries, indices = self.create_test_data(num_prompts) + queries, ground_truths, datasets_list, raw_queries, indices = self.create_test_data(num_prompts) + + test_dataset = Dataset.from_dict( + { + INPUT_IDS_PROMPT_KEY: queries, + GROUND_TRUTHS_KEY: ground_truths, + VERIFIER_SOURCE_KEY: datasets_list, + RAW_PROMPT_KEY: raw_queries, + } + ) inference_results_Q = ray_queue.Queue(maxsize=num_prompts) - pending_queries_map = grpo_fast.PendingQueriesMap() self._ray_queues.append(inference_results_Q) - for i in range(num_prompts): - for _ in range(num_samples_per_prompt): - pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i], raw_queries[i]) - for i in range(num_prompts): mock_result = self.create_mock_result(i, epoch_number=1, num_samples_per_prompt=num_samples_per_prompt) inference_results_Q.put(mock_result) - mock_args = self.create_mock_args(num_engines=4, num_samples=num_samples_per_prompt) mock_generation_config = Mock() mock_generation_config.n = num_samples_per_prompt mock_model_dims = self.create_mock_model_dims() @@ -1110,14 +1113,15 @@ async def reward_fn_zero_std( result, batch, reward_metrics, batch_stats = grpo_fast.accumulate_inference_batches( inference_results_Q, - pending_queries_map, - mock_args, - generation_config=mock_generation_config, + mock_generation_config, num_prompts=num_prompts, model_dims=mock_model_dims, tokenizer=tokenizer, reward_fn=reward_fn_zero_std, + dataset=test_dataset, filter_zero_std_samples=True, + verbose=False, + max_possible_score=1.0, ) self.assertIsNone(result) From d6f8309108d737f495c251193b3e96b2aceb7c02 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Tue, 25 Nov 2025 15:27:52 -0700 Subject: [PATCH 47/96] fixed sharding --- open_instruct/grpo_fast.py | 1 - 1 file changed, 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 2831e29df..cc64bd2ca 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -634,7 +634,6 @@ def __init__( self.tokenizer = tokenizer self.pad_token_id = tokenizer.pad_token_id self.num_mini_batches = args.num_mini_batches - dataset = dataset.shard(num_shards=world_size, index=rank) self.dataloader = data_loader_config.build( dataset=dataset, reward_fn=reward_fn, From 661b96ff75095b56162a29e5c16f8be40e0f9707 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 26 Nov 2025 08:24:34 -0700 Subject: [PATCH 48/96] Fixed sharding --- open_instruct/streaming_data_loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index d4f66ab32..27eff9dde 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -335,7 +335,8 @@ def __init__( self.current_epoch = 0 dataset_indices = np.arange(len(dataset)) - self.iter_dataloader = ShufflingIterator(dataset_indices, 1, seed=seed + dp_rank) + dataset_indices = dataset_indices[dp_rank::dp_world_size] + self.iter_dataloader = ShufflingIterator(dataset_indices, 1, seed=seed) self.local_queue = StdQueue(maxsize=config.async_steps) self.background_thread = None From 98f3d5d97636b58293321a329d0013cf5c2a867f Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 26 Nov 2025 08:54:17 -0700 Subject: [PATCH 49/96] Committed fix. --- open_instruct/grpo_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index cc64bd2ca..eea69839d 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -641,7 +641,7 @@ def __init__( param_prompt_Q=param_prompt_Q, tokenizer=tokenizer, generation_config=generation_config, - dp_rank=self.local_rank, + dp_rank=rank, fs_local_rank=self.local_rank, num_training_steps=args.num_training_steps, seed=args.seed, From 87d5d0228555b74fed0895bbfb683a087fc9bba8 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 09:04:07 -0700 Subject: [PATCH 50/96] Made a bunch of changes. --- open_instruct/grpo_fast.py | 603 ++----------------------- open_instruct/queue_types.py | 12 +- open_instruct/streaming_data_loader.py | 126 ++---- open_instruct/test_grpo_fast.py | 31 +- open_instruct/vllm_utils.py | 2 + 5 files changed, 89 insertions(+), 685 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 1ca140c1d..60e2a0da6 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -37,16 +37,10 @@ import deepspeed from open_instruct import streaming_data_loader, utils -from open_instruct.streaming_data_loader import ( - ShufflingIterator, - accumulate_inference_batches, - add_prompt_to_generator, - collate_fn, -) +from open_instruct.streaming_data_loader import accumulate_inference_batches, add_prompt_to_generator, collate_fn # isort: on import asyncio -import json import logging import math import random @@ -55,11 +49,9 @@ import threading import time from argparse import Namespace -from collections import defaultdict from collections.abc import Callable from dataclasses import asdict, dataclass, field from datetime import timedelta -from itertools import chain from queue import Empty, Full, Queue from typing import Any, Literal @@ -106,8 +98,8 @@ print_rich_table, push_folder_to_hub, ) -from open_instruct.queue_types import ShutdownSentinel, GenerationResult, PromptRequest, RequestInfo, TokenStatistics -from open_instruct.rl_utils import PackedSequences, Timer, masked_mean, pack_sequences +from open_instruct.queue_types import ShutdownSentinel +from open_instruct.rl_utils import PackedSequences, Timer, masked_mean from open_instruct.utils import ( ArgumentParserPlus, BeakerRuntimeConfig, @@ -513,11 +505,16 @@ def __post_init__(self): self.max_possible_score += self.r1_style_format_reward -def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: - padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) - if pin_memory: - padded_tensor = padded_tensor.pin_memory() - return padded_tensor +def get_num_verifiers(dataset: Dataset) -> int: + if VERIFIER_SOURCE_KEY not in dataset.column_names: + return 0 + verifiers = set() + for item in dataset[VERIFIER_SOURCE_KEY]: + if isinstance(item, list): + verifiers.update(item) + else: + verifiers.add(item) + return len(verifiers) @Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker") @@ -615,7 +612,6 @@ def __init__( args: Args, data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, dataset: Dataset, - reward_fn: Callable, inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, tokenizer: PreTrainedTokenizer, @@ -629,7 +625,6 @@ def __init__( self.num_mini_batches = args.num_mini_batches self.dataloader = data_loader_config.build( dataset=dataset, - reward_fn=reward_fn, inference_results_Q=inference_results_Q, param_prompt_Q=param_prompt_Q, tokenizer=tokenizer, @@ -1035,7 +1030,6 @@ def step(self): batch_data = next(self.dataloader) batch_metrics = batch_data["metrics"] collated_query_responses = batch_data["collated_query_responses"] - collated_tool_masks = batch_data["collated_tool_masks"] collated_attention_masks = batch_data["collated_attention_masks"] collated_position_ids = batch_data["collated_position_ids"] collated_advantages = batch_data["collated_advantages"] @@ -1362,6 +1356,12 @@ def save_checkpoint_state(self, checkpoint_state_dir: str, client_state: dict[st checkpoint_state_dir, args.gs_checkpoint_state_dir ) + def get_dataloader_state(self) -> dict[str, Any]: + return self.dataloader.state_dict() + + def set_dataloader_state(self, state: dict[str, Any]) -> None: + self.dataloader.load_state_dict(state) + def save_model(self, output_dir: str, chat_template_name: str, tokenizer: PreTrainedTokenizer) -> None: model_to_save = self.model if chat_template_name is not None and "olmo" in chat_template_name: @@ -1450,7 +1450,6 @@ def __init__( args: Args, data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, dataset: Dataset, - reward_fn: Callable, inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, tokenizer: PreTrainedTokenizer, @@ -1480,7 +1479,6 @@ def __init__( args, data_loader_config, dataset, - reward_fn, inference_results_Q, param_prompt_Q, tokenizer, @@ -1529,7 +1527,6 @@ def get_bundle_index(rank, num_gpus_per_node): args, data_loader_config, dataset, - reward_fn, inference_results_Q, param_prompt_Q, tokenizer, @@ -1597,520 +1594,6 @@ def calculate_utilization_metrics( return utilization_metrics -@dataclass -class BatchStatistics: - prompt_lengths: list[int] - response_lengths: list[int] - filtered_prompts: int - filtered_prompts_zero: int - filtered_prompts_solved: int - filtered_prompts_nonzero: int - percent_solved_mean: float - no_resampled_prompts: int - total_prompts: int - - -def accumulate_inference_batches( - inference_results_Q: ray_queue.Queue, - args: Args, - generation_config: vllm.SamplingParams, - num_prompts: int, - model_dims: utils.ModelDims, - tokenizer: PreTrainedTokenizer, - prompt_dataset: Dataset, - data_loader: data_loader_lib.HFDataLoader | None = None, - param_prompt_Q: ray_queue.Queue | None = None, - actor_manager=None, - timeout: float | None = None, - active_sampling: bool = False, - filter_zero_std_samples: bool = False, - replenish_prompts: bool = False, - no_resampling_pass_rate: float | None = None, -) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: - """Accumulate multiple inference results into a single training batch. - - Args: - inference_results_Q: Queue containing individual GenerationResult objects (one per prompt) - args: Arguments containing vllm_num_engines and batch size info - generation_config: Generation config containing n (number of samples per prompt) - num_prompts: Number of prompts to accumulate - data_loader: Iterator over the dataloader for replenishing prompts. Required when - replenish_prompts=True or no_resampling_pass_rate is set. Can be None for - evaluation where all prompts are pre-queued. - prompt_dataset: Dataset containing prompts - param_prompt_Q: Queue containing prompts to send to generator. Required when - replenish_prompts=True. Can be None for evaluation where no replenishment is needed. - timeout: Optional timeout in seconds for queue get operations. If None, blocks indefinitely. - active_sampling: Whether to continue sampling until we have sampled num_prompts prompts with non-zero std - filter_zero_std_samples: Whether to filter samples with zero reward std - replenish_prompts: Add a prompt back onto the prompt_Q after receiving a finished result - no_resampling_pass_rate: Optional rate at which to note samples solved at greater than this rate - and exclude them from further sampling - - Raises: - queue.Empty: If timeout is specified and no data is available within timeout. - - Returns: - Tuple of (combined_result, Batch with queries, ground_truths, datasets, prompt_lengths, response_lengths) - or (ShutdownSentinel, None, None, None) if shutdown signal received - """ - if no_resampling_pass_rate is not None: - assert data_loader is not None, "no_resampling requires data_loader" - - if replenish_prompts: - assert param_prompt_Q is not None and data_loader is not None and prompt_dataset is not None, ( - "replenish_prompts requires param_prompt_Q, data_loader, and prompt_dataset" - ) - results = [] - all_queries = [] - all_ground_truths = [] - all_datasets = [] - all_raw_queries = [] - all_decoded_responses = [] - all_reward_metrics = [] - all_scores = [] - all_percent_solved = [] - total_filtered_prompts = 0 - filtered_prompt_zero = 0 - filtered_prompt_solved = 0 - filtered_prompt_nonzero = 0 - total_no_resampled = 0 - progress_bar = tqdm( - total=num_prompts, - desc=f"Accumulating Responses and Rewarding {num_prompts} prompts", - bar_format="{l_bar}{bar}{r_bar}\n", - disable=not args.verbose, - ) - num_prompts_sampled = 0 - while num_prompts_sampled < num_prompts: - result = inference_results_Q.get(timeout=timeout) - - if isinstance(result, ShutdownSentinel): - return result, None, None, None - - # Validate that each individual result has the expected number of responses - assert len(result.responses) == generation_config.n, ( - f"Mismatch: individual prompt result has {len(result.responses)} responses " - f"but expected {generation_config.n} samples per prompt. " - f"Prompt ID: {result.prompt_id}" - ) - - # Replenish generation queue with new prompt - if replenish_prompts: - add_prompt_to_generator(next(data_loader), param_prompt_Q, generation_config, is_eval=False) - - # TODO(finbarrtimbers): Move this to LLMRayActor. - for i in range(len(result.finish_reasons)): - if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: - result.responses[i].append(tokenizer.eos_token_id) - result.masks[i].append(1) - result.logprobs[i].append(float("nan")) - - decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) - - percent_solved = np.mean(result.reward_scores).item() / args.max_possible_score - # Don't resample prompt that was solved at more than no_resample_positive_rate - if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: - total_no_resampled += 1 - data_loader.exclude_index(result.dataset_index) - logging.debug( - f"[Data Preparation Thread] Prompt solved at {percent_solved}, total no resampled: {total_no_resampled}" - ) - - # Filter out zero std prompts - if filter_zero_std_samples and np.std(result.reward_scores) == 0: - # If we're not active sampling, still count this as a sample - if not active_sampling: - num_prompts_sampled += 1 - progress_bar.update(1) - - total_filtered_prompts += 1 - if result.reward_scores[0] == 0: - filtered_prompt_zero += 1 - elif result.reward_scores[0] == args.max_possible_score: - filtered_prompt_solved += 1 - else: - filtered_prompt_nonzero += 1 - logging.debug( - f"[Data Preparation Thread] Filtered prompt with reward std 0, total filtered {total_filtered_prompts}" - ) - continue - else: - num_prompts_sampled += 1 - progress_bar.update(1) - - results.append(result) - prompt_data = prompt_dataset[result.dataset_index] - all_queries.extend(repeat_each([prompt_data[INPUT_IDS_PROMPT_KEY]], generation_config.n)) - all_ground_truths.extend(repeat_each([prompt_data[GROUND_TRUTHS_KEY]], generation_config.n)) - all_datasets.extend(repeat_each([prompt_data[VERIFIER_SOURCE_KEY]], generation_config.n)) - all_raw_queries.extend(repeat_each([prompt_data[RAW_PROMPT_KEY]], generation_config.n)) - all_decoded_responses.extend(decoded_responses) - all_scores.extend(result.reward_scores) - all_reward_metrics.append(result.reward_metrics) - all_percent_solved.append(percent_solved) - - if len(results) == 0: - logger.warning( - "[Data Preparation Thread] All prompts were filtered during accumulation. " - f"Filtered: {total_filtered_prompts} (zero std: {filtered_prompt_zero}, " - f"solved: {filtered_prompt_solved}, nonzero: {filtered_prompt_nonzero})" - ) - return None, None, None, None - - # Combine all results into a single GenerationResult - combined_responses = [] - combined_finish_reasons = [] - combined_masks = [] - combined_num_calls = [] - combined_timeouts = [] - combined_tool_errors = [] - combined_tool_outputs = [] - combined_tool_runtimes = [] - combined_tool_calleds = [] - combined_logprobs = [] - - earliest_start_time = float("inf") - prompt_lengths = [] - response_lengths = [] - - total_prompt_tokens = 0 - total_response_tokens = 0 - max_generation_time = 0 - - for i, result in enumerate(results): - combined_responses.extend(result.responses) - combined_finish_reasons.extend(result.finish_reasons) - combined_masks.extend(result.masks) - combined_num_calls.extend(result.request_info.num_calls) - combined_timeouts.extend(result.request_info.timeouts) - combined_tool_errors.extend(result.request_info.tool_errors) - combined_tool_outputs.extend(result.request_info.tool_outputs) - combined_tool_runtimes.extend(result.request_info.tool_runtimes) - combined_tool_calleds.extend(result.request_info.tool_calleds) - - combined_logprobs.extend(result.logprobs) - - earliest_start_time = min(earliest_start_time, result.start_time) - - prompt_lengths.append(len(all_queries[i * generation_config.n])) - - for response in result.responses: - response_lengths.append(len(response)) - - total_prompt_tokens += result.token_statistics.num_prompt_tokens - total_response_tokens += result.token_statistics.num_response_tokens - max_generation_time = max(max_generation_time, result.token_statistics.generation_time) - - # Use the maximum generation time across engines since they work in parallel - # This avoids including queue overhead and accumulation time in MFU/MBU calculations - total_generation_time = max_generation_time - - accumulated_stats = TokenStatistics( - num_prompt_tokens=total_prompt_tokens, - num_response_tokens=total_response_tokens, - generation_time=total_generation_time, - earliest_start_time=earliest_start_time, - ) - - # Create combined RequestInfo - combined_request_info = RequestInfo( - num_calls=combined_num_calls, - timeouts=combined_timeouts, - tool_errors=combined_tool_errors, - tool_outputs=combined_tool_outputs, - tool_runtimes=combined_tool_runtimes, - tool_calleds=combined_tool_calleds, - ) - - # Create combined GenerationResult - combined_result = GenerationResult( - responses=combined_responses, - finish_reasons=combined_finish_reasons, - masks=combined_masks, - request_info=combined_request_info, - dataset_index=None, - prompt_id=None, - token_statistics=accumulated_stats, - logprobs=combined_logprobs, - ) - - if actor_manager is not None: - ray.get(actor_manager.report_token_statistics.remote(accumulated_stats)) - - # Note: We don't have dataset_indices here, but they're not needed for the returned batch - batch = Batch( - queries=all_queries, - ground_truths=all_ground_truths, - datasets=all_datasets, - raw_queries=all_raw_queries, - decoded_responses=all_decoded_responses, - indices=None, # Not meaningful for combined results - scores=all_scores, - ) - - combined_reward_metrics = combine_reward_metrics(all_reward_metrics) - percent_solved_mean = np.mean(all_percent_solved) if all_percent_solved else 0.0 - - batch_stats = BatchStatistics( - prompt_lengths=prompt_lengths, - response_lengths=response_lengths, - filtered_prompts=total_filtered_prompts, - filtered_prompts_zero=filtered_prompt_zero, - filtered_prompts_solved=filtered_prompt_solved, - filtered_prompts_nonzero=filtered_prompt_nonzero, - percent_solved_mean=percent_solved_mean, - no_resampled_prompts=total_no_resampled, - total_prompts=len(results), - ) - return combined_result, batch, combined_reward_metrics, batch_stats - - -def data_preparation_thread( - inference_results_Q: ray_queue.Queue, - param_prompt_Q: ray_queue.Queue, - packed_sequences_Q: Queue, - args: Args, - tokenizer: PreTrainedTokenizer, - num_training_steps: int, - generation_config, - resume_training_step: int, - data_loader: data_loader_lib.HFDataLoader, - train_dataset: Dataset, - actor_manager=None, - model_dims: utils.ModelDims = None, -): - for training_step in range(resume_training_step, num_training_steps + 1): - # Streaming accumulation: collect results as they arrive - with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: - result, batch, reward_metrics, batch_stats = accumulate_inference_batches( - inference_results_Q, - generation_config, - num_prompts=args.num_unique_prompts_rollout, - model_dims=model_dims, - tokenizer=tokenizer, - reward_fn=reward_fn, - dataset=train_dataset, - data_loader=data_loader, - prompt_dataset=train_dataset, - param_prompt_Q=param_prompt_Q, - actor_manager=actor_manager, - active_sampling=args.active_sampling, - filter_zero_std_samples=args.filter_zero_std_samples, - replenish_prompts=True, - no_resampling_pass_rate=args.no_resampling_pass_rate, - ) - if isinstance(result, ShutdownSentinel): - logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting") - return - if result is None: - logger.info("[Data Preparation Thread] All prompts filtered, putting empty batch into queue") - packed_sequences = PackedSequences( - query_responses=[], - attention_masks=[], - response_masks=[], - original_responses=[], - advantages=[], - position_ids=[], - vllm_logprobs=[], - ) - collated_data = [] - packed_sequences_Q.put( - { - "packed_sequences": packed_sequences, - "collated_data": collated_data, - "metrics": {}, - "responses_count": 0, - "num_new_tokens": 0, - "B": 0, - "prompt_lengths": [], - "response_lengths": [], - "num_filtered_prompts": 0, - } - ) - continue - - getting_response_time = timer.duration - scores = np.array(batch.scores) - - good_outputs = [ - len(result.request_info.tool_outputs[i]) > 0 - and result.request_info.tool_calleds[i] - and not result.request_info.timeouts[i] - and not result.request_info.tool_errors[i] - for i in range(len(result.request_info.tool_outputs)) - ] - scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout) - mean_grouped_rewards = scores_per_prompt.mean(axis=-1) - mean_grouped_rewards = np.repeat(mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0) - std_grouped_rewards = scores_per_prompt.std(axis=-1) - std_grouped_rewards = np.repeat(std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0) - if args.advantage_normalization_type == "standard": - advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - elif args.advantage_normalization_type == "centered": - advantages = scores - mean_grouped_rewards - else: - raise ValueError(f"Invalid advantage normalization type: {args.advantage_normalization_type}") - - if args.mask_truncated_completions: - stop_idxes = torch.tensor( - [i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"] - ) - num_truncated = len(result.finish_reasons) - len(stop_idxes) - if num_truncated > 0: - logger.info( - f"[Truncated completions filtering] Filtered {num_truncated} responses that didn't finish with 'stop'. " - f"Retention rate: {len(stop_idxes) / len(result.finish_reasons):.2%}" - ) - scores = scores[stop_idxes] - advantages = advantages[stop_idxes] - batch = batch[stop_idxes.tolist()] - result.responses = [result.responses[i] for i in stop_idxes] - result.masks = [result.masks[i] for i in stop_idxes] - result.finish_reasons = [result.finish_reasons[i] for i in stop_idxes] - result.logprobs = [result.logprobs[i] for i in stop_idxes] - - with Timer("📦 [Data Preparation Thread] Packing sequences"): - packed_sequences = pack_sequences( - queries=batch.queries, - responses=result.responses, - masks=result.masks, - pack_length=args.pack_length, - pad_token_id=tokenizer.pad_token_id, - vllm_logprobs=result.logprobs, - mask_tool_use=args.mask_tool_use, - ) - num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses) - # Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value - # and each value is the corresponding advantage score: index 0 is set to 0 since response masks start from 1 (1-indexed) - lookup_advantages = np.zeros(len(advantages) + 1, dtype=np.float32) - lookup_advantages[1:] = advantages - packed_advantages = [ - torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) - for packed_mask in packed_sequences.response_masks - ] - packed_sequences.advantages = packed_advantages - - # if we have less batches than world size, we need to pad out so each world is fine - # ideally, you should avoid this since its wasting computation. - if args.allow_world_padding: - with Timer("🤺 [Data Preparation Thread] Padding sequences for world size"): - shortfall = args.world_size - len(packed_sequences.query_responses) - if shortfall > 0: - logger.warning( - f"Padding {shortfall} sequences for world size. In future, you should adjust your compute this." - ) - # construct "dummy" sequences for padding out the world size - dummy_qr = torch.tensor([tokenizer.pad_token_id, tokenizer.eos_token_id], dtype=torch.long) - dummy_attention = torch.tensor([1, 1], dtype=torch.long) - dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long) - dummy_response_mask = torch.zeros_like(dummy_qr) - dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float) - # pad out the world size - for _ in range(shortfall): - packed_sequences.query_responses.append(dummy_qr) - packed_sequences.attention_masks.append(dummy_attention) - packed_sequences.position_ids.append(dummy_position_ids) - packed_sequences.response_masks.append(dummy_response_mask) - packed_sequences.advantages.append(dummy_advantage) - - collated_data = prepare_collated_data_for_workers( - packed_sequences, args.world_size, args.per_device_train_batch_size, tokenizer.pad_token_id - ) - B = len(packed_sequences.query_responses) // args.world_size - - # Create a result package with metrics and data - if len(result.responses) == 0: - # Handle empty responses case - # in this case, we won't log metrics, so it should be fine. - metrics = {} - logger.warning(f"No responses in batch {training_step}.") - else: - real_num_responses = len(result.responses) - expected_num_responses = args.num_samples_per_prompt_rollout * args.num_unique_prompts_rollout - - unsolved_num_responses = (scores < args.max_possible_score).sum() - sequence_lengths = np.array([len(response) for response in result.responses]) - sequence_length_solved = ( - np.array([]) if np.all(scores == 0) else np.array(sequence_lengths[scores == args.max_possible_score]) - ) - sequence_length_unsolved = ( - np.array([]) if np.all(scores == args.max_possible_score) else np.array(sequence_lengths[scores == 0]) - ) - stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( - result.finish_reasons - ) - - batch_metrics = asdict(batch_stats) - batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()} - - metrics = { - "scores": scores.mean(), - "real_batch_size_ratio": real_num_responses / expected_num_responses, - "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, - "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, - "val/solve_rate_hist": None, - "val/total_reward_groups": real_num_responses / args.num_samples_per_prompt_rollout, - "val/sequence_lengths": sequence_lengths.mean(), - "val/sequence_lengths_min": sequence_lengths.min(), - "val/sequence_lengths_max": sequence_lengths.max(), - "val/sequence_lengths_unsolved": ( - 0 if len(sequence_length_unsolved) == 0 else sequence_length_unsolved.mean() - ), - "val/sequence_lengths_solved": ( - 0 if len(sequence_length_solved) == 0 else sequence_length_solved.mean() - ), - "val/sequence_lengths_unsolved_hist": sequence_length_unsolved, - "val/sequence_lengths_solved_hist": sequence_length_solved, - "val/stop_rate": stop_rate, - "val/advantages_mean": advantages.mean(), - "val/advantages_min": advantages.min(), - "val/advantages_max": advantages.max(), - "val/advantages_hist": advantages, - "val/num_calls_rate": np.array(result.request_info.num_calls).mean(), - "val/timeouts_rate": np.array(result.request_info.timeouts).mean(), - "val/tool_errors_rate": np.array([len(item) > 0 for item in result.request_info.tool_errors]).mean(), - "val/good_outputs_rate": np.array(good_outputs).mean(), - "val/tool_runtimes_rate": np.array(result.request_info.tool_runtimes).mean(), - "val/tool_calleds_rate": np.array(result.request_info.tool_calleds).mean(), - "time/getting_response": getting_response_time, - **reward_metrics, - **batch_metrics_prefixed, - } - - total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens - metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time - - if args.save_traces: - traces = { - "scores": scores.tolist(), - "finish_reasons": result.finish_reasons, - "responses": result.responses, - "training_step": training_step, - **asdict(batch), # Unpack all batch fields - **reward_metrics, - } - os.makedirs(args.output_dir, exist_ok=True) - with open(f"{args.output_dir}/traces_{args.run_name}.jsonl", "a") as f: - json.dump(traces, f) - f.write("\n") - - # Put the packed sequences and metrics into the output queue - packed_sequences_Q.put( - { - "packed_sequences": packed_sequences, # for debugging purposes - "collated_data": collated_data, - "metrics": metrics, - "responses_count": len(result.responses), - "num_new_tokens": num_new_tokens, - "B": B, - "prompt_lengths": batch_stats.prompt_lengths, - "response_lengths": batch_stats.response_lengths, - "num_filtered_prompts": batch_stats.filtered_prompts, - } - ) - - def setup_runtime_variables(args: Args, streaming_config: streaming_data_loader.StreamingDataLoaderConfig) -> Args: """Set up runtime variables for the experiment.""" args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" @@ -2239,7 +1722,7 @@ def create_model_and_optimizer( data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, train_dataset: Dataset, eval_dataset, - reward_fn: Callable, + reward_config: RewardConfig, generation_config, ) -> tuple[ModelGroup, list[vllm_utils.LLMRayActor], dict, int, int]: """Create the model, optimizer, and vLLM engines.""" @@ -2354,7 +1837,6 @@ def create_model_and_optimizer( args=args, data_loader_config=data_loader_config, dataset=train_dataset, - reward_fn=reward_fn, inference_results_Q=inference_results_Q, param_prompt_Q=param_prompt_Q, tokenizer=tokenizer, @@ -2414,7 +1896,6 @@ def create_generation_configs(args: Args, streaming_config: streaming_data_loade return {"train": generation_config, "eval": eval_generation_config} - def weight_sync_thread( args: Args, stop_event: threading.Event, @@ -2604,6 +2085,7 @@ def maybe_evaluate( generate_metrics_Q: Queue, num_eval_prompts: int, model_dims: utils.ModelDims, + reward_fn: Callable, actor_manager=None, ): """Optionally evaluate the model.""" @@ -2620,7 +2102,7 @@ def maybe_evaluate( model_dims=model_dims, tokenizer=tokenizer, reward_fn=reward_fn, - prompt_dataset=eval_dataset, + dataset=eval_dataset, actor_manager=actor_manager, timeout=timeout, active_sampling=False, @@ -2788,7 +2270,6 @@ def run_training( vllm_engines, generation_configs, reward_fn, - data_loader, resume_training_step, episode, wandb_url, @@ -2798,7 +2279,6 @@ def run_training( inference_results_Q, param_prompt_Q, evaluation_inference_results_Q, - packed_sequences_Q, generate_metrics_Q, weight_sync_metrics_Q, actor_manager: ActorManager, @@ -2808,6 +2288,17 @@ def run_training( if resume_training_step > 1: logger.info(f"[Main Thread] Resuming training from step {resume_training_step}") + # Restore dataloader state if available in checkpoint + if checkpoint_state and "dataloader_state" in checkpoint_state: + ray_get_with_progress( + [ + policy_group.models[i].set_dataloader_state.remote(checkpoint_state["dataloader_state"]) + for i in range(args.world_size) + ], + desc="Restoring dataloader state", + ) + logger.info("Restored dataloader state from checkpoint") + logger.info("======== ✅ weight sync thread starts =========") weight_sync_trigger_event = threading.Event() weight_sync_thread_future = executor.submit( @@ -2882,8 +2373,6 @@ def health_check_fn(): ): for eval_example in iter(eval_data_loader): add_prompt_to_generator(eval_example, param_prompt_Q, generation_configs["eval"], is_eval=True) - if collated_data is None: - continue episode += args.num_unique_prompts_rollout * streaming_config.num_samples_per_prompt_rollout @@ -2929,9 +2418,8 @@ def health_check_fn(): "num_total_tokens": num_total_tokens, } - # Save dataloader state - if data_loader is not None: - client_state["dataloader_state"] = data_loader.state_dict() + # Save dataloader state from Ray actor + client_state["dataloader_state"] = ray.get(policy_group.models[0].get_dataloader_state.remote()) ray_get_with_progress( [ @@ -2956,6 +2444,7 @@ def health_check_fn(): generate_metrics_Q, len(eval_dataset) if eval_dataset else 0, model_dims, + reward_fn, actor_manager, ) @@ -3018,7 +2507,7 @@ def main( ) generation_configs = create_generation_configs(args, streaming_config) - (policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims) = ( + (policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims, reward_fn) = ( create_model_and_optimizer( args, tc, @@ -3048,22 +2537,7 @@ def main( episode = checkpoint_state["episode"] logger.info(f"Restored episode count: {episode}") - data_loader = data_loader_lib.HFDataLoader( - dataset=train_dataset, - batch_size=1, - seed=args.seed, - rank=0, - world_size=1, - work_dir=args.output_dir, - automatic_reshuffle=True, - ) - - if checkpoint_state and "dataloader_state" in checkpoint_state: - data_loader.load_state_dict(checkpoint_state["dataloader_state"]) - logger.info("Restored dataloader state from checkpoint") - # Create additional queues (main queues already created above) - packed_sequences_Q = Queue(maxsize=args.async_steps) generate_metrics_Q = Queue(maxsize=args.async_steps) weight_sync_metrics_Q = Queue(maxsize=args.async_steps) @@ -3080,7 +2554,7 @@ def main( policy_group, vllm_engines, generation_configs, - data_loader, + reward_fn, resume_training_step, episode, wandb_url, @@ -3090,7 +2564,6 @@ def main( inference_results_Q, param_prompt_Q, evaluation_inference_results_Q, - packed_sequences_Q, generate_metrics_Q, weight_sync_metrics_Q, actor_manager, diff --git a/open_instruct/queue_types.py b/open_instruct/queue_types.py index 0861d9e97..e3b9014ba 100644 --- a/open_instruct/queue_types.py +++ b/open_instruct/queue_types.py @@ -36,8 +36,9 @@ class GenerationResult: finish_reasons: list[str] masks: list[list[int]] request_info: RequestInfo - dataset_index: int | None - prompt_id: str | None + dataset_index: int | None = None + prompt_id: str | None = None + epoch_number: int = 0 token_statistics: TokenStatistics | None = None start_time: float | None = None logprobs: list[list[float]] | None = None @@ -57,5 +58,10 @@ class PromptRequest: prompt: list[int] generation_config: Any dataset_index: int - prompt_id: str + epoch_number: int = 0 + training_step: int = 0 is_eval: bool = False + + @property + def prompt_id(self) -> str: + return f"{self.epoch_number}_{self.dataset_index}" diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index 27eff9dde..77c2bd36b 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio import logging import threading from abc import abstractmethod -from collections.abc import Callable, Iterable +from collections.abc import Iterable from dataclasses import asdict, dataclass from pathlib import Path from queue import Queue as StdQueue @@ -30,6 +29,7 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer +from open_instruct import data_loader as data_loader_lib from open_instruct import utils from open_instruct.dataset_transformation import ( GROUND_TRUTHS_KEY, @@ -87,7 +87,6 @@ def __post_init__(self): def build( self, dataset: Dataset, - reward_fn: Callable, inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, tokenizer: PreTrainedTokenizer, @@ -107,7 +106,6 @@ def build( ) -> "StreamingDataLoader": return StreamingDataLoader( dataset=dataset, - reward_fn=reward_fn, inference_results_Q=inference_results_Q, param_prompt_Q=param_prompt_Q, tokenizer=tokenizer, @@ -225,71 +223,11 @@ def global_num_tokens_in_batch(self, batch: dict[str, Any]) -> int | None: return self.global_batch_size -class ShufflingIterator: - def __init__(self, data: np.ndarray, batch_size: int, seed: int | None = None): - self.data = data.copy() - self.batch_size = batch_size - self.index = 0 - self.epoch_number = 0 - self.rng = np.random.default_rng(seed) - self.rng.shuffle(self.data) - self.exclude_list = [] - - self._update_effective_size() - - def __iter__(self): - return self - - def __next__(self) -> list[int] | int: - if self.index >= self.effective_size: - self.index = 0 - self._update_effective_size() - self.epoch_number += 1 - self.rng.shuffle(self.data) - - end_index = self.index + self.batch_size - batch = self.data[self.index : end_index].tolist() - if self.batch_size == 1: - batch = batch[0] - self.index = end_index - - return batch - - def get_state(self) -> dict[str, Any]: - return { - "index": self.index, - "epoch_number": self.epoch_number, - "data": self.data.copy(), - "rng_state": self.rng.bit_generator.state, - "exclude_list": self.exclude_list.copy(), - } - - def set_state(self, state: dict[str, Any]) -> None: - self.index = state["index"] - self.epoch_number = state.get("epoch_number", 0) - self.data = state["data"].copy() - self.rng.bit_generator.state = state["rng_state"] - self.exclude_list = state.get("exclude_list", []) - self._update_effective_size() - - def exclude_index(self, index: int) -> None: - self.exclude_list.append(index) - - def _update_effective_size(self) -> None: - if self.exclude_list: - mask = ~np.isin(self.data, self.exclude_list) - self.data = self.data[mask] - self.exclude_list = [] - - self.effective_size = len(self.data) - (len(self.data) % self.batch_size) - - class StreamingDataLoader(TextDataLoaderBase): def __init__( self, *, dataset: Dataset, - reward_fn: Callable, inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, tokenizer: PreTrainedTokenizer, @@ -317,7 +255,6 @@ def __init__( ) self.dataset = dataset - self.reward_fn = reward_fn self.inference_results_Q = inference_results_Q self.param_prompt_Q = param_prompt_Q self.tokenizer = tokenizer @@ -333,10 +270,17 @@ def __init__( self.training_step = 0 self.current_epoch = 0 + self.seed = seed - dataset_indices = np.arange(len(dataset)) - dataset_indices = dataset_indices[dp_rank::dp_world_size] - self.iter_dataloader = ShufflingIterator(dataset_indices, 1, seed=seed) + self.iter_dataloader = data_loader_lib.HFDataLoader( + dataset=dataset, + batch_size=1, + seed=seed, + rank=dp_rank, + world_size=dp_world_size, + work_dir=work_dir, + automatic_reshuffle=True, + ) self.local_queue = StdQueue(maxsize=config.async_steps) self.background_thread = None @@ -350,13 +294,13 @@ def state_dict(self) -> dict[str, Any]: return { "training_step": self.training_step, "current_epoch": self.current_epoch, - "iter_dataloader_state": self.iter_dataloader.get_state(), + "iter_dataloader_state": self.iter_dataloader.state_dict(), } def load_state_dict(self, state_dict: dict[str, Any]): self.training_step = state_dict["training_step"] self.current_epoch = state_dict.get("current_epoch", 0) - self.iter_dataloader.set_state(state_dict["iter_dataloader_state"]) + self.iter_dataloader.load_state_dict(state_dict["iter_dataloader_state"]) def reshuffle(self, epoch: int | None = None, **kwargs): if epoch is not None: @@ -398,13 +342,12 @@ def _start_background_thread(self): def _data_preparation_loop(self): for _ in range(self.config.async_steps * self.global_batch_size // self.dp_world_size): - local_index = next(self.iter_dataloader) - example = self.dataset[local_index] - dataset_index = example["index"] + example = next(self.iter_dataloader) + dataset_index = example["dataset_index"] add_prompt_to_generator( example, dataset_index, - self.iter_dataloader.epoch_number, + self.iter_dataloader._epoch, self.training_step, self.param_prompt_Q, self.generation_config, @@ -423,7 +366,6 @@ def _data_preparation_loop(self): num_prompts=self.rank_batch_size, model_dims=self.model_dims, tokenizer=self.tokenizer, - reward_fn=self.reward_fn, dataset=self.dataset, actor_manager=self.actor_manager, active_sampling=self.config.active_sampling, @@ -724,7 +666,6 @@ def accumulate_inference_batches( num_prompts: int, model_dims: utils.ModelDims, tokenizer: PreTrainedTokenizer, - reward_fn: Callable, dataset: Dataset, actor_manager=None, timeout: float | None = None, @@ -732,7 +673,7 @@ def accumulate_inference_batches( filter_zero_std_samples: bool = False, replenish_prompts: bool = False, no_resampling_pass_rate: float | None = None, - iter_dataloader: ShufflingIterator | None = None, + iter_dataloader: data_loader_lib.HFDataLoader | None = None, param_prompt_Q: ray_queue.Queue | None = None, training_step: int = None, verbose: bool = False, @@ -788,13 +729,12 @@ def accumulate_inference_batches( raw_query = example[RAW_PROMPT_KEY] if replenish_prompts: - local_index = next(iter_dataloader) - example = dataset[local_index] - dataset_index = example["index"] + example = next(iter_dataloader) + dataset_index = example["dataset_index"] add_prompt_to_generator( example, dataset_index, - iter_dataloader.epoch_number, + iter_dataloader._epoch, training_step, param_prompt_Q, generation_config, @@ -814,19 +754,7 @@ def accumulate_inference_batches( k_datasets = repeat_each([dataset_name], generation_config.n) k_raw_queries = repeat_each([raw_query], generation_config.n) - scores, reward_metrics = asyncio.run( - reward_fn( - result.responses, - decoded_responses, - k_ground_truths, - k_datasets, - result.finish_reasons, - result.request_info, - k_raw_queries, - ) - ) - - percent_solved = np.mean(scores).item() / max_possible_score + percent_solved = np.mean(result.reward_scores).item() / max_possible_score if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: iter_dataloader.exclude_index(result.dataset_index) total_no_resampled += 1 @@ -834,15 +762,15 @@ def accumulate_inference_batches( f"[Data Preparation Thread] Prompt solved at {percent_solved}, will be excluded from resampling, total no resampled: {total_no_resampled}" ) - if filter_zero_std_samples and np.std(scores) == 0: + if filter_zero_std_samples and np.std(result.reward_scores) == 0: if not active_sampling: num_prompts_sampled += 1 progress_bar.update(1) total_filtered_prompts += 1 - if scores[0] == 0: + if result.reward_scores[0] == 0: filtered_prompt_zero += 1 - elif scores[0] == max_possible_score: + elif result.reward_scores[0] == max_possible_score: filtered_prompt_solved += 1 else: filtered_prompt_nonzero += 1 @@ -860,8 +788,8 @@ def accumulate_inference_batches( all_datasets.extend(k_datasets) all_raw_queries.extend(k_raw_queries) all_decoded_responses.extend(decoded_responses) - all_scores.extend(scores) - all_reward_metrics.append(reward_metrics) + all_scores.extend(result.reward_scores) + all_reward_metrics.append(result.reward_metrics) all_percent_solved.append(percent_solved) if len(results) == 0: diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index 297994f6e..0df038467 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -25,7 +25,6 @@ ) from open_instruct.grpo_fast import get_num_verifiers from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics -from open_instruct.streaming_data_loader import PendingQueriesMap, ShufflingIterator from open_instruct.vllm_utils import create_vllm_engines @@ -279,7 +278,9 @@ def setup_and_add_prompts_to_generator(self, queries, ground_truths, datasets, r ) for example in data_loader: - grpo_fast.add_prompt_to_generator(example, param_prompt_Q, mock_generation_config, False) + grpo_fast.add_prompt_to_generator( + example, example["dataset_index"], 0, 0, param_prompt_Q, mock_generation_config, False + ) return param_prompt_Q, inference_results_Q, mock_dataset @@ -620,7 +621,6 @@ def test_out_of_order_processing(self): mock_result = self.create_mock_result_from_request(request, num_samples_per_prompt) inference_results_Q.put(mock_result) - mock_args = self.create_mock_args(num_engines, num_samples_per_prompt) mock_generation_config = Mock() mock_generation_config.n = num_samples_per_prompt @@ -631,7 +631,7 @@ def test_out_of_order_processing(self): num_prompts=num_prompts, model_dims=mock_model_dims, tokenizer=tokenizer, - prompt_dataset=mock_dataset, + dataset=mock_dataset, ) self.assertEqual(len(batch.queries), num_prompts * num_samples_per_prompt) @@ -674,7 +674,7 @@ def run_accumulate(): num_prompts=num_prompts, model_dims=mock_model_dims, tokenizer=tokenizer, - prompt_dataset=mock_dataset, + dataset=mock_dataset, ) completed.set() except Exception: @@ -711,7 +711,9 @@ def test_more_engines_than_queries(self): ) for example in data_loader: - grpo_fast.add_prompt_to_generator(example, param_prompt_Q, mock_generation_config, False) + grpo_fast.add_prompt_to_generator( + example, example["dataset_index"], 0, 0, param_prompt_Q, mock_generation_config, False + ) self.assertEqual( param_prompt_Q.qsize(), num_queries, f"Should have {num_queries} batches for {num_queries} queries" @@ -744,7 +746,9 @@ def test_uneven_distribution_no_empty_batches(self): ) for example in data_loader: - grpo_fast.add_prompt_to_generator(example, param_prompt_Q, mock_generation_config, False) + grpo_fast.add_prompt_to_generator( + example, example["dataset_index"], 0, 0, param_prompt_Q, mock_generation_config, False + ) request_count = 0 while not param_prompt_Q.empty(): @@ -840,20 +844,11 @@ def test_all_prompts_filtered_returns_none(self): queries, ground_truths, datasets_list, raw_queries, indices = self.create_test_data(num_prompts) - test_dataset = Dataset.from_dict( - { - INPUT_IDS_PROMPT_KEY: queries, - GROUND_TRUTHS_KEY: ground_truths, - VERIFIER_SOURCE_KEY: datasets_list, - RAW_PROMPT_KEY: raw_queries, - } - ) - inference_results_Q = ray_queue.Queue(maxsize=num_prompts) self._ray_queues.append(inference_results_Q) - mock_dataset = self.create_mock_dataset(queries, ground_truths, datasets, raw_queries) + mock_dataset = self.create_mock_dataset(queries, ground_truths, datasets_list, raw_queries) for i in range(num_prompts): constant_scores = [0.5] * num_samples_per_prompt @@ -875,7 +870,7 @@ def test_all_prompts_filtered_returns_none(self): num_prompts=num_prompts, model_dims=mock_model_dims, tokenizer=tokenizer, - prompt_dataset=mock_dataset, + dataset=mock_dataset, filter_zero_std_samples=True, verbose=False, max_possible_score=1.0, diff --git a/open_instruct/vllm_utils.py b/open_instruct/vllm_utils.py index 94d006131..5c7777054 100644 --- a/open_instruct/vllm_utils.py +++ b/open_instruct/vllm_utils.py @@ -379,6 +379,7 @@ def process_completed_request(request_id, outs, current_time, tools, request_met ), dataset_index=metadata["dataset_index"], prompt_id=metadata["prompt_id"], + epoch_number=metadata.get("epoch_number", 0), token_statistics=TokenStatistics( num_prompt_tokens=len(metadata["prompt_token_ids"]), num_response_tokens=total_generation_tokens, @@ -486,6 +487,7 @@ def add_request(actor: "LLMRayActor", request: PromptRequest) -> None: "is_eval": request.is_eval, "dataset_index": request.dataset_index, "prompt_id": request.prompt_id, + "epoch_number": request.epoch_number, "sampling_params": sampling_params, "original_sampling_params": request.generation_config, "prompt_token_ids": list(request.prompt), From adbaf83c999d8f891199d93b239e67490c91bb5e Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 09:31:54 -0700 Subject: [PATCH 51/96] Remove duplicate args from Args class MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit These parameters now live in StreamingDataLoaderConfig: - active_sampling - filter_zero_std_samples - no_resampling_pass_rate - mask_truncated_completions - advantage_normalization_type 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 60e2a0da6..3fa87b495 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -225,8 +225,6 @@ class Args: """Enable immediate stopping of request processing when should_stop is set, allowing for quick pausing and resumption""" kl_estimator: Literal[0, 1, 2, 3] = 2 """the KL estimator to use""" - pack_length: int = 512 - """the length of the pack (you should prob set to the max length of the model)""" loss_denominator: str = "token" """Optional constant denominator for masked_mean; can be "token" or a float value. when "token", the loss is divided by the total number of tokens in the batch (standard LM training). @@ -241,19 +239,6 @@ class Args: """How many training steps to take before updating the reference policy.""" load_ref_policy: bool = True """Whether to load and use a reference policy for KL penalty calculation.""" - advantage_normalization_type: Literal["standard", "centered"] = "standard" - """The type of advantage normalization to use. Standard normalization is the default: it subtracts the mean and - divides by the standard deviation. Centered normalization is the same but subtracts the mean only (e.g., used in - DR.GRPO https://arxiv.org/pdf/2503.20783).""" - mask_truncated_completions: bool = False - """Whether to mask out truncated completions. Also called overlong filtering, from DAPO (https://arxiv.org/abs/2503.14476).""" - - active_sampling: bool = False - """Whether to continue sampling responses until you get a full batch.""" - filter_zero_std_samples: bool = True - """Whether to filter out prompts with zero reward std (all samples have the same score).""" - no_resampling_pass_rate: float | None = None - """If the response to a prompt is solved at a rate higher than this, do not resample this prompt again""" record_entropy: bool = False """whether to record the entropy of the policy during training. Uses extra memory.""" use_vllm_logprobs: bool = False From 0d1252a9f0f18e6ac5d05607b04f02b35f813167 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 09:39:02 -0700 Subject: [PATCH 52/96] Remove num_samples_per_prompt_rollout validation from Args MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This field now lives in StreamingDataLoaderConfig, which has its own validation in __post_init__. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3fa87b495..4cdbc4749 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -422,9 +422,6 @@ def __post_init__(self): "use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless." ) self.loss_denominator = utils.get_denominator(self.loss_denominator) - assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" - if self.num_samples_per_prompt_rollout == 1: - logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") assert self.apply_verifiable_reward or self.apply_r1_style_format_reward or self.non_stop_penalty, ( "At least one reward must be applied!" ) From db7dd2b695ab0f3a9caf0a8bb7c2f1895a58fd79 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 10:01:22 -0700 Subject: [PATCH 53/96] Remove stale reward_fn references from grpo_fast.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After PR #1225 moved reward_fn to live inside LLMRayActor, these references were left behind during the merge. This removes: - reward_fn parameter from maybe_evaluate() and run_training() - reward_fn from accumulate_inference_batches() calls - reward_fn from create_model_and_optimizer return value unpacking - Unused Callable import 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 4cdbc4749..72d0b7f46 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -49,7 +49,6 @@ import threading import time from argparse import Namespace -from collections.abc import Callable from dataclasses import asdict, dataclass, field from datetime import timedelta from queue import Empty, Full, Queue @@ -2067,7 +2066,6 @@ def maybe_evaluate( generate_metrics_Q: Queue, num_eval_prompts: int, model_dims: utils.ModelDims, - reward_fn: Callable, actor_manager=None, ): """Optionally evaluate the model.""" @@ -2083,7 +2081,6 @@ def maybe_evaluate( num_prompts=num_eval_prompts, model_dims=model_dims, tokenizer=tokenizer, - reward_fn=reward_fn, dataset=eval_dataset, actor_manager=actor_manager, timeout=timeout, @@ -2251,7 +2248,6 @@ def run_training( policy_group, vllm_engines, generation_configs, - reward_fn, resume_training_step, episode, wandb_url, @@ -2426,7 +2422,6 @@ def health_check_fn(): generate_metrics_Q, len(eval_dataset) if eval_dataset else 0, model_dims, - reward_fn, actor_manager, ) @@ -2489,7 +2484,7 @@ def main( ) generation_configs = create_generation_configs(args, streaming_config) - (policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims, reward_fn) = ( + (policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims) = ( create_model_and_optimizer( args, tc, @@ -2536,7 +2531,6 @@ def main( policy_group, vllm_engines, generation_configs, - reward_fn, resume_training_step, episode, wandb_url, From b0de54a87e4e23045b9275c38d71232adbe626c4 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 10:12:30 -0700 Subject: [PATCH 54/96] Fix args.async_steps -> streaming_config.async_steps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit async_steps was moved to StreamingDataLoaderConfig but these references in main() were not updated. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 72d0b7f46..925a436ee 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -2515,8 +2515,8 @@ def main( logger.info(f"Restored episode count: {episode}") # Create additional queues (main queues already created above) - generate_metrics_Q = Queue(maxsize=args.async_steps) - weight_sync_metrics_Q = Queue(maxsize=args.async_steps) + generate_metrics_Q = Queue(maxsize=streaming_config.async_steps) + weight_sync_metrics_Q = Queue(maxsize=streaming_config.async_steps) stop_event = threading.Event() executor = futures.ThreadPoolExecutor(max_workers=3, thread_name_prefix="grpo") From 4005ea5f6183f5b65456f2b216185a6548bd28cd Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 10:16:26 -0700 Subject: [PATCH 55/96] Fix KeyError for time/reward in combined_reward_metrics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The reward timing is now included in the GenerationResult from LLMRayActor, but combined_reward_metrics may not have this key. Use conditional check before logging. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/streaming_data_loader.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index 77c2bd36b..d5dd153f9 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -900,8 +900,9 @@ def accumulate_inference_batches( no_resampled_prompts=total_no_resampled, total_prompts=len(results), ) - logging.info( - f"[Data Preparation Thread] Calculating rewards took {combined_reward_metrics['time/reward']} seconds" - ) + if "time/reward" in combined_reward_metrics: + logging.info( + f"[Data Preparation Thread] Calculating rewards took {combined_reward_metrics['time/reward']} seconds" + ) return combined_result, batch, combined_reward_metrics, batch_stats From 04e3eeb126d0fe9676d9eb90ac79d048f894106a Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 10:46:25 -0700 Subject: [PATCH 56/96] Remove stale time/reward logging line entirely MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit User requested to just remove the logging line rather than making it conditional. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/streaming_data_loader.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index d5dd153f9..88e1eafcb 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -900,9 +900,4 @@ def accumulate_inference_batches( no_resampled_prompts=total_no_resampled, total_prompts=len(results), ) - if "time/reward" in combined_reward_metrics: - logging.info( - f"[Data Preparation Thread] Calculating rewards took {combined_reward_metrics['time/reward']} seconds" - ) - return combined_result, batch, combined_reward_metrics, batch_stats From d3ea24ecc41ce57e7eb34b4d0a2bacd5c0ae8c15 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 10:50:58 -0700 Subject: [PATCH 57/96] Make tool_masks optional in _prepare_collated_data_for_self MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PackedSequences doesn't have tool_masks attribute for non-tool-use experiments. Use getattr with default None and conditionally include in output. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/streaming_data_loader.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py index 88e1eafcb..05c32febe 100644 --- a/open_instruct/streaming_data_loader.py +++ b/open_instruct/streaming_data_loader.py @@ -511,7 +511,7 @@ def _data_preparation_loop(self): def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> dict[str, list[torch.Tensor]]: per_device_packed_query_responses = packed_sequences.query_responses - per_device_packed_tool_masks = packed_sequences.tool_masks + per_device_packed_tool_masks = getattr(packed_sequences, "tool_masks", None) per_device_packed_attention_masks = packed_sequences.attention_masks per_device_packed_position_ids = packed_sequences.position_ids per_device_packed_advantages = packed_sequences.advantages @@ -520,7 +520,7 @@ def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> b_inds = np.random.permutation(len(per_device_packed_query_responses)) collated_query_responses = [] - collated_tool_masks = [] + collated_tool_masks = [] if per_device_packed_tool_masks is not None else None collated_attention_masks = [] collated_position_ids = [] collated_response_masks = [] @@ -533,7 +533,10 @@ def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> [per_device_packed_query_responses[idx] for idx in micro_range], self.tokenizer.pad_token_id, True ) ) - collated_tool_masks.append(collate_fn([per_device_packed_tool_masks[idx] for idx in micro_range], 0, True)) + if per_device_packed_tool_masks is not None: + collated_tool_masks.append( + collate_fn([per_device_packed_tool_masks[idx] for idx in micro_range], 0, True) + ) collated_attention_masks.append( collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0, True) ) @@ -548,15 +551,17 @@ def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0, True) ) - return { + result = { "collated_query_responses": collated_query_responses, - "collated_tool_masks": collated_tool_masks, "collated_attention_masks": collated_attention_masks, "collated_position_ids": collated_position_ids, "collated_advantages": collated_advantages, "collated_response_masks": collated_response_masks, "collated_vllm_logprobs": collated_vllm_logprobs, } + if collated_tool_masks is not None: + result["collated_tool_masks"] = collated_tool_masks + return result def shutdown(self): self.shutdown_requested = True From 15b8bc3ec85584add21deb209f2352b5944291c8 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 11:26:43 -0700 Subject: [PATCH 58/96] Fix add_prompt_to_generator call signature for eval data MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The function requires example_index, epoch_number, and training_step parameters that were missing from the call. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 925a436ee..d62e48bc1 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -2349,8 +2349,10 @@ def health_check_fn(): and eval_data_loader is not None and (args.eval_on_step_0 or training_step > 1) ): - for eval_example in iter(eval_data_loader): - add_prompt_to_generator(eval_example, param_prompt_Q, generation_configs["eval"], is_eval=True) + for eval_idx, eval_example in enumerate(iter(eval_data_loader)): + add_prompt_to_generator( + eval_example, eval_idx, 0, training_step, param_prompt_Q, generation_configs["eval"], is_eval=True + ) episode += args.num_unique_prompts_rollout * streaming_config.num_samples_per_prompt_rollout From 7b2e029dbf96ee76581e976326460832c8c43c5c Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 13:51:13 -0700 Subject: [PATCH 59/96] Moved to data_loader.py --- open_instruct/data_loader.py | 795 ++++++++++++++++++++++ open_instruct/grpo_fast.py | 32 +- open_instruct/streaming_data_loader.py | 908 ------------------------- open_instruct/test_grpo_fast.py | 12 +- 4 files changed, 814 insertions(+), 933 deletions(-) delete mode 100644 open_instruct/streaming_data_loader.py diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index be2e901ea..1ee65ded8 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -1,8 +1,47 @@ +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading from collections.abc import Iterable +from dataclasses import asdict, dataclass +from pathlib import Path +from queue import Queue as StdQueue from typing import Any +import numpy as np +import torch +import vllm from datasets import Dataset from olmo_core.data import data_loader +from ray.util import queue as ray_queue +from tqdm import tqdm +from transformers import PreTrainedTokenizer + +from open_instruct import utils +from open_instruct.dataset_transformation import ( + GROUND_TRUTHS_KEY, + INPUT_IDS_PROMPT_KEY, + RAW_PROMPT_KEY, + VERIFIER_SOURCE_KEY, +) +from open_instruct.model_utils import Batch +from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, ShutdownSentinel, TokenStatistics +from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences +from open_instruct.utils import combine_reward_metrics, repeat_each + +logger = logging.getLogger(__name__) class HFDataLoader(data_loader.DataLoaderBase): @@ -126,3 +165,759 @@ def get_mock_batch(self) -> dict[str, Any]: The first item from the dataset. """ return self.dataset[0] + + +@dataclass +class StreamingDataLoaderConfig: + max_prompt_token_length: int = 256 + response_length: int = 256 + async_steps: int = 1 + num_samples_per_prompt_rollout: int = 4 + active_sampling: bool = False + filter_zero_std_samples: bool = True + no_resampling_pass_rate: float | None = None + advantage_normalization_type: str = "standard" + mask_truncated_completions: bool = False + pack_length: int = 512 + + def __post_init__(self): + assert self.pack_length >= self.max_prompt_token_length + self.response_length, ( + "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!" + ) + assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" + if self.num_samples_per_prompt_rollout == 1: + logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") + + if self.active_sampling: + assert self.async_steps > 1, ( + "With active_sampling, you should set async_steps > 1 to account for filtering of the first batch. " + "Otherwise, your generator only generates only one batch worth of prompts and a single filtered " + "prompt will cause the trainer to stall waiting for more data . " + ) + assert self.filter_zero_std_samples, ( + "filter_zero_std_samples must be True when active_sampling is True. " + "Active sampling requires filtering to work correctly." + ) + if self.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples: + raise ValueError( + "`filter_zero_std_samples` cannot be True when `num_samples_per_prompt_rollout` is 1, " + "as the reward standard deviation will always be 0, causing all samples to be filtered." + ) + if self.async_steps < 1: + raise ValueError("`async_steps` must be greater than 0. Fully synchronous training is not supported.") + + def build( + self, + dataset: Dataset, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + tokenizer: PreTrainedTokenizer, + generation_config: Any, + dp_rank: int, + fs_local_rank: int, + num_training_steps: int, + seed: int, + per_device_train_batch_size: int, + verbose: bool, + work_dir: Path | str, + global_batch_size: int, + dp_world_size: int, + max_possible_score: float, + actor_manager=None, + model_dims: utils.ModelDims | None = None, + ) -> "StreamingDataLoader": + return StreamingDataLoader( + dataset=dataset, + inference_results_Q=inference_results_Q, + param_prompt_Q=param_prompt_Q, + tokenizer=tokenizer, + config=self, + generation_config=generation_config, + work_dir=work_dir, + global_batch_size=global_batch_size, + num_training_steps=num_training_steps, + seed=seed, + per_device_train_batch_size=per_device_train_batch_size, + verbose=verbose, + max_possible_score=max_possible_score, + actor_manager=actor_manager, + model_dims=model_dims, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + fs_local_rank=fs_local_rank, + ) + + +class StreamingDataLoader(data_loader.DataLoaderBase): + def __init__( + self, + *, + dataset: Dataset, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + tokenizer: PreTrainedTokenizer, + config: StreamingDataLoaderConfig, + generation_config: Any, + work_dir: Path | str, + global_batch_size: int, + num_training_steps: int = 0, + seed: int, + per_device_train_batch_size: int, + verbose: bool, + max_possible_score: float, + actor_manager=None, + model_dims: utils.ModelDims = None, + dp_world_size: int = 1, + dp_rank: int = 0, + fs_local_rank: int = 0, + ): + super().__init__( + work_dir=work_dir, + global_batch_size=global_batch_size, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + fs_local_rank=fs_local_rank, + ) + + self.dataset = dataset + self.inference_results_Q = inference_results_Q + self.param_prompt_Q = param_prompt_Q + self.tokenizer = tokenizer + self.config = config + self.config.max_possible_score = max_possible_score + self.generation_config = generation_config + self.num_training_steps = num_training_steps + self.actor_manager = actor_manager + self.model_dims = model_dims + + self.per_device_train_batch_size = per_device_train_batch_size + self.verbose = verbose + + self.training_step = 0 + self.current_epoch = 0 + self.seed = seed + + self.iter_dataloader = HFDataLoader( + dataset=dataset, + batch_size=1, + seed=seed, + rank=dp_rank, + world_size=dp_world_size, + work_dir=work_dir, + automatic_reshuffle=True, + ) + + self.local_queue = StdQueue(maxsize=config.async_steps) + self.background_thread = None + self.shutdown_requested = False + + @property + def total_batches(self) -> int | None: + return self.num_training_steps + + def state_dict(self) -> dict[str, Any]: + return { + "training_step": self.training_step, + "current_epoch": self.current_epoch, + "iter_dataloader_state": self.iter_dataloader.state_dict(), + } + + def load_state_dict(self, state_dict: dict[str, Any]): + self.training_step = state_dict["training_step"] + self.current_epoch = state_dict.get("current_epoch", 0) + self.iter_dataloader.load_state_dict(state_dict["iter_dataloader_state"]) + + def reshuffle(self, epoch: int | None = None, **kwargs): + if epoch is not None: + self.current_epoch = epoch + + def get_mock_batch(self) -> dict[str, Any]: + dummy_qr = torch.tensor([self.tokenizer.pad_token_id, self.tokenizer.eos_token_id], dtype=torch.long) + dummy_tool_mask = torch.zeros_like(dummy_qr) + dummy_attention = torch.tensor([1, 1], dtype=torch.long) + dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long) + dummy_response_mask = torch.zeros_like(dummy_qr) + dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float) + + return { + "collated_query_responses": [dummy_qr], + "collated_tool_masks": [dummy_tool_mask], + "collated_attention_masks": [dummy_attention], + "collated_position_ids": [dummy_position_ids], + "collated_advantages": [dummy_advantage], + "collated_response_masks": [dummy_response_mask], + "collated_vllm_logprobs": [torch.zeros_like(dummy_qr, dtype=torch.float)], + } + + def _iter_batches(self) -> Iterable[dict[str, Any]]: + if self.background_thread is None: + self._start_background_thread() + + while self.training_step < self.num_training_steps: + batch_data = self.local_queue.get() + self.training_step += 1 + yield batch_data + + def _start_background_thread(self): + self.shutdown_requested = False + self.background_thread = threading.Thread( + target=self._data_preparation_loop, daemon=True, name=f"DataLoader-Worker-Rank{self.dp_rank}" + ) + self.background_thread.start() + + def _data_preparation_loop(self): + for _ in range(self.config.async_steps * self.global_batch_size // self.dp_world_size): + example = next(self.iter_dataloader) + add_prompt_to_generator( + example, + self.iter_dataloader._epoch, + self.training_step, + self.param_prompt_Q, + self.generation_config, + is_eval=False, + ) + + for training_step in range(self.training_step, self.num_training_steps): + if self.shutdown_requested: + logger.info(f"[DataLoader Worker {self.dp_rank}] Shutdown requested, exiting") + return + + with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: + result, batch, reward_metrics, batch_stats = accumulate_inference_batches( + self.inference_results_Q, + self.generation_config, + num_prompts=self.rank_batch_size, + model_dims=self.model_dims, + tokenizer=self.tokenizer, + dataset=self.dataset, + actor_manager=self.actor_manager, + active_sampling=self.config.active_sampling, + filter_zero_std_samples=self.config.filter_zero_std_samples, + replenish_prompts=True, + no_resampling_pass_rate=self.config.no_resampling_pass_rate, + iter_dataloader=self.iter_dataloader, + param_prompt_Q=self.param_prompt_Q, + training_step=training_step, + verbose=self.verbose, + max_possible_score=self.config.max_possible_score, + ) + if isinstance(result, ShutdownSentinel): + logger.info(f"[DataLoader Worker {self.dp_rank}] Received shutdown sentinel, exiting") + return + + getting_response_time = timer.duration + scores = np.array(batch.scores) + + good_outputs = [ + len(result.request_info.tool_outputs[i]) > 0 + and result.request_info.tool_calleds[i] + and not result.request_info.timeouts[i] + and not result.request_info.tool_errors[i] + for i in range(len(result.request_info.tool_outputs)) + ] + scores_per_prompt = scores.reshape(-1, self.config.num_samples_per_prompt_rollout) + mean_grouped_rewards = scores_per_prompt.mean(axis=-1) + mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) + std_grouped_rewards = scores_per_prompt.std(axis=-1) + std_grouped_rewards = np.repeat(std_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) + if self.config.advantage_normalization_type == "standard": + advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) + elif self.config.advantage_normalization_type == "centered": + advantages = scores - mean_grouped_rewards + else: + raise ValueError(f"Invalid advantage normalization type: {self.config.advantage_normalization_type}") + + if self.config.mask_truncated_completions: + stop_idxes = torch.tensor( + [i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"] + ) + num_truncated = len(result.finish_reasons) - len(stop_idxes) + if num_truncated > 0: + logger.info( + f"[Truncated completions filtering] Filtered {num_truncated} responses that didn't finish with 'stop'. " + f"Retention rate: {len(stop_idxes) / len(result.finish_reasons):.2%}" + ) + scores = scores[stop_idxes] + advantages = advantages[stop_idxes] + batch = batch[stop_idxes.tolist()] + result.responses = [result.responses[i] for i in stop_idxes] + result.masks = [result.masks[i] for i in stop_idxes] + result.finish_reasons = [result.finish_reasons[i] for i in stop_idxes] + result.logprobs = [result.logprobs[i] for i in stop_idxes] + + with Timer("📦 [Data Preparation Thread] Packing sequences"): + packed_sequences = pack_sequences( + queries=batch.queries, + responses=result.responses, + masks=result.masks, + pack_length=self.config.pack_length, + pad_token_id=self.tokenizer.pad_token_id, + vllm_logprobs=result.logprobs, + ) + lookup_advantages = np.zeros(len(advantages) + 1, dtype=np.float32) + lookup_advantages[1:] = advantages + packed_advantages = [ + torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) + for packed_mask in packed_sequences.response_masks + ] + packed_sequences.advantages = packed_advantages + + collated_data = self._prepare_collated_data_for_self(packed_sequences) + + if len(result.responses) == 0: + metrics = {} + logger.warning(f"No responses in batch {training_step}.") + else: + real_num_responses = len(result.responses) + expected_num_responses = self.config.num_samples_per_prompt_rollout * self.global_batch_size + + unsolved_num_responses = (scores < self.config.max_possible_score).sum() + sequence_lengths = np.array([len(response) for response in result.responses]) + sequence_length_solved = ( + np.array([]) + if np.all(scores == 0) + else np.array(sequence_lengths[scores == self.config.max_possible_score]) + ) + sequence_length_unsolved = ( + np.array([]) + if np.all(scores == self.config.max_possible_score) + else np.array(sequence_lengths[scores == 0]) + ) + stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( + result.finish_reasons + ) + + batch_metrics = asdict(batch_stats) + batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()} + + metrics = { + "scores": scores.mean(), + "real_batch_size_ratio": real_num_responses / expected_num_responses, + "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, + "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, + "val/solve_rate_hist": batch_stats.percent_solved_hist, + "val/total_reward_groups": real_num_responses / self.config.num_samples_per_prompt_rollout, + "val/sequence_lengths": sequence_lengths.mean(), + "val/sequence_lengths_min": sequence_lengths.min(), + "val/sequence_lengths_max": sequence_lengths.max(), + "val/sequence_lengths_unsolved": ( + 0 if len(sequence_length_unsolved) == 0 else sequence_length_unsolved.mean() + ), + "val/sequence_lengths_solved": ( + 0 if len(sequence_length_solved) == 0 else sequence_length_solved.mean() + ), + "val/sequence_lengths_unsolved_hist": sequence_length_unsolved, + "val/sequence_lengths_solved_hist": sequence_length_solved, + "val/stop_rate": stop_rate, + "val/advantages_mean": advantages.mean(), + "val/advantages_min": advantages.min(), + "val/advantages_max": advantages.max(), + "val/advantages_hist": advantages, + "val/num_calls_rate": np.array(result.request_info.num_calls).mean(), + "val/timeouts_rate": np.array(result.request_info.timeouts).mean(), + "val/tool_errors_rate": np.array( + [len(item) > 0 for item in result.request_info.tool_errors] + ).mean(), + "val/good_outputs_rate": np.array(good_outputs).mean(), + "val/tool_runtimes_rate": np.array(result.request_info.tool_runtimes).mean(), + "val/tool_calleds_rate": np.array(result.request_info.tool_calleds).mean(), + "time/getting_response": getting_response_time, + **reward_metrics, + **batch_metrics_prefixed, + } + + total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens + metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time + + collated_data["metrics"] = metrics + self.local_queue.put(collated_data) + + def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> dict[str, list[torch.Tensor]]: + per_device_packed_query_responses = packed_sequences.query_responses + per_device_packed_tool_masks = getattr(packed_sequences, "tool_masks", None) + per_device_packed_attention_masks = packed_sequences.attention_masks + per_device_packed_position_ids = packed_sequences.position_ids + per_device_packed_advantages = packed_sequences.advantages + per_device_packed_response_masks = packed_sequences.response_masks + per_device_packed_vllm_logprobs = packed_sequences.vllm_logprobs + + b_inds = np.random.permutation(len(per_device_packed_query_responses)) + collated_query_responses = [] + collated_tool_masks = [] if per_device_packed_tool_masks is not None else None + collated_attention_masks = [] + collated_position_ids = [] + collated_response_masks = [] + collated_advantages = [] + collated_vllm_logprobs = [] + for j in range(0, len(per_device_packed_query_responses), self.per_device_train_batch_size): + micro_range = b_inds[j : j + self.per_device_train_batch_size] + collated_query_responses.append( + collate_fn( + [per_device_packed_query_responses[idx] for idx in micro_range], self.tokenizer.pad_token_id, True + ) + ) + if per_device_packed_tool_masks is not None: + collated_tool_masks.append( + collate_fn([per_device_packed_tool_masks[idx] for idx in micro_range], 0, True) + ) + collated_attention_masks.append( + collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0, True) + ) + collated_position_ids.append( + collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0, True) + ) + collated_response_masks.append( + collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0, True) + ) + collated_advantages.append(collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0, True)) + collated_vllm_logprobs.append( + collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0, True) + ) + + result = { + "collated_query_responses": collated_query_responses, + "collated_attention_masks": collated_attention_masks, + "collated_position_ids": collated_position_ids, + "collated_advantages": collated_advantages, + "collated_response_masks": collated_response_masks, + "collated_vllm_logprobs": collated_vllm_logprobs, + } + if collated_tool_masks is not None: + result["collated_tool_masks"] = collated_tool_masks + return result + + def shutdown(self): + self.shutdown_requested = True + if self.background_thread is not None: + self.background_thread.join(timeout=5.0) + + +def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: + padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) + if pin_memory: + padded_tensor = padded_tensor.pin_memory() + return padded_tensor + + +@dataclass +class BatchStatistics: + prompt_lengths: list[int] + response_lengths: list[int] + filtered_prompts: int + filtered_prompts_zero: int + filtered_prompts_solved: int + filtered_prompts_nonzero: int + percent_solved_mean: float + percent_solved_hist: np.ndarray + no_resampled_prompts: int + total_prompts: int + + +class PendingQueriesMap: + def __init__(self): + self._map = {} + self._lock = threading.Lock() + + def insert(self, dataset_idx, query, ground_truth, dataset, raw_query): + with self._lock: + if dataset_idx in self._map: + existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[ + dataset_idx + ] + self._map[dataset_idx] = ( + existing_query, + existing_ground_truth, + existing_dataset, + existing_raw_query, + count + 1, + ) + else: + self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, 1) + + def pop(self, dataset_idx): + with self._lock: + if dataset_idx not in self._map: + raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") + + query, ground_truth, dataset, raw_query, count = self._map[dataset_idx] + + if count > 1: + self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, count - 1) + else: + del self._map[dataset_idx] + + return query, ground_truth, dataset, raw_query + + def __len__(self): + with self._lock: + return len(self._map) + + def __contains__(self, dataset_idx): + with self._lock: + return dataset_idx in self._map + + def __getitem__(self, dataset_idx): + with self._lock: + return self._map[dataset_idx] + + def keys(self): + with self._lock: + return list(self._map.keys()) + + +def add_prompt_to_generator( + example: dict[str, Any], + epoch_number: int, + training_step: int, + param_prompt_Q: ray_queue.Queue, + generation_config, + is_eval: bool, +) -> None: + query = example[INPUT_IDS_PROMPT_KEY] + + param_prompt_Q.put( + PromptRequest( + prompt=query, + generation_config=generation_config, + epoch_number=epoch_number, + training_step=training_step, + dataset_index=example["dataset_index"], + is_eval=is_eval, + ) + ) + + +def accumulate_inference_batches( + inference_results_Q: ray_queue.Queue, + generation_config: vllm.SamplingParams, + num_prompts: int, + model_dims: utils.ModelDims, + tokenizer: PreTrainedTokenizer, + dataset: Dataset, + actor_manager=None, + timeout: float | None = None, + active_sampling: bool = False, + filter_zero_std_samples: bool = False, + replenish_prompts: bool = False, + no_resampling_pass_rate: float | None = None, + iter_dataloader: HFDataLoader | None = None, + param_prompt_Q: ray_queue.Queue | None = None, + training_step: int = None, + verbose: bool = False, + max_possible_score: float = 1.0, +) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: + import ray + + if no_resampling_pass_rate is not None: + assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed" + + if replenish_prompts: + assert param_prompt_Q is not None and iter_dataloader is not None and dataset is not None, ( + "replenish_prompts requires param_prompt_Q and iter_dataloader and dataset" + ) + + results = [] + all_queries = [] + all_ground_truths = [] + all_datasets = [] + all_raw_queries = [] + all_decoded_responses = [] + all_reward_metrics = [] + all_scores = [] + all_percent_solved = [] + total_filtered_prompts = 0 + filtered_prompt_zero = 0 + filtered_prompt_solved = 0 + filtered_prompt_nonzero = 0 + total_no_resampled = 0 + progress_bar = tqdm( + total=num_prompts, + desc=f"Accumulating Responses and Rewarding {num_prompts} prompts", + bar_format="{l_bar}{bar}{r_bar}\n", + disable=not verbose, + ) + num_prompts_sampled = 0 + while num_prompts_sampled < num_prompts: + result = inference_results_Q.get(timeout=timeout) + + if isinstance(result, ShutdownSentinel): + return result, None, None, None + + assert len(result.responses) == generation_config.n, ( + f"Mismatch: individual prompt result has {len(result.responses)} responses " + f"but expected {generation_config.n} samples per prompt. " + f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}" + ) + + example = dataset[result.dataset_index] + query = example[INPUT_IDS_PROMPT_KEY] + ground_truth = example[GROUND_TRUTHS_KEY] + dataset_name = example[VERIFIER_SOURCE_KEY] + raw_query = example[RAW_PROMPT_KEY] + + if replenish_prompts: + example = next(iter_dataloader) + add_prompt_to_generator( + example, iter_dataloader._epoch, training_step, param_prompt_Q, generation_config, is_eval=False + ) + + for i in range(len(result.finish_reasons)): + if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: + result.responses[i].append(tokenizer.eos_token_id) + result.masks[i].append(1) + result.logprobs[i].append(float("nan")) + + decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) + + k_queries = repeat_each([query], generation_config.n) + k_ground_truths = repeat_each([ground_truth], generation_config.n) + k_datasets = repeat_each([dataset_name], generation_config.n) + k_raw_queries = repeat_each([raw_query], generation_config.n) + + percent_solved = np.mean(result.reward_scores).item() / max_possible_score + if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: + iter_dataloader.exclude_index(result.dataset_index) + total_no_resampled += 1 + logging.debug( + f"[Data Preparation Thread] Prompt solved at {percent_solved}, will be excluded from resampling, total no resampled: {total_no_resampled}" + ) + + if filter_zero_std_samples and np.std(result.reward_scores) == 0: + if not active_sampling: + num_prompts_sampled += 1 + progress_bar.update(1) + + total_filtered_prompts += 1 + if result.reward_scores[0] == 0: + filtered_prompt_zero += 1 + elif result.reward_scores[0] == max_possible_score: + filtered_prompt_solved += 1 + else: + filtered_prompt_nonzero += 1 + logging.debug( + f"[Data Preparation Thread] Filtered prompt with reward std 0, total filtered {total_filtered_prompts}" + ) + continue + else: + num_prompts_sampled += 1 + progress_bar.update(1) + + results.append(result) + all_queries.extend(k_queries) + all_ground_truths.extend(k_ground_truths) + all_datasets.extend(k_datasets) + all_raw_queries.extend(k_raw_queries) + all_decoded_responses.extend(decoded_responses) + all_scores.extend(result.reward_scores) + all_reward_metrics.append(result.reward_metrics) + all_percent_solved.append(percent_solved) + + if len(results) == 0: + logging.warning( + "[Data Preparation Thread] All prompts were filtered during accumulation. " + f"Filtered: {total_filtered_prompts} (zero std: {filtered_prompt_zero}, " + f"solved: {filtered_prompt_solved}, nonzero: {filtered_prompt_nonzero})" + ) + return None, None, None, None + + combined_responses = [] + combined_finish_reasons = [] + combined_masks = [] + combined_num_calls = [] + combined_timeouts = [] + combined_tool_errors = [] + combined_tool_outputs = [] + combined_tool_runtimes = [] + combined_tool_calleds = [] + combined_logprobs = [] + + earliest_start_time = float("inf") + prompt_lengths = [] + response_lengths = [] + + total_prompt_tokens = 0 + total_response_tokens = 0 + max_generation_time = 0 + + for i, result in enumerate(results): + combined_responses.extend(result.responses) + combined_finish_reasons.extend(result.finish_reasons) + combined_masks.extend(result.masks) + combined_num_calls.extend(result.request_info.num_calls) + combined_timeouts.extend(result.request_info.timeouts) + combined_tool_errors.extend(result.request_info.tool_errors) + combined_tool_outputs.extend(result.request_info.tool_outputs) + combined_tool_runtimes.extend(result.request_info.tool_runtimes) + combined_tool_calleds.extend(result.request_info.tool_calleds) + + combined_logprobs.extend(result.logprobs) + + earliest_start_time = min(earliest_start_time, result.start_time) + + prompt_lengths.append(len(all_queries[i * generation_config.n])) + + for response in result.responses: + response_lengths.append(len(response)) + + total_prompt_tokens += result.token_statistics.num_prompt_tokens + total_response_tokens += result.token_statistics.num_response_tokens + max_generation_time = max(max_generation_time, result.token_statistics.generation_time) + + total_generation_time = max_generation_time + + accumulated_stats = TokenStatistics( + num_prompt_tokens=total_prompt_tokens, + num_response_tokens=total_response_tokens, + generation_time=total_generation_time, + earliest_start_time=earliest_start_time, + ) + + combined_request_info = RequestInfo( + num_calls=combined_num_calls, + timeouts=combined_timeouts, + tool_errors=combined_tool_errors, + tool_outputs=combined_tool_outputs, + tool_runtimes=combined_tool_runtimes, + tool_calleds=combined_tool_calleds, + ) + + combined_result = GenerationResult( + responses=combined_responses, + finish_reasons=combined_finish_reasons, + masks=combined_masks, + request_info=combined_request_info, + dataset_index=None, + epoch_number=results[0].epoch_number, + token_statistics=accumulated_stats, + logprobs=combined_logprobs, + ) + + if actor_manager is not None: + ray.get(actor_manager.report_token_statistics.remote(accumulated_stats)) + + batch = Batch( + queries=all_queries, + ground_truths=all_ground_truths, + datasets=all_datasets, + raw_queries=all_raw_queries, + decoded_responses=all_decoded_responses, + indices=None, + scores=all_scores, + ) + + combined_reward_metrics = combine_reward_metrics(all_reward_metrics) + percent_solved_mean = np.mean(all_percent_solved) if all_percent_solved else 0.0 + + batch_stats = BatchStatistics( + prompt_lengths=prompt_lengths, + response_lengths=response_lengths, + filtered_prompts=total_filtered_prompts, + filtered_prompts_zero=filtered_prompt_zero, + filtered_prompts_solved=filtered_prompt_solved, + filtered_prompts_nonzero=filtered_prompt_nonzero, + percent_solved_mean=percent_solved_mean, + percent_solved_hist=np.array(all_percent_solved), + no_resampled_prompts=total_no_resampled, + total_prompts=len(results), + ) + return combined_result, batch, combined_reward_metrics, batch_stats diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index d62e48bc1..4edf18340 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -36,8 +36,9 @@ with contextlib.suppress(Exception): import deepspeed -from open_instruct import streaming_data_loader, utils -from open_instruct.streaming_data_loader import accumulate_inference_batches, add_prompt_to_generator, collate_fn +from open_instruct import data_loader as data_loader_lib +from open_instruct import utils +from open_instruct.data_loader import accumulate_inference_batches, add_prompt_to_generator, collate_fn # isort: on import asyncio @@ -74,7 +75,6 @@ from transformers import AutoModelForCausalLM, PreTrainedModel, PreTrainedTokenizer, get_scheduler from transformers.integrations import HfDeepSpeedConfig -from open_instruct import data_loader as data_loader_lib from open_instruct import logger_utils, vllm_utils from open_instruct.actor_manager import ActorManager from open_instruct.dataset_transformation import ( @@ -591,7 +591,7 @@ def __init__( master_addr: str | None, master_port: int | None, args: Args, - data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, + data_loader_config: data_loader_lib.StreamingDataLoaderConfig, dataset: Dataset, inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, @@ -791,7 +791,7 @@ def load(self, path: str, map_location=None): if hasattr(self, "ref_policy_checkpoint_path") else None, ) - # 49 base metrics: 16 from step() (KL, loss, ratio, etc.), 22 from streaming_data_loader + # 49 base metrics: 16 from step() (KL, loss, ratio, etc.), 22 from data_loader_lib # (scores, sequence_lengths, etc.), 7 from BatchStatistics, 4 from reward_metrics. # Each verifier adds 2 metrics: objective/{key}_reward and objective/{key}_correct_rate. max_metrics = 49 + 2 * num_verifiers @@ -1429,7 +1429,7 @@ def __init__( num_gpus_per_node: list[int], single_gpu_mode: bool, args: Args, - data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, + data_loader_config: data_loader_lib.StreamingDataLoaderConfig, dataset: Dataset, inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, @@ -1575,7 +1575,7 @@ def calculate_utilization_metrics( return utilization_metrics -def setup_runtime_variables(args: Args, streaming_config: streaming_data_loader.StreamingDataLoaderConfig) -> Args: +def setup_runtime_variables(args: Args, streaming_config: data_loader_lib.StreamingDataLoaderConfig) -> Args: """Set up runtime variables for the experiment.""" args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" args.output_dir = os.path.join(args.output_dir, args.run_name) @@ -1636,7 +1636,7 @@ def setup_datasets( args: Args, tc: TokenizerConfig, tokenizer: PreTrainedTokenizer, - streaming_config: streaming_data_loader.StreamingDataLoaderConfig, + streaming_config: data_loader_lib.StreamingDataLoaderConfig, ): """Set up training and evaluation datasets.""" system_prompt_override = None @@ -1700,7 +1700,7 @@ def create_model_and_optimizer( inference_results_Q: ray_queue.Queue, param_prompt_Q: ray_queue.Queue, evaluation_inference_results_Q: ray_queue.Queue, - data_loader_config: streaming_data_loader.StreamingDataLoaderConfig, + data_loader_config: data_loader_lib.StreamingDataLoaderConfig, train_dataset: Dataset, eval_dataset, reward_config: RewardConfig, @@ -1853,7 +1853,7 @@ def create_model_and_optimizer( return policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims -def create_generation_configs(args: Args, streaming_config: streaming_data_loader.StreamingDataLoaderConfig): +def create_generation_configs(args: Args, streaming_config: data_loader_lib.StreamingDataLoaderConfig): """Create generation configs for training and evaluation.""" generation_config = vllm.SamplingParams( temperature=args.temperature, @@ -1940,7 +1940,7 @@ def weight_sync_thread( def one_training_step( args: Args, - streaming_config: streaming_data_loader.StreamingDataLoaderConfig, + streaming_config: data_loader_lib.StreamingDataLoaderConfig, policy_group: ModelGroup, tokenizer: PreTrainedTokenizer, data_thread_metrics: dict[str, Any], @@ -2349,9 +2349,9 @@ def health_check_fn(): and eval_data_loader is not None and (args.eval_on_step_0 or training_step > 1) ): - for eval_idx, eval_example in enumerate(iter(eval_data_loader)): + for eval_example in iter(eval_data_loader): add_prompt_to_generator( - eval_example, eval_idx, 0, training_step, param_prompt_Q, generation_configs["eval"], is_eval=True + eval_example, 0, training_step, param_prompt_Q, generation_configs["eval"], is_eval=True ) episode += args.num_unique_prompts_rollout * streaming_config.num_samples_per_prompt_rollout @@ -2437,7 +2437,7 @@ def main( args: Args, tc: TokenizerConfig, model_config: ModelConfig, - streaming_config: streaming_data_loader.StreamingDataLoaderConfig, + streaming_config: data_loader_lib.StreamingDataLoaderConfig, ): tokenizer = make_tokenizer(tc, model_config) args = setup_runtime_variables(args, streaming_config) @@ -2584,11 +2584,11 @@ def main( if __name__ == "__main__": utils.check_oe_eval_internal() - parser = ArgumentParserPlus((Args, TokenizerConfig, ModelConfig, streaming_data_loader.StreamingDataLoaderConfig)) + parser = ArgumentParserPlus((Args, TokenizerConfig, ModelConfig, data_loader_lib.StreamingDataLoaderConfig)) args, tokenizer_config, model_config, streaming_config = parser.parse_args_into_dataclasses() assert isinstance(args, Args) assert isinstance(tokenizer_config, TokenizerConfig) assert isinstance(model_config, ModelConfig) - assert isinstance(streaming_config, streaming_data_loader.StreamingDataLoaderConfig) + assert isinstance(streaming_config, data_loader_lib.StreamingDataLoaderConfig) main(args, tokenizer_config, model_config, streaming_config) diff --git a/open_instruct/streaming_data_loader.py b/open_instruct/streaming_data_loader.py deleted file mode 100644 index 05c32febe..000000000 --- a/open_instruct/streaming_data_loader.py +++ /dev/null @@ -1,908 +0,0 @@ -# Copyright 2024 AllenAI. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import threading -from abc import abstractmethod -from collections.abc import Iterable -from dataclasses import asdict, dataclass -from pathlib import Path -from queue import Queue as StdQueue -from typing import Any - -import numpy as np -import torch -import vllm -from datasets import Dataset -from ray.util import queue as ray_queue -from tqdm import tqdm -from transformers import PreTrainedTokenizer - -from open_instruct import data_loader as data_loader_lib -from open_instruct import utils -from open_instruct.dataset_transformation import ( - GROUND_TRUTHS_KEY, - INPUT_IDS_PROMPT_KEY, - RAW_PROMPT_KEY, - VERIFIER_SOURCE_KEY, -) -from open_instruct.model_utils import Batch -from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, ShutdownSentinel, TokenStatistics -from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences -from open_instruct.utils import combine_reward_metrics, repeat_each - -logger = logging.getLogger(__name__) - - -@dataclass -class StreamingDataLoaderConfig: - max_prompt_token_length: int = 256 - response_length: int = 256 - async_steps: int = 1 - num_samples_per_prompt_rollout: int = 4 - active_sampling: bool = False - filter_zero_std_samples: bool = True - no_resampling_pass_rate: float | None = None - advantage_normalization_type: str = "standard" - mask_truncated_completions: bool = False - pack_length: int = 512 - - def __post_init__(self): - assert self.pack_length >= self.max_prompt_token_length + self.response_length, ( - "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!" - ) - assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" - if self.num_samples_per_prompt_rollout == 1: - logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") - - if self.active_sampling: - assert self.async_steps > 1, ( - "With active_sampling, you should set async_steps > 1 to account for filtering of the first batch. " - "Otherwise, your generator only generates only one batch worth of prompts and a single filtered " - "prompt will cause the trainer to stall waiting for more data . " - ) - assert self.filter_zero_std_samples, ( - "filter_zero_std_samples must be True when active_sampling is True. " - "Active sampling requires filtering to work correctly." - ) - if self.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples: - raise ValueError( - "`filter_zero_std_samples` cannot be True when `num_samples_per_prompt_rollout` is 1, " - "as the reward standard deviation will always be 0, causing all samples to be filtered." - ) - if self.async_steps < 1: - raise ValueError("`async_steps` must be greater than 0. Fully synchronous training is not supported.") - - def build( - self, - dataset: Dataset, - inference_results_Q: ray_queue.Queue, - param_prompt_Q: ray_queue.Queue, - tokenizer: PreTrainedTokenizer, - generation_config: Any, - dp_rank: int, - fs_local_rank: int, - num_training_steps: int, - seed: int, - per_device_train_batch_size: int, - verbose: bool, - work_dir: Path | str, - global_batch_size: int, - dp_world_size: int, - max_possible_score: float, - actor_manager=None, - model_dims: utils.ModelDims | None = None, - ) -> "StreamingDataLoader": - return StreamingDataLoader( - dataset=dataset, - inference_results_Q=inference_results_Q, - param_prompt_Q=param_prompt_Q, - tokenizer=tokenizer, - config=self, - generation_config=generation_config, - work_dir=work_dir, - global_batch_size=global_batch_size, - num_training_steps=num_training_steps, - seed=seed, - per_device_train_batch_size=per_device_train_batch_size, - verbose=verbose, - max_possible_score=max_possible_score, - actor_manager=actor_manager, - model_dims=model_dims, - dp_world_size=dp_world_size, - dp_rank=dp_rank, - fs_local_rank=fs_local_rank, - ) - - -class DataLoaderBase: - def __init__( - self, - *, - work_dir: Path | str, - global_batch_size: int, - dp_world_size: int = 1, - dp_rank: int = 0, - fs_local_rank: int = 0, - ): - self.work_dir = Path(work_dir) - self._global_batch_size = global_batch_size - self.dp_world_size = dp_world_size - self.dp_rank = dp_rank - self.fs_local_rank = fs_local_rank - self.batches_processed = 0 - self.epoch: int | None = None - - @property - def global_batch_size(self) -> int: - return self._global_batch_size - - @global_batch_size.setter - def global_batch_size(self, value: int): - self._global_batch_size = value - - @property - def rank_batch_size(self) -> int: - return self.global_batch_size // self.dp_world_size - - @property - @abstractmethod - def total_batches(self) -> int | None: - pass - - @abstractmethod - def state_dict(self) -> dict[str, Any]: - pass - - @abstractmethod - def load_state_dict(self, state_dict: dict[str, Any]): - pass - - @abstractmethod - def reshuffle(self, epoch: int | None = None, **kwargs): - pass - - @abstractmethod - def _iter_batches(self) -> Iterable[dict[str, Any]]: - pass - - @abstractmethod - def get_mock_batch(self) -> dict[str, Any]: - pass - - def __iter__(self): - return self._iter_batches() - - def __next__(self): - if not hasattr(self, "_iterator"): - self._iterator = self._iter_batches() - return next(self._iterator) - - def reset(self): - if hasattr(self, "_iterator"): - del self._iterator - self.batches_processed = 0 - - -class TextDataLoaderBase(DataLoaderBase): - def __init__( - self, - *, - work_dir: Path | str, - global_batch_size: int, - dp_world_size: int = 1, - dp_rank: int = 0, - fs_local_rank: int = 0, - ): - super().__init__( - work_dir=work_dir, - global_batch_size=global_batch_size, - dp_world_size=dp_world_size, - dp_rank=dp_rank, - fs_local_rank=fs_local_rank, - ) - self.tokens_processed: int = 0 - - def reset(self): - super().reset() - self.tokens_processed = 0 - - def global_num_tokens_in_batch(self, batch: dict[str, Any]) -> int | None: - del batch - return self.global_batch_size - - -class StreamingDataLoader(TextDataLoaderBase): - def __init__( - self, - *, - dataset: Dataset, - inference_results_Q: ray_queue.Queue, - param_prompt_Q: ray_queue.Queue, - tokenizer: PreTrainedTokenizer, - config: StreamingDataLoaderConfig, - generation_config: Any, - work_dir: Path | str, - global_batch_size: int, - num_training_steps: int = 0, - seed: int, - per_device_train_batch_size: int, - verbose: bool, - max_possible_score: float, - actor_manager=None, - model_dims: utils.ModelDims = None, - dp_world_size: int = 1, - dp_rank: int = 0, - fs_local_rank: int = 0, - ): - super().__init__( - work_dir=work_dir, - global_batch_size=global_batch_size, - dp_world_size=dp_world_size, - dp_rank=dp_rank, - fs_local_rank=fs_local_rank, - ) - - self.dataset = dataset - self.inference_results_Q = inference_results_Q - self.param_prompt_Q = param_prompt_Q - self.tokenizer = tokenizer - self.config = config - self.config.max_possible_score = max_possible_score - self.generation_config = generation_config - self.num_training_steps = num_training_steps - self.actor_manager = actor_manager - self.model_dims = model_dims - - self.per_device_train_batch_size = per_device_train_batch_size - self.verbose = verbose - - self.training_step = 0 - self.current_epoch = 0 - self.seed = seed - - self.iter_dataloader = data_loader_lib.HFDataLoader( - dataset=dataset, - batch_size=1, - seed=seed, - rank=dp_rank, - world_size=dp_world_size, - work_dir=work_dir, - automatic_reshuffle=True, - ) - - self.local_queue = StdQueue(maxsize=config.async_steps) - self.background_thread = None - self.shutdown_requested = False - - @property - def total_batches(self) -> int | None: - return self.num_training_steps - - def state_dict(self) -> dict[str, Any]: - return { - "training_step": self.training_step, - "current_epoch": self.current_epoch, - "iter_dataloader_state": self.iter_dataloader.state_dict(), - } - - def load_state_dict(self, state_dict: dict[str, Any]): - self.training_step = state_dict["training_step"] - self.current_epoch = state_dict.get("current_epoch", 0) - self.iter_dataloader.load_state_dict(state_dict["iter_dataloader_state"]) - - def reshuffle(self, epoch: int | None = None, **kwargs): - if epoch is not None: - self.current_epoch = epoch - - def get_mock_batch(self) -> dict[str, Any]: - dummy_qr = torch.tensor([self.tokenizer.pad_token_id, self.tokenizer.eos_token_id], dtype=torch.long) - dummy_tool_mask = torch.zeros_like(dummy_qr) - dummy_attention = torch.tensor([1, 1], dtype=torch.long) - dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long) - dummy_response_mask = torch.zeros_like(dummy_qr) - dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float) - - return { - "collated_query_responses": [dummy_qr], - "collated_tool_masks": [dummy_tool_mask], - "collated_attention_masks": [dummy_attention], - "collated_position_ids": [dummy_position_ids], - "collated_advantages": [dummy_advantage], - "collated_response_masks": [dummy_response_mask], - "collated_vllm_logprobs": [torch.zeros_like(dummy_qr, dtype=torch.float)], - } - - def _iter_batches(self) -> Iterable[dict[str, Any]]: - if self.background_thread is None: - self._start_background_thread() - - while self.training_step < self.num_training_steps: - batch_data = self.local_queue.get() - self.training_step += 1 - yield batch_data - - def _start_background_thread(self): - self.shutdown_requested = False - self.background_thread = threading.Thread( - target=self._data_preparation_loop, daemon=True, name=f"DataLoader-Worker-Rank{self.dp_rank}" - ) - self.background_thread.start() - - def _data_preparation_loop(self): - for _ in range(self.config.async_steps * self.global_batch_size // self.dp_world_size): - example = next(self.iter_dataloader) - dataset_index = example["dataset_index"] - add_prompt_to_generator( - example, - dataset_index, - self.iter_dataloader._epoch, - self.training_step, - self.param_prompt_Q, - self.generation_config, - is_eval=False, - ) - - for training_step in range(self.training_step, self.num_training_steps): - if self.shutdown_requested: - logger.info(f"[DataLoader Worker {self.dp_rank}] Shutdown requested, exiting") - return - - with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: - result, batch, reward_metrics, batch_stats = accumulate_inference_batches( - self.inference_results_Q, - self.generation_config, - num_prompts=self.rank_batch_size, - model_dims=self.model_dims, - tokenizer=self.tokenizer, - dataset=self.dataset, - actor_manager=self.actor_manager, - active_sampling=self.config.active_sampling, - filter_zero_std_samples=self.config.filter_zero_std_samples, - replenish_prompts=True, - no_resampling_pass_rate=self.config.no_resampling_pass_rate, - iter_dataloader=self.iter_dataloader, - param_prompt_Q=self.param_prompt_Q, - training_step=training_step, - verbose=self.verbose, - max_possible_score=self.config.max_possible_score, - ) - if isinstance(result, ShutdownSentinel): - logger.info(f"[DataLoader Worker {self.dp_rank}] Received shutdown sentinel, exiting") - return - - getting_response_time = timer.duration - scores = np.array(batch.scores) - - good_outputs = [ - len(result.request_info.tool_outputs[i]) > 0 - and result.request_info.tool_calleds[i] - and not result.request_info.timeouts[i] - and not result.request_info.tool_errors[i] - for i in range(len(result.request_info.tool_outputs)) - ] - scores_per_prompt = scores.reshape(-1, self.config.num_samples_per_prompt_rollout) - mean_grouped_rewards = scores_per_prompt.mean(axis=-1) - mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) - std_grouped_rewards = scores_per_prompt.std(axis=-1) - std_grouped_rewards = np.repeat(std_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) - if self.config.advantage_normalization_type == "standard": - advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - elif self.config.advantage_normalization_type == "centered": - advantages = scores - mean_grouped_rewards - else: - raise ValueError(f"Invalid advantage normalization type: {self.config.advantage_normalization_type}") - - if self.config.mask_truncated_completions: - stop_idxes = torch.tensor( - [i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"] - ) - num_truncated = len(result.finish_reasons) - len(stop_idxes) - if num_truncated > 0: - logger.info( - f"[Truncated completions filtering] Filtered {num_truncated} responses that didn't finish with 'stop'. " - f"Retention rate: {len(stop_idxes) / len(result.finish_reasons):.2%}" - ) - scores = scores[stop_idxes] - advantages = advantages[stop_idxes] - batch = batch[stop_idxes.tolist()] - result.responses = [result.responses[i] for i in stop_idxes] - result.masks = [result.masks[i] for i in stop_idxes] - result.finish_reasons = [result.finish_reasons[i] for i in stop_idxes] - result.logprobs = [result.logprobs[i] for i in stop_idxes] - - with Timer("📦 [Data Preparation Thread] Packing sequences"): - packed_sequences = pack_sequences( - queries=batch.queries, - responses=result.responses, - masks=result.masks, - pack_length=self.config.pack_length, - pad_token_id=self.tokenizer.pad_token_id, - vllm_logprobs=result.logprobs, - ) - lookup_advantages = np.zeros(len(advantages) + 1, dtype=np.float32) - lookup_advantages[1:] = advantages - packed_advantages = [ - torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) - for packed_mask in packed_sequences.response_masks - ] - packed_sequences.advantages = packed_advantages - - collated_data = self._prepare_collated_data_for_self(packed_sequences) - - if len(result.responses) == 0: - metrics = {} - logger.warning(f"No responses in batch {training_step}.") - else: - real_num_responses = len(result.responses) - expected_num_responses = self.config.num_samples_per_prompt_rollout * self.global_batch_size - - unsolved_num_responses = (scores < self.config.max_possible_score).sum() - sequence_lengths = np.array([len(response) for response in result.responses]) - sequence_length_solved = ( - np.array([]) - if np.all(scores == 0) - else np.array(sequence_lengths[scores == self.config.max_possible_score]) - ) - sequence_length_unsolved = ( - np.array([]) - if np.all(scores == self.config.max_possible_score) - else np.array(sequence_lengths[scores == 0]) - ) - stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( - result.finish_reasons - ) - - batch_metrics = asdict(batch_stats) - batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()} - - metrics = { - "scores": scores.mean(), - "real_batch_size_ratio": real_num_responses / expected_num_responses, - "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, - "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, - "val/solve_rate_hist": batch_stats.percent_solved_hist, - "val/total_reward_groups": real_num_responses / self.config.num_samples_per_prompt_rollout, - "val/sequence_lengths": sequence_lengths.mean(), - "val/sequence_lengths_min": sequence_lengths.min(), - "val/sequence_lengths_max": sequence_lengths.max(), - "val/sequence_lengths_unsolved": ( - 0 if len(sequence_length_unsolved) == 0 else sequence_length_unsolved.mean() - ), - "val/sequence_lengths_solved": ( - 0 if len(sequence_length_solved) == 0 else sequence_length_solved.mean() - ), - "val/sequence_lengths_unsolved_hist": sequence_length_unsolved, - "val/sequence_lengths_solved_hist": sequence_length_solved, - "val/stop_rate": stop_rate, - "val/advantages_mean": advantages.mean(), - "val/advantages_min": advantages.min(), - "val/advantages_max": advantages.max(), - "val/advantages_hist": advantages, - "val/num_calls_rate": np.array(result.request_info.num_calls).mean(), - "val/timeouts_rate": np.array(result.request_info.timeouts).mean(), - "val/tool_errors_rate": np.array( - [len(item) > 0 for item in result.request_info.tool_errors] - ).mean(), - "val/good_outputs_rate": np.array(good_outputs).mean(), - "val/tool_runtimes_rate": np.array(result.request_info.tool_runtimes).mean(), - "val/tool_calleds_rate": np.array(result.request_info.tool_calleds).mean(), - "time/getting_response": getting_response_time, - **reward_metrics, - **batch_metrics_prefixed, - } - - total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens - metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time - - collated_data["metrics"] = metrics - self.local_queue.put(collated_data) - - def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> dict[str, list[torch.Tensor]]: - per_device_packed_query_responses = packed_sequences.query_responses - per_device_packed_tool_masks = getattr(packed_sequences, "tool_masks", None) - per_device_packed_attention_masks = packed_sequences.attention_masks - per_device_packed_position_ids = packed_sequences.position_ids - per_device_packed_advantages = packed_sequences.advantages - per_device_packed_response_masks = packed_sequences.response_masks - per_device_packed_vllm_logprobs = packed_sequences.vllm_logprobs - - b_inds = np.random.permutation(len(per_device_packed_query_responses)) - collated_query_responses = [] - collated_tool_masks = [] if per_device_packed_tool_masks is not None else None - collated_attention_masks = [] - collated_position_ids = [] - collated_response_masks = [] - collated_advantages = [] - collated_vllm_logprobs = [] - for j in range(0, len(per_device_packed_query_responses), self.per_device_train_batch_size): - micro_range = b_inds[j : j + self.per_device_train_batch_size] - collated_query_responses.append( - collate_fn( - [per_device_packed_query_responses[idx] for idx in micro_range], self.tokenizer.pad_token_id, True - ) - ) - if per_device_packed_tool_masks is not None: - collated_tool_masks.append( - collate_fn([per_device_packed_tool_masks[idx] for idx in micro_range], 0, True) - ) - collated_attention_masks.append( - collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0, True) - ) - collated_position_ids.append( - collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0, True) - ) - collated_response_masks.append( - collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0, True) - ) - collated_advantages.append(collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0, True)) - collated_vllm_logprobs.append( - collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0, True) - ) - - result = { - "collated_query_responses": collated_query_responses, - "collated_attention_masks": collated_attention_masks, - "collated_position_ids": collated_position_ids, - "collated_advantages": collated_advantages, - "collated_response_masks": collated_response_masks, - "collated_vllm_logprobs": collated_vllm_logprobs, - } - if collated_tool_masks is not None: - result["collated_tool_masks"] = collated_tool_masks - return result - - def shutdown(self): - self.shutdown_requested = True - if self.background_thread is not None: - self.background_thread.join(timeout=5.0) - - -def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: - padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) - if pin_memory: - padded_tensor = padded_tensor.pin_memory() - return padded_tensor - - -@dataclass -class BatchStatistics: - prompt_lengths: list[int] - response_lengths: list[int] - filtered_prompts: int - filtered_prompts_zero: int - filtered_prompts_solved: int - filtered_prompts_nonzero: int - percent_solved_mean: float - percent_solved_hist: np.ndarray - no_resampled_prompts: int - total_prompts: int - - -class PendingQueriesMap: - def __init__(self): - self._map = {} - self._lock = threading.Lock() - - def insert(self, dataset_idx, query, ground_truth, dataset, raw_query): - with self._lock: - if dataset_idx in self._map: - existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[ - dataset_idx - ] - self._map[dataset_idx] = ( - existing_query, - existing_ground_truth, - existing_dataset, - existing_raw_query, - count + 1, - ) - else: - self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, 1) - - def pop(self, dataset_idx): - with self._lock: - if dataset_idx not in self._map: - raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") - - query, ground_truth, dataset, raw_query, count = self._map[dataset_idx] - - if count > 1: - self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, count - 1) - else: - del self._map[dataset_idx] - - return query, ground_truth, dataset, raw_query - - def __len__(self): - with self._lock: - return len(self._map) - - def __contains__(self, dataset_idx): - with self._lock: - return dataset_idx in self._map - - def __getitem__(self, dataset_idx): - with self._lock: - return self._map[dataset_idx] - - def keys(self): - with self._lock: - return list(self._map.keys()) - - -def add_prompt_to_generator( - example: dict[str, Any], - example_index: int, - epoch_number: int, - training_step: int, - param_prompt_Q: ray_queue.Queue, - generation_config, - is_eval: bool, -) -> None: - query = example[INPUT_IDS_PROMPT_KEY] - - param_prompt_Q.put( - PromptRequest( - prompt=query, - generation_config=generation_config, - epoch_number=epoch_number, - training_step=training_step, - dataset_index=example_index, - is_eval=is_eval, - ) - ) - - -def accumulate_inference_batches( - inference_results_Q: ray_queue.Queue, - generation_config: vllm.SamplingParams, - num_prompts: int, - model_dims: utils.ModelDims, - tokenizer: PreTrainedTokenizer, - dataset: Dataset, - actor_manager=None, - timeout: float | None = None, - active_sampling: bool = False, - filter_zero_std_samples: bool = False, - replenish_prompts: bool = False, - no_resampling_pass_rate: float | None = None, - iter_dataloader: data_loader_lib.HFDataLoader | None = None, - param_prompt_Q: ray_queue.Queue | None = None, - training_step: int = None, - verbose: bool = False, - max_possible_score: float = 1.0, -) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: - import ray - - if no_resampling_pass_rate is not None: - assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed" - - if replenish_prompts: - assert param_prompt_Q is not None and iter_dataloader is not None and dataset is not None, ( - "replenish_prompts requires param_prompt_Q and iter_dataloader and dataset" - ) - - results = [] - all_queries = [] - all_ground_truths = [] - all_datasets = [] - all_raw_queries = [] - all_decoded_responses = [] - all_reward_metrics = [] - all_scores = [] - all_percent_solved = [] - total_filtered_prompts = 0 - filtered_prompt_zero = 0 - filtered_prompt_solved = 0 - filtered_prompt_nonzero = 0 - total_no_resampled = 0 - progress_bar = tqdm( - total=num_prompts, - desc=f"Accumulating Responses and Rewarding {num_prompts} prompts", - bar_format="{l_bar}{bar}{r_bar}\n", - disable=not verbose, - ) - num_prompts_sampled = 0 - while num_prompts_sampled < num_prompts: - result = inference_results_Q.get(timeout=timeout) - - if isinstance(result, ShutdownSentinel): - return result, None, None, None - - assert len(result.responses) == generation_config.n, ( - f"Mismatch: individual prompt result has {len(result.responses)} responses " - f"but expected {generation_config.n} samples per prompt. " - f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}" - ) - - example = dataset[result.dataset_index] - query = example[INPUT_IDS_PROMPT_KEY] - ground_truth = example[GROUND_TRUTHS_KEY] - dataset_name = example[VERIFIER_SOURCE_KEY] - raw_query = example[RAW_PROMPT_KEY] - - if replenish_prompts: - example = next(iter_dataloader) - dataset_index = example["dataset_index"] - add_prompt_to_generator( - example, - dataset_index, - iter_dataloader._epoch, - training_step, - param_prompt_Q, - generation_config, - is_eval=False, - ) - - for i in range(len(result.finish_reasons)): - if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: - result.responses[i].append(tokenizer.eos_token_id) - result.masks[i].append(1) - result.logprobs[i].append(float("nan")) - - decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) - - k_queries = repeat_each([query], generation_config.n) - k_ground_truths = repeat_each([ground_truth], generation_config.n) - k_datasets = repeat_each([dataset_name], generation_config.n) - k_raw_queries = repeat_each([raw_query], generation_config.n) - - percent_solved = np.mean(result.reward_scores).item() / max_possible_score - if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: - iter_dataloader.exclude_index(result.dataset_index) - total_no_resampled += 1 - logging.debug( - f"[Data Preparation Thread] Prompt solved at {percent_solved}, will be excluded from resampling, total no resampled: {total_no_resampled}" - ) - - if filter_zero_std_samples and np.std(result.reward_scores) == 0: - if not active_sampling: - num_prompts_sampled += 1 - progress_bar.update(1) - - total_filtered_prompts += 1 - if result.reward_scores[0] == 0: - filtered_prompt_zero += 1 - elif result.reward_scores[0] == max_possible_score: - filtered_prompt_solved += 1 - else: - filtered_prompt_nonzero += 1 - logging.debug( - f"[Data Preparation Thread] Filtered prompt with reward std 0, total filtered {total_filtered_prompts}" - ) - continue - else: - num_prompts_sampled += 1 - progress_bar.update(1) - - results.append(result) - all_queries.extend(k_queries) - all_ground_truths.extend(k_ground_truths) - all_datasets.extend(k_datasets) - all_raw_queries.extend(k_raw_queries) - all_decoded_responses.extend(decoded_responses) - all_scores.extend(result.reward_scores) - all_reward_metrics.append(result.reward_metrics) - all_percent_solved.append(percent_solved) - - if len(results) == 0: - logging.warning( - "[Data Preparation Thread] All prompts were filtered during accumulation. " - f"Filtered: {total_filtered_prompts} (zero std: {filtered_prompt_zero}, " - f"solved: {filtered_prompt_solved}, nonzero: {filtered_prompt_nonzero})" - ) - return None, None, None, None - - combined_responses = [] - combined_finish_reasons = [] - combined_masks = [] - combined_num_calls = [] - combined_timeouts = [] - combined_tool_errors = [] - combined_tool_outputs = [] - combined_tool_runtimes = [] - combined_tool_calleds = [] - combined_logprobs = [] - - earliest_start_time = float("inf") - prompt_lengths = [] - response_lengths = [] - - total_prompt_tokens = 0 - total_response_tokens = 0 - max_generation_time = 0 - - for i, result in enumerate(results): - combined_responses.extend(result.responses) - combined_finish_reasons.extend(result.finish_reasons) - combined_masks.extend(result.masks) - combined_num_calls.extend(result.request_info.num_calls) - combined_timeouts.extend(result.request_info.timeouts) - combined_tool_errors.extend(result.request_info.tool_errors) - combined_tool_outputs.extend(result.request_info.tool_outputs) - combined_tool_runtimes.extend(result.request_info.tool_runtimes) - combined_tool_calleds.extend(result.request_info.tool_calleds) - - combined_logprobs.extend(result.logprobs) - - earliest_start_time = min(earliest_start_time, result.start_time) - - prompt_lengths.append(len(all_queries[i * generation_config.n])) - - for response in result.responses: - response_lengths.append(len(response)) - - total_prompt_tokens += result.token_statistics.num_prompt_tokens - total_response_tokens += result.token_statistics.num_response_tokens - max_generation_time = max(max_generation_time, result.token_statistics.generation_time) - - total_generation_time = max_generation_time - - accumulated_stats = TokenStatistics( - num_prompt_tokens=total_prompt_tokens, - num_response_tokens=total_response_tokens, - generation_time=total_generation_time, - earliest_start_time=earliest_start_time, - ) - - combined_request_info = RequestInfo( - num_calls=combined_num_calls, - timeouts=combined_timeouts, - tool_errors=combined_tool_errors, - tool_outputs=combined_tool_outputs, - tool_runtimes=combined_tool_runtimes, - tool_calleds=combined_tool_calleds, - ) - - combined_result = GenerationResult( - responses=combined_responses, - finish_reasons=combined_finish_reasons, - masks=combined_masks, - request_info=combined_request_info, - dataset_index=None, - epoch_number=results[0].epoch_number, - token_statistics=accumulated_stats, - logprobs=combined_logprobs, - ) - - if actor_manager is not None: - ray.get(actor_manager.report_token_statistics.remote(accumulated_stats)) - - batch = Batch( - queries=all_queries, - ground_truths=all_ground_truths, - datasets=all_datasets, - raw_queries=all_raw_queries, - decoded_responses=all_decoded_responses, - indices=None, - scores=all_scores, - ) - - combined_reward_metrics = combine_reward_metrics(all_reward_metrics) - percent_solved_mean = np.mean(all_percent_solved) if all_percent_solved else 0.0 - - batch_stats = BatchStatistics( - prompt_lengths=prompt_lengths, - response_lengths=response_lengths, - filtered_prompts=total_filtered_prompts, - filtered_prompts_zero=filtered_prompt_zero, - filtered_prompts_solved=filtered_prompt_solved, - filtered_prompts_nonzero=filtered_prompt_nonzero, - percent_solved_mean=percent_solved_mean, - percent_solved_hist=np.array(all_percent_solved), - no_resampled_prompts=total_no_resampled, - total_prompts=len(results), - ) - return combined_result, batch, combined_reward_metrics, batch_stats diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index 0df038467..8e0d44fa2 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -278,9 +278,7 @@ def setup_and_add_prompts_to_generator(self, queries, ground_truths, datasets, r ) for example in data_loader: - grpo_fast.add_prompt_to_generator( - example, example["dataset_index"], 0, 0, param_prompt_Q, mock_generation_config, False - ) + grpo_fast.add_prompt_to_generator(example, 0, 0, param_prompt_Q, mock_generation_config, False) return param_prompt_Q, inference_results_Q, mock_dataset @@ -711,9 +709,7 @@ def test_more_engines_than_queries(self): ) for example in data_loader: - grpo_fast.add_prompt_to_generator( - example, example["dataset_index"], 0, 0, param_prompt_Q, mock_generation_config, False - ) + grpo_fast.add_prompt_to_generator(example, 0, 0, param_prompt_Q, mock_generation_config, False) self.assertEqual( param_prompt_Q.qsize(), num_queries, f"Should have {num_queries} batches for {num_queries} queries" @@ -746,9 +742,7 @@ def test_uneven_distribution_no_empty_batches(self): ) for example in data_loader: - grpo_fast.add_prompt_to_generator( - example, example["dataset_index"], 0, 0, param_prompt_Q, mock_generation_config, False - ) + grpo_fast.add_prompt_to_generator(example, 0, 0, param_prompt_Q, mock_generation_config, False) request_count = 0 while not param_prompt_Q.empty(): From eaa7e7737a983d33efcb6963dd747460249edc7d Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 14:10:47 -0700 Subject: [PATCH 60/96] Cleaned up PR. --- open_instruct/grpo_fast.py | 45 +++++---------------------------- open_instruct/test_grpo_fast.py | 15 ----------- 2 files changed, 7 insertions(+), 53 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 4edf18340..cf5fed970 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -79,7 +79,6 @@ from open_instruct.actor_manager import ActorManager from open_instruct.dataset_transformation import ( INPUT_IDS_PROMPT_KEY, - VERIFIER_SOURCE_KEY, TokenizerConfig, get_cached_dataset_tulu, visualize_token, @@ -300,8 +299,6 @@ class Args: on the first node and 4 learner processes on the second node; each process will have 1 GPU)""" vllm_num_engines: int = 1 """number of vLLM Engines, set to 0 to disable vLLM""" - inference_batch_size: int | None = None - """inference batch size per vLLM engine. If None, calculated as ceil(num_unique_prompts_rollout / vllm_num_engines) * num_samples_per_prompt_rollout""" vllm_tensor_parallel_size: int = 1 """tensor parallel size of vLLM Engine for multi-GPU inference""" vllm_enforce_eager: bool = False @@ -486,18 +483,6 @@ def __post_init__(self): self.max_possible_score += self.r1_style_format_reward -def get_num_verifiers(dataset: Dataset) -> int: - if VERIFIER_SOURCE_KEY not in dataset.column_names: - return 0 - verifiers = set() - for item in dataset[VERIFIER_SOURCE_KEY]: - if isinstance(item, list): - verifiers.update(item) - else: - verifiers.add(item) - return len(verifiers) - - @Timer("🔄 [Data Preparation Thread] Prepare collated data for each worker") def prepare_collated_data_for_workers( packed_sequences: PackedSequences, @@ -631,7 +616,6 @@ def from_pretrained( beaker_config: BeakerRuntimeConfig, wandb_url: str, tokenizer: PreTrainedTokenizer, - num_verifiers: int, ) -> int: # ------------------------------------------------------------ # Monkey patch to load checkpoints with `weights_only=False` @@ -791,11 +775,7 @@ def load(self, path: str, map_location=None): if hasattr(self, "ref_policy_checkpoint_path") else None, ) - # 49 base metrics: 16 from step() (KL, loss, ratio, etc.), 22 from data_loader_lib - # (scores, sequence_lengths, etc.), 7 from BatchStatistics, 4 from reward_metrics. - # Each verifier adds 2 metrics: objective/{key}_reward and objective/{key}_correct_rate. - max_metrics = 49 + 2 * num_verifiers - self.local_metrics = utils.MetricsTracker(max_metrics=max_metrics, device=self.device) + self.local_metrics = utils.MetricsTracker(max_metrics=64, device=self.device) return optimization_steps_done def forward( @@ -1586,9 +1566,6 @@ def setup_runtime_variables(args: Args, streaming_config: data_loader_lib.Stream args.num_training_steps = args.total_episodes // ( args.num_unique_prompts_rollout * streaming_config.num_samples_per_prompt_rollout ) - if args.inference_batch_size is None: - total_prompts = streaming_config.num_samples_per_prompt_rollout * args.num_unique_prompts_rollout - args.inference_batch_size = max(1, math.ceil(total_prompts / args.vllm_num_engines)) args.try_launch_beaker_eval_jobs_on_weka = args.try_launch_beaker_eval_jobs_on_weka and is_beaker_job() if args.push_to_hub: if args.hf_repo_id is None: # auto-generate one @@ -1828,9 +1805,8 @@ def create_model_and_optimizer( logger.info(f"[DEBUG] ModelGroup created with {len(policy_group.models)} policy actors") logger.info("[DEBUG] Starting model initialization across all ranks...") - num_verifiers = get_num_verifiers(train_dataset) inits = [ - model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer, num_verifiers) + model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer) for model in policy_group.models ] @@ -1962,8 +1938,7 @@ def one_training_step( [policy_group.models[i].step.remote() for i in range(args.world_size)], desc=f"Running training step {training_step}", ) - metrics_list = [r[0] for r in results] - array_metrics_list = [r[1] for r in results] + metrics, array_metrics = zip(*results) if ( args.load_ref_policy and args.ref_policy_update_freq is not None @@ -1998,21 +1973,15 @@ def one_training_step( ray.get(actor_manager.report_training_step_time.remote(train_timer.duration)) - all_keys = set() - for m in metrics_list: - all_keys.update(m.keys()) - average_metrics = {} - for k in all_keys: - values = [m[k] for m in metrics_list if k in m] - average_metrics[k] = sum(values) / len(values) - for key, value in array_metrics_list[0].items(): + average_metrics = {k: np.mean([m[k] for m in metrics if k in m]) for k in set().union(*metrics)} + for key, value in array_metrics[0].items(): average_metrics[key] = value step_time = time.perf_counter() - start_time total_training_time = time.perf_counter() - training_start_time total_generation_time = average_metrics["time/getting_response"] - prompt_lengths = array_metrics_list[0]["batch/prompt_lengths"] - response_lengths = array_metrics_list[0]["batch/response_lengths"] + prompt_lengths = array_metrics[0]["batch/prompt_lengths"] + response_lengths = array_metrics[0]["batch/response_lengths"] num_step_tokens = sum(prompt_lengths) + sum(response_lengths) utilization_metrics = calculate_utilization_metrics( diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index 8e0d44fa2..4d7c232c8 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -23,25 +23,10 @@ RAW_PROMPT_KEY, VERIFIER_SOURCE_KEY, ) -from open_instruct.grpo_fast import get_num_verifiers from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics from open_instruct.vllm_utils import create_vllm_engines -class TestGetNumVerifiers(unittest.TestCase): - def test_single_verifier_per_example(self): - dataset = Dataset.from_dict({VERIFIER_SOURCE_KEY: ["gsm8k", "math", "gsm8k"]}) - self.assertEqual(get_num_verifiers(dataset), 2) - - def test_multiple_verifiers_per_example(self): - dataset = Dataset.from_dict({VERIFIER_SOURCE_KEY: [["gsm8k", "math"], ["code"]]}) - self.assertEqual(get_num_verifiers(dataset), 3) - - def test_empty_dataset(self): - dataset = Dataset.from_dict({VERIFIER_SOURCE_KEY: []}) - self.assertEqual(get_num_verifiers(dataset), 0) - - class TestGrpoFastBase(unittest.TestCase): """Base class with common test utilities.""" From 95d9125699fbd62cf688bd1e9a168b04c40e7603 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 14:25:51 -0700 Subject: [PATCH 61/96] Cleaned up PR. --- open_instruct/data_loader.py | 38 +++++++++++------------------------- 1 file changed, 11 insertions(+), 27 deletions(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 1ee65ded8..c9a036eb4 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -15,6 +15,7 @@ import logging import threading from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict, dataclass from pathlib import Path from queue import Queue as StdQueue @@ -308,8 +309,9 @@ def __init__( ) self.local_queue = StdQueue(maxsize=config.async_steps) - self.background_thread = None self.shutdown_requested = False + self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f"DataLoader-Worker-Rank{dp_rank}") + self._data_prep_future = self._executor.submit(self._data_preparation_loop) @property def total_batches(self) -> int | None: @@ -333,7 +335,6 @@ def reshuffle(self, epoch: int | None = None, **kwargs): def get_mock_batch(self) -> dict[str, Any]: dummy_qr = torch.tensor([self.tokenizer.pad_token_id, self.tokenizer.eos_token_id], dtype=torch.long) - dummy_tool_mask = torch.zeros_like(dummy_qr) dummy_attention = torch.tensor([1, 1], dtype=torch.long) dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long) dummy_response_mask = torch.zeros_like(dummy_qr) @@ -341,7 +342,6 @@ def get_mock_batch(self) -> dict[str, Any]: return { "collated_query_responses": [dummy_qr], - "collated_tool_masks": [dummy_tool_mask], "collated_attention_masks": [dummy_attention], "collated_position_ids": [dummy_position_ids], "collated_advantages": [dummy_advantage], @@ -350,26 +350,20 @@ def get_mock_batch(self) -> dict[str, Any]: } def _iter_batches(self) -> Iterable[dict[str, Any]]: - if self.background_thread is None: - self._start_background_thread() - - while self.training_step < self.num_training_steps: + for _ in range(self.training_step, self.num_training_steps): + self._health_check() batch_data = self.local_queue.get() self.training_step += 1 yield batch_data - def _start_background_thread(self): - self.shutdown_requested = False - self.background_thread = threading.Thread( - target=self._data_preparation_loop, daemon=True, name=f"DataLoader-Worker-Rank{self.dp_rank}" - ) - self.background_thread.start() + def _health_check(self): + if self._data_prep_future.done(): + self._data_prep_future.result() def _data_preparation_loop(self): for _ in range(self.config.async_steps * self.global_batch_size // self.dp_world_size): - example = next(self.iter_dataloader) add_prompt_to_generator( - example, + next(self.iter_dataloader), self.iter_dataloader._epoch, self.training_step, self.param_prompt_Q, @@ -534,7 +528,6 @@ def _data_preparation_loop(self): def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> dict[str, list[torch.Tensor]]: per_device_packed_query_responses = packed_sequences.query_responses - per_device_packed_tool_masks = getattr(packed_sequences, "tool_masks", None) per_device_packed_attention_masks = packed_sequences.attention_masks per_device_packed_position_ids = packed_sequences.position_ids per_device_packed_advantages = packed_sequences.advantages @@ -543,7 +536,6 @@ def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> b_inds = np.random.permutation(len(per_device_packed_query_responses)) collated_query_responses = [] - collated_tool_masks = [] if per_device_packed_tool_masks is not None else None collated_attention_masks = [] collated_position_ids = [] collated_response_masks = [] @@ -556,10 +548,6 @@ def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> [per_device_packed_query_responses[idx] for idx in micro_range], self.tokenizer.pad_token_id, True ) ) - if per_device_packed_tool_masks is not None: - collated_tool_masks.append( - collate_fn([per_device_packed_tool_masks[idx] for idx in micro_range], 0, True) - ) collated_attention_masks.append( collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0, True) ) @@ -574,7 +562,7 @@ def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0, True) ) - result = { + return { "collated_query_responses": collated_query_responses, "collated_attention_masks": collated_attention_masks, "collated_position_ids": collated_position_ids, @@ -582,14 +570,10 @@ def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> "collated_response_masks": collated_response_masks, "collated_vllm_logprobs": collated_vllm_logprobs, } - if collated_tool_masks is not None: - result["collated_tool_masks"] = collated_tool_masks - return result def shutdown(self): self.shutdown_requested = True - if self.background_thread is not None: - self.background_thread.join(timeout=5.0) + self._executor.shutdown(wait=True) def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: From 3ed71471976f5090087209130087c1bc2e484cdd Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 14:46:42 -0700 Subject: [PATCH 62/96] Now, tests pass. --- open_instruct/data_loader.py | 16 ++++++++++++++++ open_instruct/grpo_fast.py | 7 +++++-- open_instruct/test_vllm_utils.py | 8 ++++---- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index c9a036eb4..d697f8587 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -398,6 +398,22 @@ def _data_preparation_loop(self): if isinstance(result, ShutdownSentinel): logger.info(f"[DataLoader Worker {self.dp_rank}] Received shutdown sentinel, exiting") return + if result is None: + logger.info( + f"[DataLoader Worker {self.dp_rank}] All prompts filtered, " + "yielding empty batch" + ) + empty_batch = { + "collated_query_responses": [], + "collated_attention_masks": [], + "collated_position_ids": [], + "collated_advantages": [], + "collated_response_masks": [], + "collated_vllm_logprobs": [], + "metrics": {}, + } + self.local_queue.put(empty_batch) + continue getting_response_time = timer.duration scores = np.array(batch.scores) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index cf5fed970..544b27d44 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -589,7 +589,7 @@ def __init__( self.tokenizer = tokenizer self.pad_token_id = tokenizer.pad_token_id self.num_mini_batches = args.num_mini_batches - self.dataloader = data_loader_config.build( + self.dataloader = iter(data_loader_config.build( dataset=dataset, inference_results_Q=inference_results_Q, param_prompt_Q=param_prompt_Q, @@ -607,7 +607,7 @@ def __init__( max_possible_score=args.max_possible_score, actor_manager=actor_manager, model_dims=model_dims, - ) + )) def from_pretrained( self, @@ -996,6 +996,9 @@ def step(self): collated_advantages = batch_data["collated_advantages"] collated_response_masks = batch_data["collated_response_masks"] collated_vllm_logprobs = batch_data["collated_vllm_logprobs"] + if len(collated_query_responses) == 0: + logger.warning("[Training] Empty batch received, skipping training step") + return [], {} args = self.args to_device_inplace(collated_query_responses, self.device) to_device_inplace(collated_attention_masks, self.device) diff --git a/open_instruct/test_vllm_utils.py b/open_instruct/test_vllm_utils.py index b9418e996..9ff16e4b6 100644 --- a/open_instruct/test_vllm_utils.py +++ b/open_instruct/test_vllm_utils.py @@ -25,7 +25,7 @@ def create_mock_logprobs(token_ids): return [{tid: MagicMock(logprob=-0.1 * tid)} for tid in token_ids] mock_request = PromptRequest( - prompt=[1, 2, 3], generation_config=None, is_eval=False, dataset_index=43039, prompt_id="test_prompt_1" + prompt=[1, 2, 3], generation_config=None, is_eval=False, dataset_index=43039 ) request_id = make_request_id(mock_request) @@ -64,7 +64,7 @@ def create_mock_logprobs(token_ids): request_id: { "is_eval": False, "dataset_index": 43039, - "prompt_id": "test_prompt_1", + "prompt_id": "0_43039", "prompt_token_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "start_time": 1000.0, } @@ -108,7 +108,7 @@ def create_mock_logprobs(token_ids): return [{tid: MagicMock(logprob=-0.1 * tid)} for tid in token_ids] mock_request = PromptRequest( - prompt=[1, 2, 3], generation_config=None, is_eval=True, dataset_index=200, prompt_id="test_prompt_2" + prompt=[1, 2, 3], generation_config=None, is_eval=True, dataset_index=200 ) request_id = make_request_id(mock_request) @@ -133,7 +133,7 @@ def create_mock_logprobs(token_ids): request_id: { "is_eval": True, "dataset_index": 200, - "prompt_id": "test_prompt_2", + "prompt_id": "0_200", "prompt_token_ids": [1, 2, 3, 4, 5], "start_time": 2000.0, } From eb1081774ce93769a503c40258d0476913560a8d Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 16:32:36 -0700 Subject: [PATCH 63/96] Uses a singleton --- open_instruct/data_loader.py | 663 +++++++++++++++++-------------- open_instruct/grpo_fast.py | 105 +++-- open_instruct/test_vllm_utils.py | 8 +- 3 files changed, 404 insertions(+), 372 deletions(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index d697f8587..e90ef6e60 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -14,14 +14,15 @@ import logging import threading +import time from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor from dataclasses import asdict, dataclass from pathlib import Path -from queue import Queue as StdQueue from typing import Any import numpy as np +import ray import torch import vllm from datasets import Dataset @@ -39,7 +40,7 @@ ) from open_instruct.model_utils import Batch from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, ShutdownSentinel, TokenStatistics -from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences +from open_instruct.rl_utils import PackedSequences, pack_sequences from open_instruct.utils import combine_reward_metrics, repeat_each logger = logging.getLogger(__name__) @@ -207,42 +208,24 @@ def __post_init__(self): if self.async_steps < 1: raise ValueError("`async_steps` must be greater than 0. Fully synchronous training is not supported.") - def build( + def build_dataloader( self, - dataset: Dataset, - inference_results_Q: ray_queue.Queue, - param_prompt_Q: ray_queue.Queue, + data_prep_actor_name: str, tokenizer: PreTrainedTokenizer, - generation_config: Any, dp_rank: int, fs_local_rank: int, num_training_steps: int, - seed: int, - per_device_train_batch_size: int, - verbose: bool, work_dir: Path | str, global_batch_size: int, dp_world_size: int, - max_possible_score: float, - actor_manager=None, - model_dims: utils.ModelDims | None = None, ) -> "StreamingDataLoader": + """Build a thin wrapper dataloader that pulls from the DataPreparationActor singleton.""" return StreamingDataLoader( - dataset=dataset, - inference_results_Q=inference_results_Q, - param_prompt_Q=param_prompt_Q, + data_prep_actor_name=data_prep_actor_name, tokenizer=tokenizer, - config=self, - generation_config=generation_config, work_dir=work_dir, global_batch_size=global_batch_size, num_training_steps=num_training_steps, - seed=seed, - per_device_train_batch_size=per_device_train_batch_size, - verbose=verbose, - max_possible_score=max_possible_score, - actor_manager=actor_manager, - model_dims=model_dims, dp_world_size=dp_world_size, dp_rank=dp_rank, fs_local_rank=fs_local_rank, @@ -250,24 +233,16 @@ def build( class StreamingDataLoader(data_loader.DataLoaderBase): + """Thin wrapper dataloader that pulls pre-prepared data from the DataPreparationActor singleton.""" + def __init__( self, *, - dataset: Dataset, - inference_results_Q: ray_queue.Queue, - param_prompt_Q: ray_queue.Queue, + data_prep_actor_name: str, tokenizer: PreTrainedTokenizer, - config: StreamingDataLoaderConfig, - generation_config: Any, work_dir: Path | str, global_batch_size: int, num_training_steps: int = 0, - seed: int, - per_device_train_batch_size: int, - verbose: bool, - max_possible_score: float, - actor_manager=None, - model_dims: utils.ModelDims = None, dp_world_size: int = 1, dp_rank: int = 0, fs_local_rank: int = 0, @@ -280,54 +255,22 @@ def __init__( fs_local_rank=fs_local_rank, ) - self.dataset = dataset - self.inference_results_Q = inference_results_Q - self.param_prompt_Q = param_prompt_Q + self.data_prep_actor = ray.get_actor(data_prep_actor_name) self.tokenizer = tokenizer - self.config = config - self.config.max_possible_score = max_possible_score - self.generation_config = generation_config self.num_training_steps = num_training_steps - self.actor_manager = actor_manager - self.model_dims = model_dims - - self.per_device_train_batch_size = per_device_train_batch_size - self.verbose = verbose - self.training_step = 0 self.current_epoch = 0 - self.seed = seed - - self.iter_dataloader = HFDataLoader( - dataset=dataset, - batch_size=1, - seed=seed, - rank=dp_rank, - world_size=dp_world_size, - work_dir=work_dir, - automatic_reshuffle=True, - ) - - self.local_queue = StdQueue(maxsize=config.async_steps) - self.shutdown_requested = False - self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix=f"DataLoader-Worker-Rank{dp_rank}") - self._data_prep_future = self._executor.submit(self._data_preparation_loop) @property def total_batches(self) -> int | None: return self.num_training_steps def state_dict(self) -> dict[str, Any]: - return { - "training_step": self.training_step, - "current_epoch": self.current_epoch, - "iter_dataloader_state": self.iter_dataloader.state_dict(), - } + return {"training_step": self.training_step, "current_epoch": self.current_epoch} def load_state_dict(self, state_dict: dict[str, Any]): self.training_step = state_dict["training_step"] self.current_epoch = state_dict.get("current_epoch", 0) - self.iter_dataloader.load_state_dict(state_dict["iter_dataloader_state"]) def reshuffle(self, epoch: int | None = None, **kwargs): if epoch is not None: @@ -350,246 +293,13 @@ def get_mock_batch(self) -> dict[str, Any]: } def _iter_batches(self) -> Iterable[dict[str, Any]]: - for _ in range(self.training_step, self.num_training_steps): - self._health_check() - batch_data = self.local_queue.get() - self.training_step += 1 + for step in range(self.training_step, self.num_training_steps): + batch_data = ray.get(self.data_prep_actor.get_data.remote(rank=self.dp_rank, step=step)) + self.training_step = step + 1 yield batch_data - def _health_check(self): - if self._data_prep_future.done(): - self._data_prep_future.result() - - def _data_preparation_loop(self): - for _ in range(self.config.async_steps * self.global_batch_size // self.dp_world_size): - add_prompt_to_generator( - next(self.iter_dataloader), - self.iter_dataloader._epoch, - self.training_step, - self.param_prompt_Q, - self.generation_config, - is_eval=False, - ) - - for training_step in range(self.training_step, self.num_training_steps): - if self.shutdown_requested: - logger.info(f"[DataLoader Worker {self.dp_rank}] Shutdown requested, exiting") - return - - with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: - result, batch, reward_metrics, batch_stats = accumulate_inference_batches( - self.inference_results_Q, - self.generation_config, - num_prompts=self.rank_batch_size, - model_dims=self.model_dims, - tokenizer=self.tokenizer, - dataset=self.dataset, - actor_manager=self.actor_manager, - active_sampling=self.config.active_sampling, - filter_zero_std_samples=self.config.filter_zero_std_samples, - replenish_prompts=True, - no_resampling_pass_rate=self.config.no_resampling_pass_rate, - iter_dataloader=self.iter_dataloader, - param_prompt_Q=self.param_prompt_Q, - training_step=training_step, - verbose=self.verbose, - max_possible_score=self.config.max_possible_score, - ) - if isinstance(result, ShutdownSentinel): - logger.info(f"[DataLoader Worker {self.dp_rank}] Received shutdown sentinel, exiting") - return - if result is None: - logger.info( - f"[DataLoader Worker {self.dp_rank}] All prompts filtered, " - "yielding empty batch" - ) - empty_batch = { - "collated_query_responses": [], - "collated_attention_masks": [], - "collated_position_ids": [], - "collated_advantages": [], - "collated_response_masks": [], - "collated_vllm_logprobs": [], - "metrics": {}, - } - self.local_queue.put(empty_batch) - continue - - getting_response_time = timer.duration - scores = np.array(batch.scores) - - good_outputs = [ - len(result.request_info.tool_outputs[i]) > 0 - and result.request_info.tool_calleds[i] - and not result.request_info.timeouts[i] - and not result.request_info.tool_errors[i] - for i in range(len(result.request_info.tool_outputs)) - ] - scores_per_prompt = scores.reshape(-1, self.config.num_samples_per_prompt_rollout) - mean_grouped_rewards = scores_per_prompt.mean(axis=-1) - mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) - std_grouped_rewards = scores_per_prompt.std(axis=-1) - std_grouped_rewards = np.repeat(std_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) - if self.config.advantage_normalization_type == "standard": - advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - elif self.config.advantage_normalization_type == "centered": - advantages = scores - mean_grouped_rewards - else: - raise ValueError(f"Invalid advantage normalization type: {self.config.advantage_normalization_type}") - - if self.config.mask_truncated_completions: - stop_idxes = torch.tensor( - [i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"] - ) - num_truncated = len(result.finish_reasons) - len(stop_idxes) - if num_truncated > 0: - logger.info( - f"[Truncated completions filtering] Filtered {num_truncated} responses that didn't finish with 'stop'. " - f"Retention rate: {len(stop_idxes) / len(result.finish_reasons):.2%}" - ) - scores = scores[stop_idxes] - advantages = advantages[stop_idxes] - batch = batch[stop_idxes.tolist()] - result.responses = [result.responses[i] for i in stop_idxes] - result.masks = [result.masks[i] for i in stop_idxes] - result.finish_reasons = [result.finish_reasons[i] for i in stop_idxes] - result.logprobs = [result.logprobs[i] for i in stop_idxes] - - with Timer("📦 [Data Preparation Thread] Packing sequences"): - packed_sequences = pack_sequences( - queries=batch.queries, - responses=result.responses, - masks=result.masks, - pack_length=self.config.pack_length, - pad_token_id=self.tokenizer.pad_token_id, - vllm_logprobs=result.logprobs, - ) - lookup_advantages = np.zeros(len(advantages) + 1, dtype=np.float32) - lookup_advantages[1:] = advantages - packed_advantages = [ - torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) - for packed_mask in packed_sequences.response_masks - ] - packed_sequences.advantages = packed_advantages - - collated_data = self._prepare_collated_data_for_self(packed_sequences) - - if len(result.responses) == 0: - metrics = {} - logger.warning(f"No responses in batch {training_step}.") - else: - real_num_responses = len(result.responses) - expected_num_responses = self.config.num_samples_per_prompt_rollout * self.global_batch_size - - unsolved_num_responses = (scores < self.config.max_possible_score).sum() - sequence_lengths = np.array([len(response) for response in result.responses]) - sequence_length_solved = ( - np.array([]) - if np.all(scores == 0) - else np.array(sequence_lengths[scores == self.config.max_possible_score]) - ) - sequence_length_unsolved = ( - np.array([]) - if np.all(scores == self.config.max_possible_score) - else np.array(sequence_lengths[scores == 0]) - ) - stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( - result.finish_reasons - ) - - batch_metrics = asdict(batch_stats) - batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()} - - metrics = { - "scores": scores.mean(), - "real_batch_size_ratio": real_num_responses / expected_num_responses, - "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, - "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, - "val/solve_rate_hist": batch_stats.percent_solved_hist, - "val/total_reward_groups": real_num_responses / self.config.num_samples_per_prompt_rollout, - "val/sequence_lengths": sequence_lengths.mean(), - "val/sequence_lengths_min": sequence_lengths.min(), - "val/sequence_lengths_max": sequence_lengths.max(), - "val/sequence_lengths_unsolved": ( - 0 if len(sequence_length_unsolved) == 0 else sequence_length_unsolved.mean() - ), - "val/sequence_lengths_solved": ( - 0 if len(sequence_length_solved) == 0 else sequence_length_solved.mean() - ), - "val/sequence_lengths_unsolved_hist": sequence_length_unsolved, - "val/sequence_lengths_solved_hist": sequence_length_solved, - "val/stop_rate": stop_rate, - "val/advantages_mean": advantages.mean(), - "val/advantages_min": advantages.min(), - "val/advantages_max": advantages.max(), - "val/advantages_hist": advantages, - "val/num_calls_rate": np.array(result.request_info.num_calls).mean(), - "val/timeouts_rate": np.array(result.request_info.timeouts).mean(), - "val/tool_errors_rate": np.array( - [len(item) > 0 for item in result.request_info.tool_errors] - ).mean(), - "val/good_outputs_rate": np.array(good_outputs).mean(), - "val/tool_runtimes_rate": np.array(result.request_info.tool_runtimes).mean(), - "val/tool_calleds_rate": np.array(result.request_info.tool_calleds).mean(), - "time/getting_response": getting_response_time, - **reward_metrics, - **batch_metrics_prefixed, - } - - total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens - metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time - - collated_data["metrics"] = metrics - self.local_queue.put(collated_data) - - def _prepare_collated_data_for_self(self, packed_sequences: PackedSequences) -> dict[str, list[torch.Tensor]]: - per_device_packed_query_responses = packed_sequences.query_responses - per_device_packed_attention_masks = packed_sequences.attention_masks - per_device_packed_position_ids = packed_sequences.position_ids - per_device_packed_advantages = packed_sequences.advantages - per_device_packed_response_masks = packed_sequences.response_masks - per_device_packed_vllm_logprobs = packed_sequences.vllm_logprobs - - b_inds = np.random.permutation(len(per_device_packed_query_responses)) - collated_query_responses = [] - collated_attention_masks = [] - collated_position_ids = [] - collated_response_masks = [] - collated_advantages = [] - collated_vllm_logprobs = [] - for j in range(0, len(per_device_packed_query_responses), self.per_device_train_batch_size): - micro_range = b_inds[j : j + self.per_device_train_batch_size] - collated_query_responses.append( - collate_fn( - [per_device_packed_query_responses[idx] for idx in micro_range], self.tokenizer.pad_token_id, True - ) - ) - collated_attention_masks.append( - collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0, True) - ) - collated_position_ids.append( - collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0, True) - ) - collated_response_masks.append( - collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0, True) - ) - collated_advantages.append(collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0, True)) - collated_vllm_logprobs.append( - collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0, True) - ) - - return { - "collated_query_responses": collated_query_responses, - "collated_attention_masks": collated_attention_masks, - "collated_position_ids": collated_position_ids, - "collated_advantages": collated_advantages, - "collated_response_masks": collated_response_masks, - "collated_vllm_logprobs": collated_vllm_logprobs, - } - def shutdown(self): - self.shutdown_requested = True - self._executor.shutdown(wait=True) + pass def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: @@ -921,3 +631,344 @@ def accumulate_inference_batches( total_prompts=len(results), ) return combined_result, batch, combined_reward_metrics, batch_stats + + +def prepare_collated_data_for_workers( + packed_sequences: PackedSequences, + world_size: int, + per_device_train_batch_size: int, + pad_token_id: int, + pin_memory: bool = True, +) -> list[dict[str, list[torch.Tensor]]]: + """Distributes and collates packed sequences for distributed training. + + Splits packed sequences across workers, randomly shuffles each worker's data, + and collates into micro-batches for training. + + Args: + packed_sequences: Packed training sequences containing query responses, + attention masks, position IDs, advantages, response masks, + and vllm logprobs. + world_size: Number of distributed workers. + per_device_train_batch_size: Batch size for each device's micro-batch. + pad_token_id: Token ID used for padding sequences. + pin_memory: Whether to pin memory for faster data transfer to GPU. + + Returns: + List of dictionaries, one per worker, each containing collated tensors + for query_responses, attention_masks, position_ids, + advantages, response_masks, and vllm_logprobs. + """ + B = len(packed_sequences.query_responses) // world_size + collated_data = [] + for i in range(world_size): + per_device_packed_query_responses = packed_sequences.query_responses[B * i : B * (i + 1)] + per_device_packed_attention_masks = packed_sequences.attention_masks[B * i : B * (i + 1)] + per_device_packed_position_ids = packed_sequences.position_ids[B * i : B * (i + 1)] + per_device_packed_advantages = packed_sequences.advantages[B * i : B * (i + 1)] + per_device_packed_response_masks = packed_sequences.response_masks[B * i : B * (i + 1)] + per_device_packed_vllm_logprobs = packed_sequences.vllm_logprobs[B * i : B * (i + 1)] + + b_inds = np.random.permutation(len(per_device_packed_query_responses)) + collated_query_responses = [] + collated_attention_masks = [] + collated_position_ids = [] + collated_response_masks = [] + collated_advantages = [] + collated_vllm_logprobs = [] + for j in range(0, len(per_device_packed_query_responses), per_device_train_batch_size): + micro_range = b_inds[j : j + per_device_train_batch_size] + collated_query_responses.append( + collate_fn([per_device_packed_query_responses[idx] for idx in micro_range], pad_token_id, pin_memory) + ) + collated_attention_masks.append( + collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0, pin_memory) + ) + collated_position_ids.append( + collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0, pin_memory) + ) + collated_response_masks.append( + collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0, pin_memory) + ) + collated_advantages.append( + collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0, pin_memory) + ) + collated_vllm_logprobs.append( + collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0, pin_memory) + ) + collated_data.append( + { + "collated_query_responses": collated_query_responses, + "collated_attention_masks": collated_attention_masks, + "collated_position_ids": collated_position_ids, + "collated_advantages": collated_advantages, + "collated_response_masks": collated_response_masks, + "collated_vllm_logprobs": collated_vllm_logprobs, + } + ) + return collated_data + + +@ray.remote +class DataPreparationActor: + """Ray actor singleton that handles centralized data preparation for all ranks. + + This actor runs a background thread that continuously prepares training data, + ensuring all ranks receive the same number of micro-batches (preventing deadlock + from uneven filtering). + """ + + def __init__( + self, + dataset: Dataset, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + tokenizer: PreTrainedTokenizer, + config: StreamingDataLoaderConfig, + generation_config, + num_training_steps: int, + seed: int, + per_device_train_batch_size: int, + global_batch_size: int, + dp_world_size: int, + max_possible_score: float, + actor_manager, + model_dims: utils.ModelDims, + verbose: bool, + work_dir: str, + ): + self.inference_results_Q = inference_results_Q + self.param_prompt_Q = param_prompt_Q + self.tokenizer = tokenizer + self.config = config + self.config.max_possible_score = max_possible_score + self.generation_config = generation_config + self.num_training_steps = num_training_steps + self.per_device_train_batch_size = per_device_train_batch_size + self.global_batch_size = global_batch_size + self.dp_world_size = dp_world_size + self.actor_manager = actor_manager + self.model_dims = model_dims + self.verbose = verbose + self.dataset = dataset + + self.iter_dataloader = HFDataLoader( + dataset=dataset, batch_size=1, seed=seed, rank=0, world_size=1, work_dir=work_dir, automatic_reshuffle=True + ) + + self.prepared_data: dict[int, list[dict]] = {} + self.metrics: dict[int, dict] = {} + self.current_prepared_step = -1 + self.lock = threading.Lock() + self.shutdown_requested = False + self.training_step = 0 + + self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="DataPrepActor") + self._prep_future = self._executor.submit(self._data_preparation_loop) + + def _data_preparation_loop(self): + for _ in range(self.config.async_steps * self.global_batch_size): + add_prompt_to_generator( + next(self.iter_dataloader), + self.iter_dataloader._epoch, + self.training_step, + self.param_prompt_Q, + self.generation_config, + is_eval=False, + ) + + for step in range(self.training_step, self.num_training_steps): + if self.shutdown_requested: + logger.info("[DataPreparationActor] Shutdown requested, exiting") + return + + result, batch, reward_metrics, batch_stats = accumulate_inference_batches( + self.inference_results_Q, + self.generation_config, + num_prompts=self.global_batch_size, + model_dims=self.model_dims, + tokenizer=self.tokenizer, + dataset=self.dataset, + actor_manager=self.actor_manager, + active_sampling=self.config.active_sampling, + filter_zero_std_samples=self.config.filter_zero_std_samples, + replenish_prompts=True, + no_resampling_pass_rate=self.config.no_resampling_pass_rate, + iter_dataloader=self.iter_dataloader, + param_prompt_Q=self.param_prompt_Q, + training_step=step, + verbose=self.verbose, + max_possible_score=self.config.max_possible_score, + ) + + if isinstance(result, ShutdownSentinel): + logger.info("[DataPreparationActor] Received shutdown sentinel, exiting") + return + + if result is None: + logger.info("[DataPreparationActor] All prompts filtered, yielding empty batch") + empty_data = [ + { + "collated_query_responses": [], + "collated_attention_masks": [], + "collated_position_ids": [], + "collated_advantages": [], + "collated_response_masks": [], + "collated_vllm_logprobs": [], + } + for _ in range(self.dp_world_size) + ] + with self.lock: + self.prepared_data[step] = empty_data + self.metrics[step] = {} + self.current_prepared_step = step + continue + + scores = np.array(batch.scores) + scores_per_prompt = scores.reshape(-1, self.config.num_samples_per_prompt_rollout) + mean_grouped_rewards = scores_per_prompt.mean(axis=-1) + mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) + std_grouped_rewards = scores_per_prompt.std(axis=-1) + std_grouped_rewards = np.repeat(std_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) + + if self.config.advantage_normalization_type == "standard": + advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) + elif self.config.advantage_normalization_type == "centered": + advantages = scores - mean_grouped_rewards + else: + raise ValueError(f"Invalid advantage normalization type: {self.config.advantage_normalization_type}") + + if self.config.mask_truncated_completions: + stop_idxes = torch.tensor( + [i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"] + ) + num_truncated = len(result.finish_reasons) - len(stop_idxes) + if num_truncated > 0: + logger.info( + f"[DataPreparationActor] Filtered {num_truncated} responses that didn't finish with 'stop'. " + f"Retention rate: {len(stop_idxes) / len(result.finish_reasons):.2%}" + ) + scores = scores[stop_idxes] + advantages = advantages[stop_idxes] + batch = batch[stop_idxes.tolist()] + result.responses = [result.responses[i] for i in stop_idxes] + result.masks = [result.masks[i] for i in stop_idxes] + result.finish_reasons = [result.finish_reasons[i] for i in stop_idxes] + result.logprobs = [result.logprobs[i] for i in stop_idxes] + + packed_sequences = pack_sequences( + queries=batch.queries, + responses=result.responses, + masks=result.masks, + pack_length=self.config.pack_length, + pad_token_id=self.tokenizer.pad_token_id, + vllm_logprobs=result.logprobs, + ) + lookup_advantages = np.zeros(len(advantages) + 1, dtype=np.float32) + lookup_advantages[1:] = advantages + packed_advantages = [ + torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) + for packed_mask in packed_sequences.response_masks + ] + packed_sequences.advantages = packed_advantages + + collated_data = prepare_collated_data_for_workers( + packed_sequences, self.dp_world_size, self.per_device_train_batch_size, self.tokenizer.pad_token_id + ) + + if len(result.responses) == 0: + step_metrics = {} + else: + real_num_responses = len(result.responses) + expected_num_responses = self.config.num_samples_per_prompt_rollout * self.global_batch_size + unsolved_num_responses = (scores < self.config.max_possible_score).sum() + sequence_lengths = np.array([len(response) for response in result.responses]) + sequence_length_solved = ( + np.array([]) + if np.all(scores == 0) + else np.array(sequence_lengths[scores == self.config.max_possible_score]) + ) + sequence_length_unsolved = ( + np.array([]) + if np.all(scores == self.config.max_possible_score) + else np.array(sequence_lengths[scores == 0]) + ) + stop_rate = sum(int(fr == "stop") for fr in result.finish_reasons) / len(result.finish_reasons) + + batch_metrics_dict = asdict(batch_stats) + batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics_dict.items()} + + step_metrics = { + "scores": scores.mean(), + "real_batch_size_ratio": real_num_responses / expected_num_responses, + "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, + "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, + "val/solve_rate_hist": batch_stats.percent_solved_hist, + "val/total_reward_groups": real_num_responses / self.config.num_samples_per_prompt_rollout, + "val/sequence_lengths": sequence_lengths.mean(), + "val/sequence_lengths_min": sequence_lengths.min(), + "val/sequence_lengths_max": sequence_lengths.max(), + "val/sequence_lengths_unsolved": ( + 0 if len(sequence_length_unsolved) == 0 else sequence_length_unsolved.mean() + ), + "val/sequence_lengths_solved": ( + 0 if len(sequence_length_solved) == 0 else sequence_length_solved.mean() + ), + "val/sequence_lengths_unsolved_hist": sequence_length_unsolved, + "val/sequence_lengths_solved_hist": sequence_length_solved, + "val/stop_rate": stop_rate, + "val/advantages_mean": advantages.mean(), + "val/advantages_min": advantages.min(), + "val/advantages_max": advantages.max(), + "val/advantages_hist": advantages, + "val/num_calls_rate": np.array(result.request_info.num_calls).mean(), + "val/timeouts_rate": np.array(result.request_info.timeouts).mean(), + "val/tool_errors_rate": np.array( + [len(item) > 0 for item in result.request_info.tool_errors] + ).mean(), + "val/tool_runtimes_rate": np.array(result.request_info.tool_runtimes).mean(), + "val/tool_calleds_rate": np.array(result.request_info.tool_calleds).mean(), + **reward_metrics, + **batch_metrics_prefixed, + } + + total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens + step_metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time + + with self.lock: + self.prepared_data[step] = collated_data + self.metrics[step] = step_metrics + self.current_prepared_step = step + + def get_data(self, rank: int, step: int) -> dict: + """Called by each rank's StreamingDataLoader. Blocks until data ready.""" + while True: + with self.lock: + if step <= self.current_prepared_step: + data = self.prepared_data[step][rank].copy() + data["metrics"] = self.metrics[step] + self._cleanup_old_steps(step) + return data + time.sleep(0.01) + + def _cleanup_old_steps(self, current_step: int): + """Remove old step data to prevent memory leak.""" + steps_to_remove = [s for s in self.prepared_data if s < current_step - 1] + for s in steps_to_remove: + del self.prepared_data[s] + if s in self.metrics: + del self.metrics[s] + + def shutdown(self): + self.shutdown_requested = True + self._executor.shutdown(wait=True) + + def get_state(self) -> dict: + return { + "training_step": self.current_prepared_step + 1, + "iter_dataloader_state": self.iter_dataloader.state_dict(), + } + + def set_state(self, state: dict): + self.training_step = state["training_step"] + self.iter_dataloader.load_state_dict(state["iter_dataloader_state"]) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 544b27d44..e66888670 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -38,7 +38,12 @@ from open_instruct import data_loader as data_loader_lib from open_instruct import utils -from open_instruct.data_loader import accumulate_inference_batches, add_prompt_to_generator, collate_fn +from open_instruct.data_loader import ( + DataPreparationActor, + accumulate_inference_batches, + add_prompt_to_generator, + collate_fn, +) # isort: on import asyncio @@ -577,37 +582,25 @@ def __init__( master_port: int | None, args: Args, data_loader_config: data_loader_lib.StreamingDataLoaderConfig, - dataset: Dataset, - inference_results_Q: ray_queue.Queue, - param_prompt_Q: ray_queue.Queue, + data_prep_actor_name: str, tokenizer: PreTrainedTokenizer, - generation_config, - actor_manager, - model_dims: utils.ModelDims, ): super().__init__(world_size, rank, local_rank, master_addr, master_port) self.tokenizer = tokenizer self.pad_token_id = tokenizer.pad_token_id self.num_mini_batches = args.num_mini_batches - self.dataloader = iter(data_loader_config.build( - dataset=dataset, - inference_results_Q=inference_results_Q, - param_prompt_Q=param_prompt_Q, - tokenizer=tokenizer, - generation_config=generation_config, - dp_rank=rank, - fs_local_rank=self.local_rank, - num_training_steps=args.num_training_steps, - seed=args.seed, - per_device_train_batch_size=args.per_device_train_batch_size, - verbose=args.verbose, - work_dir=args.output_dir, - global_batch_size=args.num_unique_prompts_rollout, - dp_world_size=world_size, - max_possible_score=args.max_possible_score, - actor_manager=actor_manager, - model_dims=model_dims, - )) + self.dataloader = iter( + data_loader_config.build_dataloader( + data_prep_actor_name=data_prep_actor_name, + tokenizer=tokenizer, + dp_rank=rank, + fs_local_rank=self.local_rank, + num_training_steps=args.num_training_steps, + work_dir=args.output_dir, + global_batch_size=args.num_unique_prompts_rollout, + dp_world_size=world_size, + ) + ) def from_pretrained( self, @@ -1413,13 +1406,8 @@ def __init__( single_gpu_mode: bool, args: Args, data_loader_config: data_loader_lib.StreamingDataLoaderConfig, - dataset: Dataset, - inference_results_Q: ray_queue.Queue, - param_prompt_Q: ray_queue.Queue, + data_prep_actor_name: str, tokenizer: PreTrainedTokenizer, - generation_config, - actor_manager, - model_dims: utils.ModelDims, ): self.pg = pg self.ray_process_cls = ray_process_cls @@ -1434,22 +1422,7 @@ def __init__( scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=self.pg, placement_group_bundle_index=0 ), - ).remote( - world_size, - 0, - 0, - None, - None, - args, - data_loader_config, - dataset, - inference_results_Q, - param_prompt_Q, - tokenizer, - generation_config, - actor_manager, - model_dims, - ) + ).remote(world_size, 0, 0, None, None, args, data_loader_config, data_prep_actor_name, tokenizer) self.models.append(master_policy) results, _ = ray_get_with_progress( @@ -1490,13 +1463,8 @@ def get_bundle_index(rank, num_gpus_per_node): master_port, args, data_loader_config, - dataset, - inference_results_Q, - param_prompt_Q, + data_prep_actor_name, tokenizer, - generation_config, - actor_manager, - model_dims, ) self.models.append(worker_policy) @@ -1788,6 +1756,28 @@ def create_model_and_optimizer( logger.info("[DEBUG] KV cache configuration complete") # Now create policy actors with all dependencies + logger.info("[DEBUG] Creating DataPreparationActor singleton...") + data_prep_actor_name = "data_prep_singleton" + _data_prep_actor = DataPreparationActor.options(name=data_prep_actor_name, num_cpus=2).remote( + dataset=train_dataset, + inference_results_Q=inference_results_Q, + param_prompt_Q=param_prompt_Q, + tokenizer=tokenizer, + config=data_loader_config, + generation_config=generation_config, + num_training_steps=args.num_training_steps, + seed=args.seed, + per_device_train_batch_size=args.per_device_train_batch_size, + global_batch_size=args.num_unique_prompts_rollout, + dp_world_size=args.world_size, + max_possible_score=args.max_possible_score, + actor_manager=actor_manager, + model_dims=model_dims, + verbose=args.verbose, + work_dir=args.output_dir, + ) + logger.info(f"[DEBUG] DataPreparationActor singleton created with name: {data_prep_actor_name}") + logger.info("[DEBUG] Creating ModelGroup with policy actors...") wandb_url = wandb.run.get_url() if args.with_tracking else None policy_group = ModelGroup( @@ -1797,13 +1787,8 @@ def create_model_and_optimizer( args.single_gpu_mode, args=args, data_loader_config=data_loader_config, - dataset=train_dataset, - inference_results_Q=inference_results_Q, - param_prompt_Q=param_prompt_Q, + data_prep_actor_name=data_prep_actor_name, tokenizer=tokenizer, - generation_config=generation_config, - actor_manager=actor_manager, - model_dims=model_dims, ) logger.info(f"[DEBUG] ModelGroup created with {len(policy_group.models)} policy actors") diff --git a/open_instruct/test_vllm_utils.py b/open_instruct/test_vllm_utils.py index 9ff16e4b6..7a6b62e31 100644 --- a/open_instruct/test_vllm_utils.py +++ b/open_instruct/test_vllm_utils.py @@ -24,9 +24,7 @@ def test_process_outputs_with_tools(self): def create_mock_logprobs(token_ids): return [{tid: MagicMock(logprob=-0.1 * tid)} for tid in token_ids] - mock_request = PromptRequest( - prompt=[1, 2, 3], generation_config=None, is_eval=False, dataset_index=43039 - ) + mock_request = PromptRequest(prompt=[1, 2, 3], generation_config=None, is_eval=False, dataset_index=43039) request_id = make_request_id(mock_request) mock_output1 = MagicMock(spec=vllm.CompletionOutput) @@ -107,9 +105,7 @@ def test_process_outputs_without_tools(self): def create_mock_logprobs(token_ids): return [{tid: MagicMock(logprob=-0.1 * tid)} for tid in token_ids] - mock_request = PromptRequest( - prompt=[1, 2, 3], generation_config=None, is_eval=True, dataset_index=200 - ) + mock_request = PromptRequest(prompt=[1, 2, 3], generation_config=None, is_eval=True, dataset_index=200) request_id = make_request_id(mock_request) mock_output1 = MagicMock(spec=vllm.CompletionOutput) From 0336e9001bfc977b093063b5c6a870dc8b56f408 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Wed, 3 Dec 2025 16:37:57 -0700 Subject: [PATCH 64/96] Fix DataPreparationActor garbage collection bug MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The DataPreparationActor was being garbage collected because the reference to it was not returned from create_model_and_optimizer(). When the function returned, _data_prep_actor went out of scope and Ray cleaned up the actor. Now we return and capture the reference to keep it alive during training. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 52 +++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 18 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 25c652202..1001b9c6e 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -2328,7 +2328,16 @@ def create_model_and_optimizer( ) logger.info("======== ✅ model update group setup successfully =========") - return policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims + return ( + policy_group, + vllm_engines, + tool_objects, + resume_training_step, + episode, + actor_manager, + model_dims, + _data_prep_actor, + ) def create_generation_configs(args: Args, streaming_config: data_loader_lib.StreamingDataLoaderConfig): @@ -2948,23 +2957,30 @@ def main( ) generation_configs = create_generation_configs(args, streaming_config) - (policy_group, vllm_engines, tool_objects, resume_training_step, episode, actor_manager, model_dims) = ( - create_model_and_optimizer( - args, - tc, - model_config, - beaker_config, - wandb_url, - tokenizer, - inference_results_Q, - param_prompt_Q, - evaluation_inference_results_Q, - streaming_config, - train_dataset, - eval_dataset, - reward_config, - generation_configs["train"], - ) + ( + policy_group, + vllm_engines, + tool_objects, + resume_training_step, + episode, + actor_manager, + model_dims, + _data_prep_actor, + ) = create_model_and_optimizer( + args, + tc, + model_config, + beaker_config, + wandb_url, + tokenizer, + inference_results_Q, + param_prompt_Q, + evaluation_inference_results_Q, + streaming_config, + train_dataset, + eval_dataset, + reward_config, + generation_configs["train"], ) checkpoint_state = None From c100144d17ee7055539c00347fd1341209c9e094 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 10:18:23 -0700 Subject: [PATCH 65/96] Fix rebase conflicts: remove duplicate code from main MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After rebasing on main, duplicate accumulate_inference_batches, BatchStatistics, and data_preparation_thread code came back. This code was moved to data_loader.py in the oc-dataloader branch. Removed the duplicate local definitions and added missing imports for GenerationResult and Batch. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 514 ------------------------------------- 1 file changed, 514 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 1001b9c6e..31c106bfd 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -48,7 +48,6 @@ # isort: on import asyncio import dataclasses -import json import logging import math import random @@ -1527,519 +1526,6 @@ def calculate_utilization_metrics( return utilization_metrics -@dataclass -class BatchStatistics: - prompt_lengths: list[int] - response_lengths: list[int] - filtered_prompts: int - filtered_prompts_zero: int - filtered_prompts_solved: int - filtered_prompts_nonzero: int - percent_solved_mean: float - no_resampled_prompts: int - total_prompts: int - - -def accumulate_inference_batches( - inference_results_Q: ray_queue.Queue, - args: Args, - generation_config: vllm_utils.SamplingConfig, - num_prompts: int, - model_dims: utils.ModelDims, - tokenizer: PreTrainedTokenizer, - prompt_dataset: Dataset, - data_loader: data_loader_lib.HFDataLoader | None = None, - param_prompt_Q: ray_queue.Queue | None = None, - actor_manager=None, - timeout: float | None = None, - active_sampling: bool = False, - filter_zero_std_samples: bool = False, - replenish_prompts: bool = False, - no_resampling_pass_rate: float | None = None, -) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: - """Accumulate multiple inference results into a single training batch. - - Args: - inference_results_Q: Queue containing individual GenerationResult objects (one per prompt) - args: Arguments containing vllm_num_engines and batch size info - generation_config: Generation config containing n (number of samples per prompt) - num_prompts: Number of prompts to accumulate - data_loader: Iterator over the dataloader for replenishing prompts. Required when - replenish_prompts=True or no_resampling_pass_rate is set. Can be None for - evaluation where all prompts are pre-queued. - prompt_dataset: Dataset containing prompts - param_prompt_Q: Queue containing prompts to send to generator. Required when - replenish_prompts=True. Can be None for evaluation where no replenishment is needed. - timeout: Optional timeout in seconds for queue get operations. If None, blocks indefinitely. - active_sampling: Whether to continue sampling until we have sampled num_prompts prompts with non-zero std - filter_zero_std_samples: Whether to filter samples with zero reward std - replenish_prompts: Add a prompt back onto the prompt_Q after receiving a finished result - no_resampling_pass_rate: Optional rate at which to note samples solved at greater than this rate - and exclude them from further sampling - - Raises: - queue.Empty: If timeout is specified and no data is available within timeout. - - Returns: - Tuple of (combined_result, Batch with queries, ground_truths, datasets, prompt_lengths, response_lengths) - or (ShutdownSentinel, None, None, None) if shutdown signal received - """ - if no_resampling_pass_rate is not None: - assert data_loader is not None, "no_resampling requires data_loader" - - if replenish_prompts: - assert param_prompt_Q is not None and data_loader is not None and prompt_dataset is not None, ( - "replenish_prompts requires param_prompt_Q, data_loader, and prompt_dataset" - ) - results = [] - all_queries = [] - all_ground_truths = [] - all_datasets = [] - all_raw_queries = [] - all_decoded_responses = [] - all_reward_metrics = [] - all_scores = [] - all_percent_solved = [] - total_filtered_prompts = 0 - filtered_prompt_zero = 0 - filtered_prompt_solved = 0 - filtered_prompt_nonzero = 0 - total_no_resampled = 0 - progress_bar = tqdm( - total=num_prompts, - desc=f"Accumulating Responses and Rewarding {num_prompts} prompts", - bar_format="{l_bar}{bar}{r_bar}\n", - disable=not args.verbose, - ) - num_prompts_sampled = 0 - while num_prompts_sampled < num_prompts: - result = inference_results_Q.get(timeout=timeout) - - if isinstance(result, ShutdownSentinel): - return result, None, None, None - - # Validate that each individual result has the expected number of responses - assert len(result.responses) == generation_config.n, ( - f"Mismatch: individual prompt result has {len(result.responses)} responses " - f"but expected {generation_config.n} samples per prompt. " - f"Prompt ID: {result.prompt_id}" - ) - - # Replenish generation queue with new prompt - if replenish_prompts: - add_prompt_to_generator(next(data_loader), param_prompt_Q, generation_config, is_eval=False) - - # TODO(finbarrtimbers): Move this to LLMRayActor. - for i in range(len(result.finish_reasons)): - if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: - result.responses[i].append(tokenizer.eos_token_id) - result.masks[i].append(1) - result.logprobs[i].append(float("nan")) - - decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) - - percent_solved = np.mean(result.reward_scores).item() / args.max_possible_score - # Don't resample prompt that was solved at more than no_resample_positive_rate - if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: - total_no_resampled += 1 - data_loader.exclude_index(result.dataset_index) - logging.debug( - f"[Data Preparation Thread] Prompt solved at {percent_solved}, total no resampled: {total_no_resampled}" - ) - - # Filter out zero std prompts - if filter_zero_std_samples and np.std(result.reward_scores) == 0: - # If we're not active sampling, still count this as a sample - if not active_sampling: - num_prompts_sampled += 1 - progress_bar.update(1) - - total_filtered_prompts += 1 - if result.reward_scores[0] == 0: - filtered_prompt_zero += 1 - elif result.reward_scores[0] == args.max_possible_score: - filtered_prompt_solved += 1 - else: - filtered_prompt_nonzero += 1 - logging.debug( - f"[Data Preparation Thread] Filtered prompt with reward std 0, total filtered {total_filtered_prompts}" - ) - continue - else: - num_prompts_sampled += 1 - progress_bar.update(1) - - results.append(result) - prompt_data = prompt_dataset[result.dataset_index] - all_queries.extend(repeat_each([prompt_data[INPUT_IDS_PROMPT_KEY]], generation_config.n)) - all_ground_truths.extend(repeat_each([prompt_data[GROUND_TRUTHS_KEY]], generation_config.n)) - all_datasets.extend(repeat_each([prompt_data[VERIFIER_SOURCE_KEY]], generation_config.n)) - all_raw_queries.extend(repeat_each([prompt_data[RAW_PROMPT_KEY]], generation_config.n)) - all_decoded_responses.extend(decoded_responses) - all_scores.extend(result.reward_scores) - all_reward_metrics.append(result.reward_metrics) - all_percent_solved.append(percent_solved) - - if len(results) == 0: - logger.warning( - "[Data Preparation Thread] All prompts were filtered during accumulation. " - f"Filtered: {total_filtered_prompts} (zero std: {filtered_prompt_zero}, " - f"solved: {filtered_prompt_solved}, nonzero: {filtered_prompt_nonzero})" - ) - return None, None, None, None - - # Combine all results into a single GenerationResult - combined_responses = [] - combined_finish_reasons = [] - combined_masks = [] - combined_num_calls = [] - combined_timeouts = [] - combined_tool_errors = [] - combined_tool_outputs = [] - combined_tool_runtimes = [] - combined_tool_calleds = [] - combined_logprobs = [] - - earliest_start_time = float("inf") - prompt_lengths = [] - response_lengths = [] - - total_prompt_tokens = 0 - total_response_tokens = 0 - max_generation_time = 0 - - for i, result in enumerate(results): - combined_responses.extend(result.responses) - combined_finish_reasons.extend(result.finish_reasons) - combined_masks.extend(result.masks) - combined_num_calls.extend(result.request_info.num_calls) - combined_timeouts.extend(result.request_info.timeouts) - combined_tool_errors.extend(result.request_info.tool_errors) - combined_tool_outputs.extend(result.request_info.tool_outputs) - combined_tool_runtimes.extend(result.request_info.tool_runtimes) - combined_tool_calleds.extend(result.request_info.tool_calleds) - - combined_logprobs.extend(result.logprobs) - - earliest_start_time = min(earliest_start_time, result.start_time) - - prompt_lengths.append(len(all_queries[i * generation_config.n])) - - for response in result.responses: - response_lengths.append(len(response)) - - total_prompt_tokens += result.token_statistics.num_prompt_tokens - total_response_tokens += result.token_statistics.num_response_tokens - max_generation_time = max(max_generation_time, result.token_statistics.generation_time) - - # Use the maximum generation time across engines since they work in parallel - # This avoids including queue overhead and accumulation time in MFU/MBU calculations - total_generation_time = max_generation_time - - accumulated_stats = TokenStatistics( - num_prompt_tokens=total_prompt_tokens, - num_response_tokens=total_response_tokens, - generation_time=total_generation_time, - earliest_start_time=earliest_start_time, - ) - - # Create combined RequestInfo - combined_request_info = RequestInfo( - num_calls=combined_num_calls, - timeouts=combined_timeouts, - tool_errors=combined_tool_errors, - tool_outputs=combined_tool_outputs, - tool_runtimes=combined_tool_runtimes, - tool_calleds=combined_tool_calleds, - ) - - # Create combined GenerationResult - combined_result = GenerationResult( - responses=combined_responses, - finish_reasons=combined_finish_reasons, - masks=combined_masks, - request_info=combined_request_info, - dataset_index=None, - prompt_id=None, - token_statistics=accumulated_stats, - logprobs=combined_logprobs, - ) - - if actor_manager is not None: - ray.get(actor_manager.report_token_statistics.remote(accumulated_stats)) - - # Note: We don't have dataset_indices here, but they're not needed for the returned batch - batch = Batch( - queries=all_queries, - ground_truths=all_ground_truths, - datasets=all_datasets, - raw_queries=all_raw_queries, - decoded_responses=all_decoded_responses, - indices=None, # Not meaningful for combined results - scores=all_scores, - ) - - combined_reward_metrics = combine_reward_metrics(all_reward_metrics) - percent_solved_mean = np.mean(all_percent_solved) if all_percent_solved else 0.0 - - batch_stats = BatchStatistics( - prompt_lengths=prompt_lengths, - response_lengths=response_lengths, - filtered_prompts=total_filtered_prompts, - filtered_prompts_zero=filtered_prompt_zero, - filtered_prompts_solved=filtered_prompt_solved, - filtered_prompts_nonzero=filtered_prompt_nonzero, - percent_solved_mean=percent_solved_mean, - no_resampled_prompts=total_no_resampled, - total_prompts=len(results), - ) - return combined_result, batch, combined_reward_metrics, batch_stats - - -def data_preparation_thread( - inference_results_Q: ray_queue.Queue, - param_prompt_Q: ray_queue.Queue, - packed_sequences_Q: Queue, - args: Args, - tokenizer: PreTrainedTokenizer, - num_training_steps: int, - generation_config, - resume_training_step: int, - data_loader: data_loader_lib.HFDataLoader, - train_dataset: Dataset, - actor_manager=None, - model_dims: utils.ModelDims = None, -): - for training_step in range(resume_training_step, num_training_steps + 1): - # Streaming accumulation: collect results as they arrive - with Timer("🚀 [Data Preparation Thread] Getting response ids") as timer: - result, batch, reward_metrics, batch_stats = accumulate_inference_batches( - inference_results_Q, - args, - generation_config, - num_prompts=args.num_unique_prompts_rollout, - model_dims=model_dims, - tokenizer=tokenizer, - data_loader=data_loader, - prompt_dataset=train_dataset, - param_prompt_Q=param_prompt_Q, - actor_manager=actor_manager, - active_sampling=args.active_sampling, - filter_zero_std_samples=args.filter_zero_std_samples, - replenish_prompts=True, - no_resampling_pass_rate=args.no_resampling_pass_rate, - ) - if isinstance(result, ShutdownSentinel): - logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting") - return - if result is None: - logger.info("[Data Preparation Thread] All prompts filtered, putting empty batch into queue") - packed_sequences = PackedSequences( - query_responses=[], - attention_masks=[], - response_masks=[], - original_responses=[], - advantages=[], - position_ids=[], - vllm_logprobs=[], - ) - collated_data = [] - packed_sequences_Q.put( - { - "packed_sequences": packed_sequences, - "collated_data": collated_data, - "metrics": {}, - "responses_count": 0, - "num_new_tokens": 0, - "B": 0, - "prompt_lengths": [], - "response_lengths": [], - "num_filtered_prompts": 0, - } - ) - continue - - getting_response_time = timer.duration - scores = np.array(batch.scores) - - good_outputs = [ - len(result.request_info.tool_outputs[i]) > 0 - and result.request_info.tool_calleds[i] - and not result.request_info.timeouts[i] - and not result.request_info.tool_errors[i] - for i in range(len(result.request_info.tool_outputs)) - ] - scores_per_prompt = scores.reshape(-1, args.num_samples_per_prompt_rollout) - mean_grouped_rewards = scores_per_prompt.mean(axis=-1) - mean_grouped_rewards = np.repeat(mean_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0) - std_grouped_rewards = scores_per_prompt.std(axis=-1) - std_grouped_rewards = np.repeat(std_grouped_rewards, args.num_samples_per_prompt_rollout, axis=0) - if args.advantage_normalization_type == "standard": - advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - elif args.advantage_normalization_type == "centered": - advantages = scores - mean_grouped_rewards - else: - raise ValueError(f"Invalid advantage normalization type: {args.advantage_normalization_type}") - - if args.mask_truncated_completions: - stop_idxes = torch.tensor( - [i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"] - ) - num_truncated = len(result.finish_reasons) - len(stop_idxes) - if num_truncated > 0: - logger.info( - f"[Truncated completions filtering] Filtered {num_truncated} responses that didn't finish with 'stop'. " - f"Retention rate: {len(stop_idxes) / len(result.finish_reasons):.2%}" - ) - scores = scores[stop_idxes] - advantages = advantages[stop_idxes] - batch = batch[stop_idxes.tolist()] - result.responses = [result.responses[i] for i in stop_idxes] - result.masks = [result.masks[i] for i in stop_idxes] - result.finish_reasons = [result.finish_reasons[i] for i in stop_idxes] - result.logprobs = [result.logprobs[i] for i in stop_idxes] - - with Timer("📦 [Data Preparation Thread] Packing sequences"): - packed_sequences = pack_sequences( - queries=batch.queries, - responses=result.responses, - masks=result.masks, - pack_length=args.pack_length, - pad_token_id=tokenizer.pad_token_id, - vllm_logprobs=result.logprobs, - mask_tool_use=args.mask_tool_use, - ) - num_new_tokens = sum(len(seq) for seq in packed_sequences.query_responses) - # Vectorized advantage calculation: create a lookup array where each index corresponds to a response mask value - # and each value is the corresponding advantage score: index 0 is set to 0 since response masks start from 1 (1-indexed) - lookup_advantages = np.zeros(len(advantages) + 1, dtype=np.float32) - lookup_advantages[1:] = advantages - packed_advantages = [ - torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) - for packed_mask in packed_sequences.response_masks - ] - packed_sequences.advantages = packed_advantages - - # if we have less batches than world size, we need to pad out so each world is fine - # ideally, you should avoid this since its wasting computation. - if args.allow_world_padding: - with Timer("🤺 [Data Preparation Thread] Padding sequences for world size"): - shortfall = args.world_size - len(packed_sequences.query_responses) - if shortfall > 0: - logger.warning( - f"Padding {shortfall} sequences for world size. In future, you should adjust your compute this." - ) - # construct "dummy" sequences for padding out the world size - dummy_qr = torch.tensor([tokenizer.pad_token_id, tokenizer.eos_token_id], dtype=torch.long) - dummy_attention = torch.tensor([1, 1], dtype=torch.long) - dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long) - dummy_response_mask = torch.zeros_like(dummy_qr) - dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float) - # pad out the world size - for _ in range(shortfall): - packed_sequences.query_responses.append(dummy_qr) - packed_sequences.attention_masks.append(dummy_attention) - packed_sequences.position_ids.append(dummy_position_ids) - packed_sequences.response_masks.append(dummy_response_mask) - packed_sequences.advantages.append(dummy_advantage) - - collated_data = prepare_collated_data_for_workers( - packed_sequences, args.world_size, args.per_device_train_batch_size, tokenizer.pad_token_id - ) - B = len(packed_sequences.query_responses) // args.world_size - - # Create a result package with metrics and data - if len(result.responses) == 0: - # Handle empty responses case - # in this case, we won't log metrics, so it should be fine. - metrics = {} - logger.warning(f"No responses in batch {training_step}.") - else: - real_num_responses = len(result.responses) - expected_num_responses = args.num_samples_per_prompt_rollout * args.num_unique_prompts_rollout - - unsolved_num_responses = (scores < args.max_possible_score).sum() - sequence_lengths = np.array([len(response) for response in result.responses]) - sequence_length_solved = ( - np.array([]) if np.all(scores == 0) else np.array(sequence_lengths[scores == args.max_possible_score]) - ) - sequence_length_unsolved = ( - np.array([]) if np.all(scores == args.max_possible_score) else np.array(sequence_lengths[scores == 0]) - ) - stop_rate = sum(int(finish_reason == "stop") for finish_reason in result.finish_reasons) / len( - result.finish_reasons - ) - - batch_metrics = asdict(batch_stats) - batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()} - - metrics = { - "scores": scores.mean(), - "real_batch_size_ratio": real_num_responses / expected_num_responses, - "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, - "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, - "val/solve_rate_hist": None, - "val/total_reward_groups": real_num_responses / args.num_samples_per_prompt_rollout, - "val/sequence_lengths": sequence_lengths.mean(), - "val/sequence_lengths_min": sequence_lengths.min(), - "val/sequence_lengths_max": sequence_lengths.max(), - "val/sequence_lengths_unsolved": ( - 0 if len(sequence_length_unsolved) == 0 else sequence_length_unsolved.mean() - ), - "val/sequence_lengths_solved": ( - 0 if len(sequence_length_solved) == 0 else sequence_length_solved.mean() - ), - "val/sequence_lengths_unsolved_hist": sequence_length_unsolved, - "val/sequence_lengths_solved_hist": sequence_length_solved, - "val/stop_rate": stop_rate, - "val/advantages_mean": advantages.mean(), - "val/advantages_min": advantages.min(), - "val/advantages_max": advantages.max(), - "val/advantages_hist": advantages, - "val/num_calls_rate": np.array(result.request_info.num_calls).mean(), - "val/timeouts_rate": np.array(result.request_info.timeouts).mean(), - "val/tool_errors_rate": np.array([len(item) > 0 for item in result.request_info.tool_errors]).mean(), - "val/good_outputs_rate": np.array(good_outputs).mean(), - "val/tool_runtimes_rate": np.array(result.request_info.tool_runtimes).mean(), - "val/tool_calleds_rate": np.array(result.request_info.tool_calleds).mean(), - "time/getting_response": getting_response_time, - **reward_metrics, - **batch_metrics_prefixed, - } - - total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens - metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time - - if args.save_traces: - traces = { - "scores": scores.tolist(), - "finish_reasons": result.finish_reasons, - "responses": result.responses, - "training_step": training_step, - **asdict(batch), # Unpack all batch fields - **reward_metrics, - } - os.makedirs(args.output_dir, exist_ok=True) - with open(f"{args.output_dir}/traces_{args.run_name}.jsonl", "a") as f: - json.dump(traces, f) - f.write("\n") - - # Put the packed sequences and metrics into the output queue - packed_sequences_Q.put( - { - "packed_sequences": packed_sequences, # for debugging purposes - "collated_data": collated_data, - "metrics": metrics, - "responses_count": len(result.responses), - "num_new_tokens": num_new_tokens, - "B": B, - "prompt_lengths": batch_stats.prompt_lengths, - "response_lengths": batch_stats.response_lengths, - "num_filtered_prompts": batch_stats.filtered_prompts, - } - ) - - def setup_runtime_variables(args: Args, streaming_config: data_loader_lib.StreamingDataLoaderConfig) -> Args: """Set up runtime variables for the experiment.""" args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" From 8073f61c53b77b767123d4609ae74c64045ae0cb Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 10:24:07 -0700 Subject: [PATCH 66/96] Fix args references in create_generation_configs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit After rebase, response_length and num_samples_per_prompt_rollout moved to StreamingDataLoaderConfig. Updated references to use streaming_config instead of args. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 31c106bfd..3dd1c19e8 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1831,8 +1831,8 @@ def create_generation_configs(args: Args, streaming_config: data_loader_lib.Stre generation_config = vllm_utils.SamplingConfig( temperature=args.temperature, top_p=args.vllm_top_p, - max_tokens=args.response_length, - n=args.num_samples_per_prompt_rollout, + max_tokens=streaming_config.response_length, + n=streaming_config.num_samples_per_prompt_rollout, stop=args.stop_strings, seed=args.seed, logprobs=1, From 3be6f6bec0eb6d2dbfd1eb6c7ca5fdafb0462898 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 12:38:21 -0700 Subject: [PATCH 67/96] Add debug logging for prompt flow investigation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/data_loader.py | 3 +++ open_instruct/vllm_utils.py | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index e90ef6e60..e0bcd1ab5 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -384,6 +384,9 @@ def add_prompt_to_generator( is_eval: bool, ) -> None: query = example[INPUT_IDS_PROMPT_KEY] + logger.info( + f"[add_prompt_to_generator] Adding prompt: dataset_index={example['dataset_index']}, epoch={epoch_number}, step={training_step}" + ) param_prompt_Q.put( PromptRequest( diff --git a/open_instruct/vllm_utils.py b/open_instruct/vllm_utils.py index efdade838..0875a2c6c 100644 --- a/open_instruct/vllm_utils.py +++ b/open_instruct/vllm_utils.py @@ -395,12 +395,17 @@ async def _check_health(port: int) -> None: def _prefetch_worker(actor: "LLMRayActor") -> None: + logger.info("[_prefetch_worker] Starting prefetch worker loop") while True: if actor._should_stop() or len(actor.active_tasks) >= actor.inference_batch_size: time.sleep(DRAIN_ACTIVE_TASKS_SLEEP_S) continue + logger.info(f"[_prefetch_worker] Waiting for request, active_tasks={len(actor.active_tasks)}") request = actor.prompt_queue.get() + logger.info( + f"[_prefetch_worker] Got request: dataset_index={request.dataset_index}, is_eval={request.is_eval}" + ) add_request(actor, request) From 402d58d9c27f1d83b41814926dda48e3875d8118 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 12:44:56 -0700 Subject: [PATCH 68/96] Add more debug logging for HTTP calls and completions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/vllm_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/open_instruct/vllm_utils.py b/open_instruct/vllm_utils.py index 0875a2c6c..ef43154a7 100644 --- a/open_instruct/vllm_utils.py +++ b/open_instruct/vllm_utils.py @@ -444,6 +444,7 @@ def _create_server_args(model_path: str) -> argparse.Namespace: def accumulate_completions(actor: "LLMRayActor", sub_request: dict) -> None: base_request_id = sub_request["base_request_id"] expected_n = sub_request["expected_n"] + logger.info(f"[accumulate_completions] {base_request_id}: received sub-request") if base_request_id not in actor.request_outputs: actor.request_outputs[base_request_id] = { @@ -477,6 +478,7 @@ async def finalize_completed_request(actor: "LLMRayActor", base_request_id: str) dataset = actor.eval_dataset if is_eval else actor.train_dataset result.reward_scores, result.reward_metrics = await compute_rewards(actor, result, dataset, is_eval) results_queue = actor.eval_results_queue if is_eval else actor.results_queue + logger.info(f"[finalize_completed_request] {base_request_id}: Putting result in queue (is_eval={is_eval})") results_queue.put(result) @@ -803,6 +805,7 @@ async def process_request(actor: LLMRayActor, sub_request_id: str, sampling_para while True: current_sampling_params = dataclasses.replace(sampling_params, max_tokens=current_max_tokens) + logger.info(f"[process_request] {sub_request_id}: Making API call with max_tokens={current_max_tokens}") api_response = await actor.client.completions.create( model=actor.model_name, prompt=current_prompt, @@ -814,6 +817,7 @@ async def process_request(actor: LLMRayActor, sub_request_id: str, sampling_para }, **dataclasses.asdict(current_sampling_params), ) + logger.info(f"[process_request] {sub_request_id}: Got API response") output = api_response.choices[0] model_tokens = list(output.token_ids) From 037d3c87ce484fce764e5af178e51c611a037719 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 12:57:38 -0700 Subject: [PATCH 69/96] Add debug logging to trace data flow between DataPreparationActor and training MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix missing GenerationResult import in grpo_fast.py - Add logging in accumulate_inference_batches to trace when results are received - Add logging in DataPreparationActor._data_preparation_loop to trace startup 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/data_loader.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index e0bcd1ab5..c39b511d3 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -450,8 +450,13 @@ def accumulate_inference_batches( disable=not verbose, ) num_prompts_sampled = 0 + logger.info(f"[accumulate_inference_batches] Starting to accumulate {num_prompts} prompts") while num_prompts_sampled < num_prompts: + logger.info(f"[accumulate_inference_batches] Waiting for result {num_prompts_sampled + 1}/{num_prompts}") result = inference_results_Q.get(timeout=timeout) + logger.info( + f"[accumulate_inference_batches] Got result: {result.dataset_index if hasattr(result, 'dataset_index') else type(result)}" + ) if isinstance(result, ShutdownSentinel): return result, None, None, None @@ -770,6 +775,9 @@ def __init__( self._prep_future = self._executor.submit(self._data_preparation_loop) def _data_preparation_loop(self): + logger.info( + f"[DataPreparationActor] Starting data preparation loop, async_steps={self.config.async_steps}, global_batch_size={self.global_batch_size}" + ) for _ in range(self.config.async_steps * self.global_batch_size): add_prompt_to_generator( next(self.iter_dataloader), @@ -779,6 +787,9 @@ def _data_preparation_loop(self): self.generation_config, is_eval=False, ) + logger.info( + f"[DataPreparationActor] Initial prompts submitted, entering main loop for {self.num_training_steps} steps" + ) for step in range(self.training_step, self.num_training_steps): if self.shutdown_requested: From 181212837b3325a4767b3c3ad108435c26c1bb16 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 12:58:04 -0700 Subject: [PATCH 70/96] Fix GenerationResult import in grpo_fast.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3dd1c19e8..eb12f7562 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -101,7 +101,7 @@ print_rich_table, push_folder_to_hub, ) -from open_instruct.queue_types import ShutdownSentinel +from open_instruct.queue_types import GenerationResult, ShutdownSentinel from open_instruct.rl_utils import PackedSequences, Timer, masked_mean from open_instruct.utils import ( ArgumentParserPlus, From c8dab3125aad3af805c32127c38b37d88da7323a Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 13:04:41 -0700 Subject: [PATCH 71/96] Clean up stale Ray sessions in ray_node_setup.sh MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- configs/beaker_configs/ray_node_setup.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/ray_node_setup.sh index 63c71f30f..bd3d82cd6 100644 --- a/configs/beaker_configs/ray_node_setup.sh +++ b/configs/beaker_configs/ray_node_setup.sh @@ -17,6 +17,7 @@ BEAKER_LEADER_REPLICA_IP=$(getent hosts ${BEAKER_LEADER_REPLICA_HOSTNAME} | awk RAY_NODE_PORT=8888 mkdir -p "$HOME/.triton/autotune" # Create Triton autotune cache directory to silence warnings ray stop --force +rm -rf /tmp/ray/session_* 2>/dev/null || true # Clean up stale Ray sessions if [ "$BEAKER_REPLICA_RANK" == "0" ]; then echo "Starting Ray head node" From cd9a292d7e90600e5cc7f9131a5906fbec7b676b Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 13:14:03 -0700 Subject: [PATCH 72/96] Add exception handling and debug logging to DataPreparationActor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Wrap _data_preparation_loop in try/except to catch and log exceptions - Add logging before/after accumulate_inference_batches call 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/data_loader.py | 11 +++++++++++ open_instruct/grpo_fast.py | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index c39b511d3..f91a5d051 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -775,6 +775,13 @@ def __init__( self._prep_future = self._executor.submit(self._data_preparation_loop) def _data_preparation_loop(self): + try: + self._data_preparation_loop_inner() + except Exception: + logger.exception("[DataPreparationActor] Exception in data preparation loop") + raise + + def _data_preparation_loop_inner(self): logger.info( f"[DataPreparationActor] Starting data preparation loop, async_steps={self.config.async_steps}, global_batch_size={self.global_batch_size}" ) @@ -796,6 +803,7 @@ def _data_preparation_loop(self): logger.info("[DataPreparationActor] Shutdown requested, exiting") return + logger.info(f"[DataPreparationActor] Step {step}: calling accumulate_inference_batches") result, batch, reward_metrics, batch_stats = accumulate_inference_batches( self.inference_results_Q, self.generation_config, @@ -814,6 +822,9 @@ def _data_preparation_loop(self): verbose=self.verbose, max_possible_score=self.config.max_possible_score, ) + logger.info( + f"[DataPreparationActor] Step {step}: accumulate_inference_batches returned, result is None: {result is None}" + ) if isinstance(result, ShutdownSentinel): logger.info("[DataPreparationActor] Received shutdown sentinel, exiting") diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index eb12f7562..3dd1c19e8 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -101,7 +101,7 @@ print_rich_table, push_folder_to_hub, ) -from open_instruct.queue_types import GenerationResult, ShutdownSentinel +from open_instruct.queue_types import ShutdownSentinel from open_instruct.rl_utils import PackedSequences, Timer, masked_mean from open_instruct.utils import ( ArgumentParserPlus, From 55719e3985335be9636f45f69f8d5210a5108f79 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 13:17:05 -0700 Subject: [PATCH 73/96] Improve Ray cleanup to remove entire /tmp/ray directory MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- configs/beaker_configs/ray_node_setup.sh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/ray_node_setup.sh index bd3d82cd6..631ab1b1c 100644 --- a/configs/beaker_configs/ray_node_setup.sh +++ b/configs/beaker_configs/ray_node_setup.sh @@ -17,7 +17,8 @@ BEAKER_LEADER_REPLICA_IP=$(getent hosts ${BEAKER_LEADER_REPLICA_HOSTNAME} | awk RAY_NODE_PORT=8888 mkdir -p "$HOME/.triton/autotune" # Create Triton autotune cache directory to silence warnings ray stop --force -rm -rf /tmp/ray/session_* 2>/dev/null || true # Clean up stale Ray sessions +rm -rf /tmp/ray 2>/dev/null || true # Clean up entire Ray temp directory to avoid stale session issues +sleep 2 # Wait for cleanup to complete if [ "$BEAKER_REPLICA_RANK" == "0" ]; then echo "Starting Ray head node" From 0d45523b9c69c2194c6a41d6d459210baccc6710 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 13:26:06 -0700 Subject: [PATCH 74/96] Fix pin_memory() crash on CPU-only DataPreparationActor MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit DataPreparationActor runs without GPU (num_cpus=2), but collate_fn() called pin_memory() unconditionally, causing RuntimeError: No CUDA GPUs are available. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/data_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index f91a5d051..938ea286c 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -304,7 +304,7 @@ def shutdown(self): def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) - if pin_memory: + if pin_memory and torch.cuda.is_available(): padded_tensor = padded_tensor.pin_memory() return padded_tensor From fb77ef3c2bcd2e0837842bb92b2c74794fb4887e Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 13:29:16 -0700 Subject: [PATCH 75/96] Use unique temp-dir per experiment to avoid Ray session conflicts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Force kill all Ray processes, use per-experiment temp directory to avoid stale GCS session state from previous experiments. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- configs/beaker_configs/ray_node_setup.sh | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/ray_node_setup.sh index 631ab1b1c..c922d051f 100644 --- a/configs/beaker_configs/ray_node_setup.sh +++ b/configs/beaker_configs/ray_node_setup.sh @@ -16,18 +16,23 @@ BEAKER_LEADER_REPLICA_IP=$(getent hosts ${BEAKER_LEADER_REPLICA_HOSTNAME} | awk RAY_NODE_PORT=8888 mkdir -p "$HOME/.triton/autotune" # Create Triton autotune cache directory to silence warnings -ray stop --force -rm -rf /tmp/ray 2>/dev/null || true # Clean up entire Ray temp directory to avoid stale session issues -sleep 2 # Wait for cleanup to complete +pkill -9 -f "ray::" 2>/dev/null || true +pkill -9 -f "gcs_server" 2>/dev/null || true +pkill -9 -f "raylet" 2>/dev/null || true +ray stop --force 2>/dev/null || true +rm -rf /tmp/ray* 2>/dev/null || true +sleep 2 + +RAY_TEMP_DIR="/tmp/ray_${BEAKER_EXPERIMENT_ID:-$$}" +mkdir -p "$RAY_TEMP_DIR" if [ "$BEAKER_REPLICA_RANK" == "0" ]; then - echo "Starting Ray head node" - ray start --head --port=$RAY_NODE_PORT --dashboard-host=0.0.0.0 + echo "Starting Ray head node with temp dir: $RAY_TEMP_DIR" + ray start --head --port=$RAY_NODE_PORT --dashboard-host=0.0.0.0 --temp-dir="$RAY_TEMP_DIR" else - echo "Starting Ray worker node $BEAKER_REPLICA_RANK" + echo "Starting Ray worker node $BEAKER_REPLICA_RANK with temp dir: $RAY_TEMP_DIR" export RAY_ADDRESS="${BEAKER_LEADER_REPLICA_IP}:${RAY_NODE_PORT}" - # Start worker without --block so we can control lifecycle and exit code. - ray start --address="${RAY_ADDRESS}" --dashboard-host=0.0.0.0 + ray start --address="${RAY_ADDRESS}" --dashboard-host=0.0.0.0 --temp-dir="$RAY_TEMP_DIR" cleanup() { echo "[ray_node_setup] Cleanup: stopping Ray worker and exiting 0" From 2e91a63d3c8dc93848ae1b0f6b67b01516096da5 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 13:32:13 -0700 Subject: [PATCH 76/96] More aggressive Ray cleanup and use explicit --storage path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Kill redis-server and plasma_store processes - Clean up /dev/shm/plasma* and ~/.ray - Add timestamp to temp dir for uniqueness - Use --storage parameter for isolated GCS storage 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- configs/beaker_configs/ray_node_setup.sh | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/ray_node_setup.sh index c922d051f..427939524 100644 --- a/configs/beaker_configs/ray_node_setup.sh +++ b/configs/beaker_configs/ray_node_setup.sh @@ -15,20 +15,29 @@ echo PATH=$PATH BEAKER_LEADER_REPLICA_IP=$(getent hosts ${BEAKER_LEADER_REPLICA_HOSTNAME} | awk '{print $1}') RAY_NODE_PORT=8888 -mkdir -p "$HOME/.triton/autotune" # Create Triton autotune cache directory to silence warnings +mkdir -p "$HOME/.triton/autotune" + +echo "Cleaning up any existing Ray processes..." pkill -9 -f "ray::" 2>/dev/null || true pkill -9 -f "gcs_server" 2>/dev/null || true pkill -9 -f "raylet" 2>/dev/null || true +pkill -9 -f "redis-server" 2>/dev/null || true +pkill -9 -f "plasma_store" 2>/dev/null || true ray stop --force 2>/dev/null || true +sleep 3 + rm -rf /tmp/ray* 2>/dev/null || true +rm -rf /dev/shm/plasma* 2>/dev/null || true +rm -rf ~/.ray 2>/dev/null || true sleep 2 -RAY_TEMP_DIR="/tmp/ray_${BEAKER_EXPERIMENT_ID:-$$}" +RAY_TEMP_DIR="/tmp/ray_${BEAKER_EXPERIMENT_ID:-$$}_$(date +%s)" mkdir -p "$RAY_TEMP_DIR" if [ "$BEAKER_REPLICA_RANK" == "0" ]; then echo "Starting Ray head node with temp dir: $RAY_TEMP_DIR" - ray start --head --port=$RAY_NODE_PORT --dashboard-host=0.0.0.0 --temp-dir="$RAY_TEMP_DIR" + RAY_STORAGE_DIR="$RAY_TEMP_DIR/storage" + ray start --head --port=$RAY_NODE_PORT --dashboard-host=0.0.0.0 --temp-dir="$RAY_TEMP_DIR" --storage="$RAY_STORAGE_DIR" else echo "Starting Ray worker node $BEAKER_REPLICA_RANK with temp dir: $RAY_TEMP_DIR" export RAY_ADDRESS="${BEAKER_LEADER_REPLICA_IP}:${RAY_NODE_PORT}" From 03d7d8dc21e44b2038c9ce3741e43503bb919a49 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 13:34:57 -0700 Subject: [PATCH 77/96] Remove invalid --storage flag from ray start MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- configs/beaker_configs/ray_node_setup.sh | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/ray_node_setup.sh index 427939524..540f8cba7 100644 --- a/configs/beaker_configs/ray_node_setup.sh +++ b/configs/beaker_configs/ray_node_setup.sh @@ -36,8 +36,7 @@ mkdir -p "$RAY_TEMP_DIR" if [ "$BEAKER_REPLICA_RANK" == "0" ]; then echo "Starting Ray head node with temp dir: $RAY_TEMP_DIR" - RAY_STORAGE_DIR="$RAY_TEMP_DIR/storage" - ray start --head --port=$RAY_NODE_PORT --dashboard-host=0.0.0.0 --temp-dir="$RAY_TEMP_DIR" --storage="$RAY_STORAGE_DIR" + ray start --head --port=$RAY_NODE_PORT --dashboard-host=0.0.0.0 --temp-dir="$RAY_TEMP_DIR" else echo "Starting Ray worker node $BEAKER_REPLICA_RANK with temp dir: $RAY_TEMP_DIR" export RAY_ADDRESS="${BEAKER_LEADER_REPLICA_IP}:${RAY_NODE_PORT}" From 206e1fc4d7ba6faed91de59cbd6bae6011b93d7f Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 13:38:53 -0700 Subject: [PATCH 78/96] Debug Ray session conflict: random port, list /tmp, more cleanup MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Use random port to avoid conflicts with stale Ray processes - List /tmp before Ray start for debugging - Clean more directories including /dev/shm/* - Set RAY_TMPDIR environment variable - Add --disable-usage-stats 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- configs/beaker_configs/ray_node_setup.sh | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/ray_node_setup.sh index 540f8cba7..407ad2a0a 100644 --- a/configs/beaker_configs/ray_node_setup.sh +++ b/configs/beaker_configs/ray_node_setup.sh @@ -14,7 +14,7 @@ echo PATH=$PATH BEAKER_LEADER_REPLICA_IP=$(getent hosts ${BEAKER_LEADER_REPLICA_HOSTNAME} | awk '{print $1}') -RAY_NODE_PORT=8888 +RAY_NODE_PORT=$((8800 + RANDOM % 100)) mkdir -p "$HOME/.triton/autotune" echo "Cleaning up any existing Ray processes..." @@ -23,20 +23,26 @@ pkill -9 -f "gcs_server" 2>/dev/null || true pkill -9 -f "raylet" 2>/dev/null || true pkill -9 -f "redis-server" 2>/dev/null || true pkill -9 -f "plasma_store" 2>/dev/null || true +pkill -9 -f "log_monitor" 2>/dev/null || true +pkill -9 -f "monitor.py" 2>/dev/null || true ray stop --force 2>/dev/null || true sleep 3 rm -rf /tmp/ray* 2>/dev/null || true -rm -rf /dev/shm/plasma* 2>/dev/null || true +rm -rf /dev/shm/* 2>/dev/null || true rm -rf ~/.ray 2>/dev/null || true +rm -rf /run/user/*/ray* 2>/dev/null || true sleep 2 RAY_TEMP_DIR="/tmp/ray_${BEAKER_EXPERIMENT_ID:-$$}_$(date +%s)" mkdir -p "$RAY_TEMP_DIR" +export RAY_TMPDIR="$RAY_TEMP_DIR" if [ "$BEAKER_REPLICA_RANK" == "0" ]; then echo "Starting Ray head node with temp dir: $RAY_TEMP_DIR" - ray start --head --port=$RAY_NODE_PORT --dashboard-host=0.0.0.0 --temp-dir="$RAY_TEMP_DIR" + echo "Listing /tmp before start:" + ls -la /tmp/ || true + ray start --head --port=$RAY_NODE_PORT --dashboard-host=0.0.0.0 --temp-dir="$RAY_TEMP_DIR" --disable-usage-stats else echo "Starting Ray worker node $BEAKER_REPLICA_RANK with temp dir: $RAY_TEMP_DIR" export RAY_ADDRESS="${BEAKER_LEADER_REPLICA_IP}:${RAY_NODE_PORT}" From 58003f4f1ef0fd780e04edfde892983130f6325c Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 13:41:55 -0700 Subject: [PATCH 79/96] Shorten Ray temp dir path to avoid Unix socket length limit MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit AF_UNIX socket paths must be <= 107 bytes. Use shorter temp dir name. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- configs/beaker_configs/ray_node_setup.sh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/ray_node_setup.sh index 407ad2a0a..9d1bab0ea 100644 --- a/configs/beaker_configs/ray_node_setup.sh +++ b/configs/beaker_configs/ray_node_setup.sh @@ -34,14 +34,12 @@ rm -rf ~/.ray 2>/dev/null || true rm -rf /run/user/*/ray* 2>/dev/null || true sleep 2 -RAY_TEMP_DIR="/tmp/ray_${BEAKER_EXPERIMENT_ID:-$$}_$(date +%s)" +RAY_TEMP_DIR="/tmp/r_$(date +%s)" mkdir -p "$RAY_TEMP_DIR" export RAY_TMPDIR="$RAY_TEMP_DIR" if [ "$BEAKER_REPLICA_RANK" == "0" ]; then echo "Starting Ray head node with temp dir: $RAY_TEMP_DIR" - echo "Listing /tmp before start:" - ls -la /tmp/ || true ray start --head --port=$RAY_NODE_PORT --dashboard-host=0.0.0.0 --temp-dir="$RAY_TEMP_DIR" --disable-usage-stats else echo "Starting Ray worker node $BEAKER_REPLICA_RANK with temp dir: $RAY_TEMP_DIR" From 63903c560d1a327d54a962e453378409dcf35fe5 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 13:46:40 -0700 Subject: [PATCH 80/96] Fix missing time/getting_response metric key error MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use .get() with default value since StreamingDataLoader doesn't populate this metric yet. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3dd1c19e8..4bfba0064 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1967,7 +1967,7 @@ def one_training_step( step_time = time.perf_counter() - start_time total_training_time = time.perf_counter() - training_start_time - total_generation_time = average_metrics["time/getting_response"] + total_generation_time = data_thread_metrics.get("time/getting_response", 0.0) prompt_lengths = array_metrics[0]["batch/prompt_lengths"] response_lengths = array_metrics[0]["batch/response_lengths"] num_step_tokens = sum(prompt_lengths) + sum(response_lengths) From 12fe066e22052a4862af69cffc8ab49408e786c1 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Thu, 4 Dec 2025 16:42:58 -0700 Subject: [PATCH 81/96] Skip training step when batch is empty (matching main behavior) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When all training actors return empty metrics (indicating empty batch), skip the rest of the training step to avoid KeyError on batch/prompt_lengths. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 4bfba0064..cb72c5bb1 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1927,6 +1927,9 @@ def one_training_step( desc=f"Running training step {training_step}", ) metrics, array_metrics = zip(*results) + if all(len(m) == 0 for m in metrics): + logger.warning("[Main Thread] 🤡 After packing, there is not enough data to train") + return if ( args.load_ref_policy and args.ref_policy_update_freq is not None From ea15d595261c622e0bd02c3e06d432db93ae6aab Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 08:41:01 -0700 Subject: [PATCH 82/96] Fix multi-node Ray: use fixed port instead of random MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The random port ($RANDOM) generates different values on each node, breaking multi-node Ray clusters. Worker nodes couldn't connect to the head because they used a different port. Reverts to fixed port 6379 (default Ray port). The aggressive cleanup added in 206e1fc4 handles stale processes. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- configs/beaker_configs/ray_node_setup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/ray_node_setup.sh index 9d1bab0ea..6bdc84b82 100644 --- a/configs/beaker_configs/ray_node_setup.sh +++ b/configs/beaker_configs/ray_node_setup.sh @@ -14,7 +14,7 @@ echo PATH=$PATH BEAKER_LEADER_REPLICA_IP=$(getent hosts ${BEAKER_LEADER_REPLICA_HOSTNAME} | awk '{print $1}') -RAY_NODE_PORT=$((8800 + RANDOM % 100)) +RAY_NODE_PORT=6379 mkdir -p "$HOME/.triton/autotune" echo "Cleaning up any existing Ray processes..." From e40f7b478bac70ad580d52fba5b1c13a285dd141 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 09:11:59 -0700 Subject: [PATCH 83/96] Revert Ray port to 8888 (matching main branch) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- configs/beaker_configs/ray_node_setup.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/ray_node_setup.sh index 6bdc84b82..43f7121b4 100644 --- a/configs/beaker_configs/ray_node_setup.sh +++ b/configs/beaker_configs/ray_node_setup.sh @@ -14,7 +14,7 @@ echo PATH=$PATH BEAKER_LEADER_REPLICA_IP=$(getent hosts ${BEAKER_LEADER_REPLICA_HOSTNAME} | awk '{print $1}') -RAY_NODE_PORT=6379 +RAY_NODE_PORT=8888 mkdir -p "$HOME/.triton/autotune" echo "Cleaning up any existing Ray processes..." From c3961e0049a9597cff097e11cdde642eb87b2833 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 09:17:06 -0700 Subject: [PATCH 84/96] Add debug logging for Ray cluster resources before vLLM placement group MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/vllm_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/open_instruct/vllm_utils.py b/open_instruct/vllm_utils.py index ef43154a7..38101824c 100644 --- a/open_instruct/vllm_utils.py +++ b/open_instruct/vllm_utils.py @@ -978,8 +978,15 @@ def create_vllm_engines( if not use_hybrid_engine: # Create a big placement group to ensure that all engines are packed bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_engines * tensor_parallel_size)] + cluster_resources = ray.cluster_resources() + available_resources = ray.available_resources() + logger.info(f"[DEBUG] Cluster resources: {cluster_resources}") + logger.info(f"[DEBUG] Available resources: {available_resources}") + logger.info(f"[DEBUG] Creating vLLM placement group with {len(bundles)} bundles") pg = placement_group(bundles, strategy="PACK") + logger.info(f"[DEBUG] Waiting for vLLM placement group...") ray.get(pg.ready()) + logger.info(f"[DEBUG] vLLM placement group ready!") # ensure we use bundles on the same node where possible if tp>1. bundle_indices_list = get_bundle_indices_list(pg) From 5d9bb38640fd52ed06de9a5e6df307874c3339e1 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 09:22:37 -0700 Subject: [PATCH 85/96] Wait for all expected GPUs before creating placement groups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Worker nodes may take time to fully register their GPUs with Ray. This ensures all expected resources are available before proceeding. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index cb72c5bb1..74a218b39 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1655,6 +1655,20 @@ def create_model_and_optimizer( generation_config, ) -> tuple[ModelGroup, list[vllm_utils.LLMRayActor], dict, int, int]: """Create the model, optimizer, and vLLM engines.""" + # Wait for all expected GPUs to be available in the cluster + # This ensures worker nodes have fully registered their resources + expected_gpus = sum(args.num_learners_per_node) + args.vllm_num_engines * args.vllm_tensor_parallel_size + logger.info(f"[DEBUG] Waiting for {expected_gpus} GPUs to be available in cluster...") + for i in range(60): # Wait up to 60 seconds + cluster_resources = ray.cluster_resources() + available_gpus = cluster_resources.get("GPU", 0) + logger.info(f"[DEBUG] Cluster has {available_gpus} GPUs (need {expected_gpus})") + if available_gpus >= expected_gpus: + break + time.sleep(1) + else: + logger.warning(f"[WARNING] Only {available_gpus} GPUs available, expected {expected_gpus}") + # Create placement group bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.num_learners_per_node] pg = placement_group(bundles, strategy="STRICT_SPREAD") From 4f81e89c6a6b685a8cc174b98ca942be8056f6fc Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 09:27:02 -0700 Subject: [PATCH 86/96] Revert ray_node_setup.sh to main branch version MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reverting to main's simpler setup to debug worker GPU detection issue. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- configs/beaker_configs/ray_node_setup.sh | 33 +++++------------------- 1 file changed, 7 insertions(+), 26 deletions(-) diff --git a/configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/ray_node_setup.sh index 43f7121b4..63c71f30f 100644 --- a/configs/beaker_configs/ray_node_setup.sh +++ b/configs/beaker_configs/ray_node_setup.sh @@ -15,36 +15,17 @@ echo PATH=$PATH BEAKER_LEADER_REPLICA_IP=$(getent hosts ${BEAKER_LEADER_REPLICA_HOSTNAME} | awk '{print $1}') RAY_NODE_PORT=8888 -mkdir -p "$HOME/.triton/autotune" - -echo "Cleaning up any existing Ray processes..." -pkill -9 -f "ray::" 2>/dev/null || true -pkill -9 -f "gcs_server" 2>/dev/null || true -pkill -9 -f "raylet" 2>/dev/null || true -pkill -9 -f "redis-server" 2>/dev/null || true -pkill -9 -f "plasma_store" 2>/dev/null || true -pkill -9 -f "log_monitor" 2>/dev/null || true -pkill -9 -f "monitor.py" 2>/dev/null || true -ray stop --force 2>/dev/null || true -sleep 3 - -rm -rf /tmp/ray* 2>/dev/null || true -rm -rf /dev/shm/* 2>/dev/null || true -rm -rf ~/.ray 2>/dev/null || true -rm -rf /run/user/*/ray* 2>/dev/null || true -sleep 2 - -RAY_TEMP_DIR="/tmp/r_$(date +%s)" -mkdir -p "$RAY_TEMP_DIR" -export RAY_TMPDIR="$RAY_TEMP_DIR" +mkdir -p "$HOME/.triton/autotune" # Create Triton autotune cache directory to silence warnings +ray stop --force if [ "$BEAKER_REPLICA_RANK" == "0" ]; then - echo "Starting Ray head node with temp dir: $RAY_TEMP_DIR" - ray start --head --port=$RAY_NODE_PORT --dashboard-host=0.0.0.0 --temp-dir="$RAY_TEMP_DIR" --disable-usage-stats + echo "Starting Ray head node" + ray start --head --port=$RAY_NODE_PORT --dashboard-host=0.0.0.0 else - echo "Starting Ray worker node $BEAKER_REPLICA_RANK with temp dir: $RAY_TEMP_DIR" + echo "Starting Ray worker node $BEAKER_REPLICA_RANK" export RAY_ADDRESS="${BEAKER_LEADER_REPLICA_IP}:${RAY_NODE_PORT}" - ray start --address="${RAY_ADDRESS}" --dashboard-host=0.0.0.0 --temp-dir="$RAY_TEMP_DIR" + # Start worker without --block so we can control lifecycle and exit code. + ray start --address="${RAY_ADDRESS}" --dashboard-host=0.0.0.0 cleanup() { echo "[ray_node_setup] Cleanup: stopping Ray worker and exiting 0" From 8c0ab3020ddf322e2b0a5b511911bc989830809a Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 09:52:36 -0700 Subject: [PATCH 87/96] Remove GPU wait loop and debug logging MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove debugging additions that may have been interfering with multi-node Ray cluster setup. This reverts to cleaner code that should allow worker nodes to join properly. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- open_instruct/grpo_fast.py | 47 ------------------------------------- open_instruct/vllm_utils.py | 22 ----------------- 2 files changed, 69 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 74a218b39..cbcf5322d 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -641,11 +641,7 @@ def load(self, path: str, map_location=None): np.random.seed(worker_seed) random.seed(worker_seed) - logger.info( - f"[DEBUG] Rank {self.rank}: Initializing DeepSpeed distributed (timeout={args.backend_timeout} minutes)..." - ) deepspeed.init_distributed(timeout=timedelta(minutes=args.backend_timeout)) - logger.info(f"[DEBUG] Rank {self.rank}: DeepSpeed distributed initialized successfully") ds_config = get_train_ds_config( offload=args.deepspeed_offload_param, @@ -805,11 +801,8 @@ def forward( return logprob, entropy def setup_model_update_group(self, vllm_engines): - logger = logger_utils.setup_logger(__name__) - logger.info(f"[DEBUG] Rank {self.rank}: Entered setup_model_update_group") self.vllm_engines = vllm_engines if self.rank == 0: - logger.info(f"[DEBUG] Rank 0: Initializing process group for {len(vllm_engines)} vLLM engines") master_address = ray._private.services.get_node_ip_address() with socket.socket() as sock: sock.bind(("", 0)) @@ -820,10 +813,6 @@ def setup_model_update_group(self, vllm_engines): ) world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 backend = self.args.vllm_sync_backend - logger.info( - f"[DEBUG] Rank 0: master_address={master_address}, master_port={master_port}, " - f"world_size={world_size}, backend={backend}" - ) refs = [ engine.init_process_group.remote( master_address, @@ -844,15 +833,8 @@ def setup_model_update_group(self, vllm_engines): group_name="openrlhf", timeout=timedelta(minutes=self.args.backend_timeout), ) - logger.info( - f"[DEBUG] Rank 0: Waiting for {len(refs)} vLLM engines to initialize process groups (timeout=600s)..." - ) ray_get_with_progress(refs, desc="Initializing vLLM process groups", timeout=600) - logger.info("[DEBUG] Rank 0: All vLLM engines initialized, approaching barrier") - else: - logger.info(f"[DEBUG] Rank {self.rank}: Approaching barrier") torch.distributed.barrier() - logger.info(f"[DEBUG] Rank {self.rank}: Passed barrier successfully") def broadcast_to_vllm(self): # avoid OOM @@ -1655,20 +1637,6 @@ def create_model_and_optimizer( generation_config, ) -> tuple[ModelGroup, list[vllm_utils.LLMRayActor], dict, int, int]: """Create the model, optimizer, and vLLM engines.""" - # Wait for all expected GPUs to be available in the cluster - # This ensures worker nodes have fully registered their resources - expected_gpus = sum(args.num_learners_per_node) + args.vllm_num_engines * args.vllm_tensor_parallel_size - logger.info(f"[DEBUG] Waiting for {expected_gpus} GPUs to be available in cluster...") - for i in range(60): # Wait up to 60 seconds - cluster_resources = ray.cluster_resources() - available_gpus = cluster_resources.get("GPU", 0) - logger.info(f"[DEBUG] Cluster has {available_gpus} GPUs (need {expected_gpus})") - if available_gpus >= expected_gpus: - break - time.sleep(1) - else: - logger.warning(f"[WARNING] Only {available_gpus} GPUs available, expected {expected_gpus}") - # Create placement group bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.num_learners_per_node] pg = placement_group(bundles, strategy="STRICT_SPREAD") @@ -1734,16 +1702,10 @@ def create_model_and_optimizer( train_dataset=train_dataset, eval_dataset=eval_dataset, ) - logger.info(f"[DEBUG] Created {len(vllm_engines)} vLLM engines") - # Get model dimensions from vLLM engine - logger.info("[DEBUG] Fetching model dimensions from first vLLM engine...") model_dims = ray.get(vllm_engines[0].get_model_dims.remote()) logger.info("======== ✅ vLLM engines and actor_manager initialized =========") - # Get and set KV cache max concurrency from the first engine (all engines have the same config) - # fp8 kv cache for now forces v0 engine and breaks this. - logger.info("[DEBUG] Setting up KV cache configuration...") if vllm_engines: kv_cache_max_concurrency = ray.get(vllm_engines[0].get_kv_cache_info.remote()) ray.get(actor_manager.set_kv_cache_max_concurrency.remote(kv_cache_max_concurrency)) @@ -1765,12 +1727,8 @@ def create_model_and_optimizer( f"You might want to use more inference nodes ({nodes_needed} nodes to generate the entire batch simultaneously)." ) else: - # dummy value ray.get(actor_manager.set_kv_cache_max_concurrency.remote(-1)) - logger.info("[DEBUG] KV cache configuration complete") - # Now create policy actors with all dependencies - logger.info("[DEBUG] Creating DataPreparationActor singleton...") data_prep_actor_name = "data_prep_singleton" _data_prep_actor = DataPreparationActor.options(name=data_prep_actor_name, num_cpus=2).remote( dataset=train_dataset, @@ -1790,9 +1748,7 @@ def create_model_and_optimizer( verbose=args.verbose, work_dir=args.output_dir, ) - logger.info(f"[DEBUG] DataPreparationActor singleton created with name: {data_prep_actor_name}") - logger.info("[DEBUG] Creating ModelGroup with policy actors...") wandb_url = wandb.run.get_url() if args.with_tracking else None policy_group = ModelGroup( pg, @@ -1804,9 +1760,7 @@ def create_model_and_optimizer( data_prep_actor_name=data_prep_actor_name, tokenizer=tokenizer, ) - logger.info(f"[DEBUG] ModelGroup created with {len(policy_group.models)} policy actors") - logger.info("[DEBUG] Starting model initialization across all ranks...") inits = [ model.from_pretrained.remote(args, model_config, beaker_config, wandb_url, tokenizer) for model in policy_group.models @@ -1821,7 +1775,6 @@ def create_model_and_optimizer( ) logger.info("======== ✅ all models initialized =========") - logger.info("[DEBUG] Setting up model update group across all ranks...") ray_get_with_progress( [m.setup_model_update_group.remote(vllm_engines=vllm_engines) for m in policy_group.models], desc="Setting up model update group", diff --git a/open_instruct/vllm_utils.py b/open_instruct/vllm_utils.py index 38101824c..035caf193 100644 --- a/open_instruct/vllm_utils.py +++ b/open_instruct/vllm_utils.py @@ -395,17 +395,12 @@ async def _check_health(port: int) -> None: def _prefetch_worker(actor: "LLMRayActor") -> None: - logger.info("[_prefetch_worker] Starting prefetch worker loop") while True: if actor._should_stop() or len(actor.active_tasks) >= actor.inference_batch_size: time.sleep(DRAIN_ACTIVE_TASKS_SLEEP_S) continue - logger.info(f"[_prefetch_worker] Waiting for request, active_tasks={len(actor.active_tasks)}") request = actor.prompt_queue.get() - logger.info( - f"[_prefetch_worker] Got request: dataset_index={request.dataset_index}, is_eval={request.is_eval}" - ) add_request(actor, request) @@ -444,7 +439,6 @@ def _create_server_args(model_path: str) -> argparse.Namespace: def accumulate_completions(actor: "LLMRayActor", sub_request: dict) -> None: base_request_id = sub_request["base_request_id"] expected_n = sub_request["expected_n"] - logger.info(f"[accumulate_completions] {base_request_id}: received sub-request") if base_request_id not in actor.request_outputs: actor.request_outputs[base_request_id] = { @@ -478,7 +472,6 @@ async def finalize_completed_request(actor: "LLMRayActor", base_request_id: str) dataset = actor.eval_dataset if is_eval else actor.train_dataset result.reward_scores, result.reward_metrics = await compute_rewards(actor, result, dataset, is_eval) results_queue = actor.eval_results_queue if is_eval else actor.results_queue - logger.info(f"[finalize_completed_request] {base_request_id}: Putting result in queue (is_eval={is_eval})") results_queue.put(result) @@ -805,7 +798,6 @@ async def process_request(actor: LLMRayActor, sub_request_id: str, sampling_para while True: current_sampling_params = dataclasses.replace(sampling_params, max_tokens=current_max_tokens) - logger.info(f"[process_request] {sub_request_id}: Making API call with max_tokens={current_max_tokens}") api_response = await actor.client.completions.create( model=actor.model_name, prompt=current_prompt, @@ -817,7 +809,6 @@ async def process_request(actor: LLMRayActor, sub_request_id: str, sampling_para }, **dataclasses.asdict(current_sampling_params), ) - logger.info(f"[process_request] {sub_request_id}: Got API response") output = api_response.choices[0] model_tokens = list(output.token_ids) @@ -976,24 +967,14 @@ def create_vllm_engines( logger.info(f"num_gpus: {num_gpus}") if not use_hybrid_engine: - # Create a big placement group to ensure that all engines are packed bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_engines * tensor_parallel_size)] - cluster_resources = ray.cluster_resources() - available_resources = ray.available_resources() - logger.info(f"[DEBUG] Cluster resources: {cluster_resources}") - logger.info(f"[DEBUG] Available resources: {available_resources}") - logger.info(f"[DEBUG] Creating vLLM placement group with {len(bundles)} bundles") pg = placement_group(bundles, strategy="PACK") - logger.info(f"[DEBUG] Waiting for vLLM placement group...") ray.get(pg.ready()) - logger.info(f"[DEBUG] vLLM placement group ready!") # ensure we use bundles on the same node where possible if tp>1. bundle_indices_list = get_bundle_indices_list(pg) - logger.info(f"[DEBUG] Creating {num_engines} vLLM engines with tensor_parallel_size={tensor_parallel_size}") for i in range(num_engines): - logger.info(f"[DEBUG] Creating vLLM engine {i + 1}/{num_engines}") bundle_indices = None bundle_indices = bundle_indices_list[i * tensor_parallel_size : (i + 1) * tensor_parallel_size] @@ -1045,12 +1026,9 @@ def create_vllm_engines( eval_dataset=eval_dataset, ) ) - logger.info(f"[DEBUG] vLLM engine {i + 1}/{num_engines} actor created") - logger.info(f"[DEBUG] All {num_engines} vLLM engine actors created, waiting for ready() (timeout=1200s)...") ray_get_with_progress( [engine.ready.remote() for engine in vllm_engines], "Initializing vLLM engines", timeout=1200 ) - logger.info(f"[DEBUG] All {num_engines} vLLM engines ready!") return vllm_engines From cb49956a4b428bd50eaae0f2f7845f7962a8d1ad Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 10:23:27 -0700 Subject: [PATCH 88/96] Cleaned up code. --- open_instruct/data_loader.py | 31 ++----------------------------- open_instruct/test_grpo_fast.py | 3 +++ 2 files changed, 5 insertions(+), 29 deletions(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 938ea286c..7b0b0398e 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -383,14 +383,9 @@ def add_prompt_to_generator( generation_config, is_eval: bool, ) -> None: - query = example[INPUT_IDS_PROMPT_KEY] - logger.info( - f"[add_prompt_to_generator] Adding prompt: dataset_index={example['dataset_index']}, epoch={epoch_number}, step={training_step}" - ) - param_prompt_Q.put( PromptRequest( - prompt=query, + prompt=example[INPUT_IDS_PROMPT_KEY], generation_config=generation_config, epoch_number=epoch_number, training_step=training_step, @@ -419,8 +414,6 @@ def accumulate_inference_batches( verbose: bool = False, max_possible_score: float = 1.0, ) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: - import ray - if no_resampling_pass_rate is not None: assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed" @@ -450,13 +443,8 @@ def accumulate_inference_batches( disable=not verbose, ) num_prompts_sampled = 0 - logger.info(f"[accumulate_inference_batches] Starting to accumulate {num_prompts} prompts") while num_prompts_sampled < num_prompts: - logger.info(f"[accumulate_inference_batches] Waiting for result {num_prompts_sampled + 1}/{num_prompts}") result = inference_results_Q.get(timeout=timeout) - logger.info( - f"[accumulate_inference_batches] Got result: {result.dataset_index if hasattr(result, 'dataset_index') else type(result)}" - ) if isinstance(result, ShutdownSentinel): return result, None, None, None @@ -581,12 +569,10 @@ def accumulate_inference_batches( total_response_tokens += result.token_statistics.num_response_tokens max_generation_time = max(max_generation_time, result.token_statistics.generation_time) - total_generation_time = max_generation_time - accumulated_stats = TokenStatistics( num_prompt_tokens=total_prompt_tokens, num_response_tokens=total_response_tokens, - generation_time=total_generation_time, + generation_time=max_generation_time, earliest_start_time=earliest_start_time, ) @@ -782,9 +768,6 @@ def _data_preparation_loop(self): raise def _data_preparation_loop_inner(self): - logger.info( - f"[DataPreparationActor] Starting data preparation loop, async_steps={self.config.async_steps}, global_batch_size={self.global_batch_size}" - ) for _ in range(self.config.async_steps * self.global_batch_size): add_prompt_to_generator( next(self.iter_dataloader), @@ -794,16 +777,11 @@ def _data_preparation_loop_inner(self): self.generation_config, is_eval=False, ) - logger.info( - f"[DataPreparationActor] Initial prompts submitted, entering main loop for {self.num_training_steps} steps" - ) for step in range(self.training_step, self.num_training_steps): if self.shutdown_requested: - logger.info("[DataPreparationActor] Shutdown requested, exiting") return - logger.info(f"[DataPreparationActor] Step {step}: calling accumulate_inference_batches") result, batch, reward_metrics, batch_stats = accumulate_inference_batches( self.inference_results_Q, self.generation_config, @@ -822,16 +800,11 @@ def _data_preparation_loop_inner(self): verbose=self.verbose, max_possible_score=self.config.max_possible_score, ) - logger.info( - f"[DataPreparationActor] Step {step}: accumulate_inference_batches returned, result is None: {result is None}" - ) if isinstance(result, ShutdownSentinel): - logger.info("[DataPreparationActor] Received shutdown sentinel, exiting") return if result is None: - logger.info("[DataPreparationActor] All prompts filtered, yielding empty batch") empty_data = [ { "collated_query_responses": [], diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index 4d7c232c8..29b40fbf2 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -1,5 +1,8 @@ import gc import os + +os.environ["HF_HUB_OFFLINE"] = "1" + import random import threading import time From 1fcef959aa807c0dcaee93ac871bcb2756db5aaf Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 12:17:28 -0700 Subject: [PATCH 89/96] Added claude /command to fix scripts --- .claude/commands/test-and-fix.md | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .claude/commands/test-and-fix.md diff --git a/.claude/commands/test-and-fix.md b/.claude/commands/test-and-fix.md new file mode 100644 index 000000000..97761e2c1 --- /dev/null +++ b/.claude/commands/test-and-fix.md @@ -0,0 +1,7 @@ +Run, in order: + + 1. @scripts/train/debug/single_gpu_on_beaker.sh + 2. @scripts/train/debug/tool_grpo_fast.sh + 3. @scripts/train/debug/large_test_script.sh + +Wait for the previous one to finish before starting the next one. Fix any errors that come up. From 9d581743105e37c94a851bd97bfd4cab04ea2bd3 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 12:17:50 -0700 Subject: [PATCH 90/96] Fix import: queue_types -> data_types in data_loader.py --- open_instruct/data_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 7b0b0398e..09fd9ef69 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -39,7 +39,7 @@ VERIFIER_SOURCE_KEY, ) from open_instruct.model_utils import Batch -from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, ShutdownSentinel, TokenStatistics +from open_instruct.data_types import GenerationResult, PromptRequest, RequestInfo, ShutdownSentinel, TokenStatistics from open_instruct.rl_utils import PackedSequences, pack_sequences from open_instruct.utils import combine_reward_metrics, repeat_each From 2bef918c5de862700586b821dc8274d2f5be9648 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 12:49:55 -0700 Subject: [PATCH 91/96] Fix data_types.py to match main: prompt_id as field, not property --- open_instruct/data_types.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/open_instruct/data_types.py b/open_instruct/data_types.py index c87778bec..fe0d7c92a 100644 --- a/open_instruct/data_types.py +++ b/open_instruct/data_types.py @@ -39,9 +39,8 @@ class GenerationResult: finish_reasons: list[str] masks: list[list[int]] request_info: RequestInfo - dataset_index: int | None = None - prompt_id: str | None = None - epoch_number: int = 0 + dataset_index: int | None + prompt_id: str | None token_statistics: TokenStatistics | None = None start_time: float | None = None logprobs: list[list[float]] | None = None @@ -61,14 +60,9 @@ class PromptRequest: prompt: list[int] generation_config: Any dataset_index: int - epoch_number: int = 0 - training_step: int = 0 + prompt_id: str is_eval: bool = False - @property - def prompt_id(self) -> str: - return f"{self.epoch_number}_{self.dataset_index}" - @dataclass class CollatedBatchData: From c5b9c1e252b1b64de190fed7a16cbf08fbc3522f Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 12:54:41 -0700 Subject: [PATCH 92/96] Align vllm_utils, data_loader, benchmark_generators with main (remove epoch_number) --- open_instruct/data_loader.py | 844 ----------------------------------- open_instruct/vllm_utils.py | 3 +- 2 files changed, 1 insertion(+), 846 deletions(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 09fd9ef69..be2e901ea 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -1,49 +1,8 @@ -# Copyright 2024 AllenAI. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import threading -import time from collections.abc import Iterable -from concurrent.futures import ThreadPoolExecutor -from dataclasses import asdict, dataclass -from pathlib import Path from typing import Any -import numpy as np -import ray -import torch -import vllm from datasets import Dataset from olmo_core.data import data_loader -from ray.util import queue as ray_queue -from tqdm import tqdm -from transformers import PreTrainedTokenizer - -from open_instruct import utils -from open_instruct.dataset_transformation import ( - GROUND_TRUTHS_KEY, - INPUT_IDS_PROMPT_KEY, - RAW_PROMPT_KEY, - VERIFIER_SOURCE_KEY, -) -from open_instruct.model_utils import Batch -from open_instruct.data_types import GenerationResult, PromptRequest, RequestInfo, ShutdownSentinel, TokenStatistics -from open_instruct.rl_utils import PackedSequences, pack_sequences -from open_instruct.utils import combine_reward_metrics, repeat_each - -logger = logging.getLogger(__name__) class HFDataLoader(data_loader.DataLoaderBase): @@ -167,806 +126,3 @@ def get_mock_batch(self) -> dict[str, Any]: The first item from the dataset. """ return self.dataset[0] - - -@dataclass -class StreamingDataLoaderConfig: - max_prompt_token_length: int = 256 - response_length: int = 256 - async_steps: int = 1 - num_samples_per_prompt_rollout: int = 4 - active_sampling: bool = False - filter_zero_std_samples: bool = True - no_resampling_pass_rate: float | None = None - advantage_normalization_type: str = "standard" - mask_truncated_completions: bool = False - pack_length: int = 512 - - def __post_init__(self): - assert self.pack_length >= self.max_prompt_token_length + self.response_length, ( - "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!" - ) - assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" - if self.num_samples_per_prompt_rollout == 1: - logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") - - if self.active_sampling: - assert self.async_steps > 1, ( - "With active_sampling, you should set async_steps > 1 to account for filtering of the first batch. " - "Otherwise, your generator only generates only one batch worth of prompts and a single filtered " - "prompt will cause the trainer to stall waiting for more data . " - ) - assert self.filter_zero_std_samples, ( - "filter_zero_std_samples must be True when active_sampling is True. " - "Active sampling requires filtering to work correctly." - ) - if self.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples: - raise ValueError( - "`filter_zero_std_samples` cannot be True when `num_samples_per_prompt_rollout` is 1, " - "as the reward standard deviation will always be 0, causing all samples to be filtered." - ) - if self.async_steps < 1: - raise ValueError("`async_steps` must be greater than 0. Fully synchronous training is not supported.") - - def build_dataloader( - self, - data_prep_actor_name: str, - tokenizer: PreTrainedTokenizer, - dp_rank: int, - fs_local_rank: int, - num_training_steps: int, - work_dir: Path | str, - global_batch_size: int, - dp_world_size: int, - ) -> "StreamingDataLoader": - """Build a thin wrapper dataloader that pulls from the DataPreparationActor singleton.""" - return StreamingDataLoader( - data_prep_actor_name=data_prep_actor_name, - tokenizer=tokenizer, - work_dir=work_dir, - global_batch_size=global_batch_size, - num_training_steps=num_training_steps, - dp_world_size=dp_world_size, - dp_rank=dp_rank, - fs_local_rank=fs_local_rank, - ) - - -class StreamingDataLoader(data_loader.DataLoaderBase): - """Thin wrapper dataloader that pulls pre-prepared data from the DataPreparationActor singleton.""" - - def __init__( - self, - *, - data_prep_actor_name: str, - tokenizer: PreTrainedTokenizer, - work_dir: Path | str, - global_batch_size: int, - num_training_steps: int = 0, - dp_world_size: int = 1, - dp_rank: int = 0, - fs_local_rank: int = 0, - ): - super().__init__( - work_dir=work_dir, - global_batch_size=global_batch_size, - dp_world_size=dp_world_size, - dp_rank=dp_rank, - fs_local_rank=fs_local_rank, - ) - - self.data_prep_actor = ray.get_actor(data_prep_actor_name) - self.tokenizer = tokenizer - self.num_training_steps = num_training_steps - self.training_step = 0 - self.current_epoch = 0 - - @property - def total_batches(self) -> int | None: - return self.num_training_steps - - def state_dict(self) -> dict[str, Any]: - return {"training_step": self.training_step, "current_epoch": self.current_epoch} - - def load_state_dict(self, state_dict: dict[str, Any]): - self.training_step = state_dict["training_step"] - self.current_epoch = state_dict.get("current_epoch", 0) - - def reshuffle(self, epoch: int | None = None, **kwargs): - if epoch is not None: - self.current_epoch = epoch - - def get_mock_batch(self) -> dict[str, Any]: - dummy_qr = torch.tensor([self.tokenizer.pad_token_id, self.tokenizer.eos_token_id], dtype=torch.long) - dummy_attention = torch.tensor([1, 1], dtype=torch.long) - dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long) - dummy_response_mask = torch.zeros_like(dummy_qr) - dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float) - - return { - "collated_query_responses": [dummy_qr], - "collated_attention_masks": [dummy_attention], - "collated_position_ids": [dummy_position_ids], - "collated_advantages": [dummy_advantage], - "collated_response_masks": [dummy_response_mask], - "collated_vllm_logprobs": [torch.zeros_like(dummy_qr, dtype=torch.float)], - } - - def _iter_batches(self) -> Iterable[dict[str, Any]]: - for step in range(self.training_step, self.num_training_steps): - batch_data = ray.get(self.data_prep_actor.get_data.remote(rank=self.dp_rank, step=step)) - self.training_step = step + 1 - yield batch_data - - def shutdown(self): - pass - - -def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: - padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) - if pin_memory and torch.cuda.is_available(): - padded_tensor = padded_tensor.pin_memory() - return padded_tensor - - -@dataclass -class BatchStatistics: - prompt_lengths: list[int] - response_lengths: list[int] - filtered_prompts: int - filtered_prompts_zero: int - filtered_prompts_solved: int - filtered_prompts_nonzero: int - percent_solved_mean: float - percent_solved_hist: np.ndarray - no_resampled_prompts: int - total_prompts: int - - -class PendingQueriesMap: - def __init__(self): - self._map = {} - self._lock = threading.Lock() - - def insert(self, dataset_idx, query, ground_truth, dataset, raw_query): - with self._lock: - if dataset_idx in self._map: - existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[ - dataset_idx - ] - self._map[dataset_idx] = ( - existing_query, - existing_ground_truth, - existing_dataset, - existing_raw_query, - count + 1, - ) - else: - self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, 1) - - def pop(self, dataset_idx): - with self._lock: - if dataset_idx not in self._map: - raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") - - query, ground_truth, dataset, raw_query, count = self._map[dataset_idx] - - if count > 1: - self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, count - 1) - else: - del self._map[dataset_idx] - - return query, ground_truth, dataset, raw_query - - def __len__(self): - with self._lock: - return len(self._map) - - def __contains__(self, dataset_idx): - with self._lock: - return dataset_idx in self._map - - def __getitem__(self, dataset_idx): - with self._lock: - return self._map[dataset_idx] - - def keys(self): - with self._lock: - return list(self._map.keys()) - - -def add_prompt_to_generator( - example: dict[str, Any], - epoch_number: int, - training_step: int, - param_prompt_Q: ray_queue.Queue, - generation_config, - is_eval: bool, -) -> None: - param_prompt_Q.put( - PromptRequest( - prompt=example[INPUT_IDS_PROMPT_KEY], - generation_config=generation_config, - epoch_number=epoch_number, - training_step=training_step, - dataset_index=example["dataset_index"], - is_eval=is_eval, - ) - ) - - -def accumulate_inference_batches( - inference_results_Q: ray_queue.Queue, - generation_config: vllm.SamplingParams, - num_prompts: int, - model_dims: utils.ModelDims, - tokenizer: PreTrainedTokenizer, - dataset: Dataset, - actor_manager=None, - timeout: float | None = None, - active_sampling: bool = False, - filter_zero_std_samples: bool = False, - replenish_prompts: bool = False, - no_resampling_pass_rate: float | None = None, - iter_dataloader: HFDataLoader | None = None, - param_prompt_Q: ray_queue.Queue | None = None, - training_step: int = None, - verbose: bool = False, - max_possible_score: float = 1.0, -) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: - if no_resampling_pass_rate is not None: - assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed" - - if replenish_prompts: - assert param_prompt_Q is not None and iter_dataloader is not None and dataset is not None, ( - "replenish_prompts requires param_prompt_Q and iter_dataloader and dataset" - ) - - results = [] - all_queries = [] - all_ground_truths = [] - all_datasets = [] - all_raw_queries = [] - all_decoded_responses = [] - all_reward_metrics = [] - all_scores = [] - all_percent_solved = [] - total_filtered_prompts = 0 - filtered_prompt_zero = 0 - filtered_prompt_solved = 0 - filtered_prompt_nonzero = 0 - total_no_resampled = 0 - progress_bar = tqdm( - total=num_prompts, - desc=f"Accumulating Responses and Rewarding {num_prompts} prompts", - bar_format="{l_bar}{bar}{r_bar}\n", - disable=not verbose, - ) - num_prompts_sampled = 0 - while num_prompts_sampled < num_prompts: - result = inference_results_Q.get(timeout=timeout) - - if isinstance(result, ShutdownSentinel): - return result, None, None, None - - assert len(result.responses) == generation_config.n, ( - f"Mismatch: individual prompt result has {len(result.responses)} responses " - f"but expected {generation_config.n} samples per prompt. " - f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}" - ) - - example = dataset[result.dataset_index] - query = example[INPUT_IDS_PROMPT_KEY] - ground_truth = example[GROUND_TRUTHS_KEY] - dataset_name = example[VERIFIER_SOURCE_KEY] - raw_query = example[RAW_PROMPT_KEY] - - if replenish_prompts: - example = next(iter_dataloader) - add_prompt_to_generator( - example, iter_dataloader._epoch, training_step, param_prompt_Q, generation_config, is_eval=False - ) - - for i in range(len(result.finish_reasons)): - if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: - result.responses[i].append(tokenizer.eos_token_id) - result.masks[i].append(1) - result.logprobs[i].append(float("nan")) - - decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) - - k_queries = repeat_each([query], generation_config.n) - k_ground_truths = repeat_each([ground_truth], generation_config.n) - k_datasets = repeat_each([dataset_name], generation_config.n) - k_raw_queries = repeat_each([raw_query], generation_config.n) - - percent_solved = np.mean(result.reward_scores).item() / max_possible_score - if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: - iter_dataloader.exclude_index(result.dataset_index) - total_no_resampled += 1 - logging.debug( - f"[Data Preparation Thread] Prompt solved at {percent_solved}, will be excluded from resampling, total no resampled: {total_no_resampled}" - ) - - if filter_zero_std_samples and np.std(result.reward_scores) == 0: - if not active_sampling: - num_prompts_sampled += 1 - progress_bar.update(1) - - total_filtered_prompts += 1 - if result.reward_scores[0] == 0: - filtered_prompt_zero += 1 - elif result.reward_scores[0] == max_possible_score: - filtered_prompt_solved += 1 - else: - filtered_prompt_nonzero += 1 - logging.debug( - f"[Data Preparation Thread] Filtered prompt with reward std 0, total filtered {total_filtered_prompts}" - ) - continue - else: - num_prompts_sampled += 1 - progress_bar.update(1) - - results.append(result) - all_queries.extend(k_queries) - all_ground_truths.extend(k_ground_truths) - all_datasets.extend(k_datasets) - all_raw_queries.extend(k_raw_queries) - all_decoded_responses.extend(decoded_responses) - all_scores.extend(result.reward_scores) - all_reward_metrics.append(result.reward_metrics) - all_percent_solved.append(percent_solved) - - if len(results) == 0: - logging.warning( - "[Data Preparation Thread] All prompts were filtered during accumulation. " - f"Filtered: {total_filtered_prompts} (zero std: {filtered_prompt_zero}, " - f"solved: {filtered_prompt_solved}, nonzero: {filtered_prompt_nonzero})" - ) - return None, None, None, None - - combined_responses = [] - combined_finish_reasons = [] - combined_masks = [] - combined_num_calls = [] - combined_timeouts = [] - combined_tool_errors = [] - combined_tool_outputs = [] - combined_tool_runtimes = [] - combined_tool_calleds = [] - combined_logprobs = [] - - earliest_start_time = float("inf") - prompt_lengths = [] - response_lengths = [] - - total_prompt_tokens = 0 - total_response_tokens = 0 - max_generation_time = 0 - - for i, result in enumerate(results): - combined_responses.extend(result.responses) - combined_finish_reasons.extend(result.finish_reasons) - combined_masks.extend(result.masks) - combined_num_calls.extend(result.request_info.num_calls) - combined_timeouts.extend(result.request_info.timeouts) - combined_tool_errors.extend(result.request_info.tool_errors) - combined_tool_outputs.extend(result.request_info.tool_outputs) - combined_tool_runtimes.extend(result.request_info.tool_runtimes) - combined_tool_calleds.extend(result.request_info.tool_calleds) - - combined_logprobs.extend(result.logprobs) - - earliest_start_time = min(earliest_start_time, result.start_time) - - prompt_lengths.append(len(all_queries[i * generation_config.n])) - - for response in result.responses: - response_lengths.append(len(response)) - - total_prompt_tokens += result.token_statistics.num_prompt_tokens - total_response_tokens += result.token_statistics.num_response_tokens - max_generation_time = max(max_generation_time, result.token_statistics.generation_time) - - accumulated_stats = TokenStatistics( - num_prompt_tokens=total_prompt_tokens, - num_response_tokens=total_response_tokens, - generation_time=max_generation_time, - earliest_start_time=earliest_start_time, - ) - - combined_request_info = RequestInfo( - num_calls=combined_num_calls, - timeouts=combined_timeouts, - tool_errors=combined_tool_errors, - tool_outputs=combined_tool_outputs, - tool_runtimes=combined_tool_runtimes, - tool_calleds=combined_tool_calleds, - ) - - combined_result = GenerationResult( - responses=combined_responses, - finish_reasons=combined_finish_reasons, - masks=combined_masks, - request_info=combined_request_info, - dataset_index=None, - epoch_number=results[0].epoch_number, - token_statistics=accumulated_stats, - logprobs=combined_logprobs, - ) - - if actor_manager is not None: - ray.get(actor_manager.report_token_statistics.remote(accumulated_stats)) - - batch = Batch( - queries=all_queries, - ground_truths=all_ground_truths, - datasets=all_datasets, - raw_queries=all_raw_queries, - decoded_responses=all_decoded_responses, - indices=None, - scores=all_scores, - ) - - combined_reward_metrics = combine_reward_metrics(all_reward_metrics) - percent_solved_mean = np.mean(all_percent_solved) if all_percent_solved else 0.0 - - batch_stats = BatchStatistics( - prompt_lengths=prompt_lengths, - response_lengths=response_lengths, - filtered_prompts=total_filtered_prompts, - filtered_prompts_zero=filtered_prompt_zero, - filtered_prompts_solved=filtered_prompt_solved, - filtered_prompts_nonzero=filtered_prompt_nonzero, - percent_solved_mean=percent_solved_mean, - percent_solved_hist=np.array(all_percent_solved), - no_resampled_prompts=total_no_resampled, - total_prompts=len(results), - ) - return combined_result, batch, combined_reward_metrics, batch_stats - - -def prepare_collated_data_for_workers( - packed_sequences: PackedSequences, - world_size: int, - per_device_train_batch_size: int, - pad_token_id: int, - pin_memory: bool = True, -) -> list[dict[str, list[torch.Tensor]]]: - """Distributes and collates packed sequences for distributed training. - - Splits packed sequences across workers, randomly shuffles each worker's data, - and collates into micro-batches for training. - - Args: - packed_sequences: Packed training sequences containing query responses, - attention masks, position IDs, advantages, response masks, - and vllm logprobs. - world_size: Number of distributed workers. - per_device_train_batch_size: Batch size for each device's micro-batch. - pad_token_id: Token ID used for padding sequences. - pin_memory: Whether to pin memory for faster data transfer to GPU. - - Returns: - List of dictionaries, one per worker, each containing collated tensors - for query_responses, attention_masks, position_ids, - advantages, response_masks, and vllm_logprobs. - """ - B = len(packed_sequences.query_responses) // world_size - collated_data = [] - for i in range(world_size): - per_device_packed_query_responses = packed_sequences.query_responses[B * i : B * (i + 1)] - per_device_packed_attention_masks = packed_sequences.attention_masks[B * i : B * (i + 1)] - per_device_packed_position_ids = packed_sequences.position_ids[B * i : B * (i + 1)] - per_device_packed_advantages = packed_sequences.advantages[B * i : B * (i + 1)] - per_device_packed_response_masks = packed_sequences.response_masks[B * i : B * (i + 1)] - per_device_packed_vllm_logprobs = packed_sequences.vllm_logprobs[B * i : B * (i + 1)] - - b_inds = np.random.permutation(len(per_device_packed_query_responses)) - collated_query_responses = [] - collated_attention_masks = [] - collated_position_ids = [] - collated_response_masks = [] - collated_advantages = [] - collated_vllm_logprobs = [] - for j in range(0, len(per_device_packed_query_responses), per_device_train_batch_size): - micro_range = b_inds[j : j + per_device_train_batch_size] - collated_query_responses.append( - collate_fn([per_device_packed_query_responses[idx] for idx in micro_range], pad_token_id, pin_memory) - ) - collated_attention_masks.append( - collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0, pin_memory) - ) - collated_position_ids.append( - collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0, pin_memory) - ) - collated_response_masks.append( - collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0, pin_memory) - ) - collated_advantages.append( - collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0, pin_memory) - ) - collated_vllm_logprobs.append( - collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0, pin_memory) - ) - collated_data.append( - { - "collated_query_responses": collated_query_responses, - "collated_attention_masks": collated_attention_masks, - "collated_position_ids": collated_position_ids, - "collated_advantages": collated_advantages, - "collated_response_masks": collated_response_masks, - "collated_vllm_logprobs": collated_vllm_logprobs, - } - ) - return collated_data - - -@ray.remote -class DataPreparationActor: - """Ray actor singleton that handles centralized data preparation for all ranks. - - This actor runs a background thread that continuously prepares training data, - ensuring all ranks receive the same number of micro-batches (preventing deadlock - from uneven filtering). - """ - - def __init__( - self, - dataset: Dataset, - inference_results_Q: ray_queue.Queue, - param_prompt_Q: ray_queue.Queue, - tokenizer: PreTrainedTokenizer, - config: StreamingDataLoaderConfig, - generation_config, - num_training_steps: int, - seed: int, - per_device_train_batch_size: int, - global_batch_size: int, - dp_world_size: int, - max_possible_score: float, - actor_manager, - model_dims: utils.ModelDims, - verbose: bool, - work_dir: str, - ): - self.inference_results_Q = inference_results_Q - self.param_prompt_Q = param_prompt_Q - self.tokenizer = tokenizer - self.config = config - self.config.max_possible_score = max_possible_score - self.generation_config = generation_config - self.num_training_steps = num_training_steps - self.per_device_train_batch_size = per_device_train_batch_size - self.global_batch_size = global_batch_size - self.dp_world_size = dp_world_size - self.actor_manager = actor_manager - self.model_dims = model_dims - self.verbose = verbose - self.dataset = dataset - - self.iter_dataloader = HFDataLoader( - dataset=dataset, batch_size=1, seed=seed, rank=0, world_size=1, work_dir=work_dir, automatic_reshuffle=True - ) - - self.prepared_data: dict[int, list[dict]] = {} - self.metrics: dict[int, dict] = {} - self.current_prepared_step = -1 - self.lock = threading.Lock() - self.shutdown_requested = False - self.training_step = 0 - - self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="DataPrepActor") - self._prep_future = self._executor.submit(self._data_preparation_loop) - - def _data_preparation_loop(self): - try: - self._data_preparation_loop_inner() - except Exception: - logger.exception("[DataPreparationActor] Exception in data preparation loop") - raise - - def _data_preparation_loop_inner(self): - for _ in range(self.config.async_steps * self.global_batch_size): - add_prompt_to_generator( - next(self.iter_dataloader), - self.iter_dataloader._epoch, - self.training_step, - self.param_prompt_Q, - self.generation_config, - is_eval=False, - ) - - for step in range(self.training_step, self.num_training_steps): - if self.shutdown_requested: - return - - result, batch, reward_metrics, batch_stats = accumulate_inference_batches( - self.inference_results_Q, - self.generation_config, - num_prompts=self.global_batch_size, - model_dims=self.model_dims, - tokenizer=self.tokenizer, - dataset=self.dataset, - actor_manager=self.actor_manager, - active_sampling=self.config.active_sampling, - filter_zero_std_samples=self.config.filter_zero_std_samples, - replenish_prompts=True, - no_resampling_pass_rate=self.config.no_resampling_pass_rate, - iter_dataloader=self.iter_dataloader, - param_prompt_Q=self.param_prompt_Q, - training_step=step, - verbose=self.verbose, - max_possible_score=self.config.max_possible_score, - ) - - if isinstance(result, ShutdownSentinel): - return - - if result is None: - empty_data = [ - { - "collated_query_responses": [], - "collated_attention_masks": [], - "collated_position_ids": [], - "collated_advantages": [], - "collated_response_masks": [], - "collated_vllm_logprobs": [], - } - for _ in range(self.dp_world_size) - ] - with self.lock: - self.prepared_data[step] = empty_data - self.metrics[step] = {} - self.current_prepared_step = step - continue - - scores = np.array(batch.scores) - scores_per_prompt = scores.reshape(-1, self.config.num_samples_per_prompt_rollout) - mean_grouped_rewards = scores_per_prompt.mean(axis=-1) - mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) - std_grouped_rewards = scores_per_prompt.std(axis=-1) - std_grouped_rewards = np.repeat(std_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) - - if self.config.advantage_normalization_type == "standard": - advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) - elif self.config.advantage_normalization_type == "centered": - advantages = scores - mean_grouped_rewards - else: - raise ValueError(f"Invalid advantage normalization type: {self.config.advantage_normalization_type}") - - if self.config.mask_truncated_completions: - stop_idxes = torch.tensor( - [i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"] - ) - num_truncated = len(result.finish_reasons) - len(stop_idxes) - if num_truncated > 0: - logger.info( - f"[DataPreparationActor] Filtered {num_truncated} responses that didn't finish with 'stop'. " - f"Retention rate: {len(stop_idxes) / len(result.finish_reasons):.2%}" - ) - scores = scores[stop_idxes] - advantages = advantages[stop_idxes] - batch = batch[stop_idxes.tolist()] - result.responses = [result.responses[i] for i in stop_idxes] - result.masks = [result.masks[i] for i in stop_idxes] - result.finish_reasons = [result.finish_reasons[i] for i in stop_idxes] - result.logprobs = [result.logprobs[i] for i in stop_idxes] - - packed_sequences = pack_sequences( - queries=batch.queries, - responses=result.responses, - masks=result.masks, - pack_length=self.config.pack_length, - pad_token_id=self.tokenizer.pad_token_id, - vllm_logprobs=result.logprobs, - ) - lookup_advantages = np.zeros(len(advantages) + 1, dtype=np.float32) - lookup_advantages[1:] = advantages - packed_advantages = [ - torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) - for packed_mask in packed_sequences.response_masks - ] - packed_sequences.advantages = packed_advantages - - collated_data = prepare_collated_data_for_workers( - packed_sequences, self.dp_world_size, self.per_device_train_batch_size, self.tokenizer.pad_token_id - ) - - if len(result.responses) == 0: - step_metrics = {} - else: - real_num_responses = len(result.responses) - expected_num_responses = self.config.num_samples_per_prompt_rollout * self.global_batch_size - unsolved_num_responses = (scores < self.config.max_possible_score).sum() - sequence_lengths = np.array([len(response) for response in result.responses]) - sequence_length_solved = ( - np.array([]) - if np.all(scores == 0) - else np.array(sequence_lengths[scores == self.config.max_possible_score]) - ) - sequence_length_unsolved = ( - np.array([]) - if np.all(scores == self.config.max_possible_score) - else np.array(sequence_lengths[scores == 0]) - ) - stop_rate = sum(int(fr == "stop") for fr in result.finish_reasons) / len(result.finish_reasons) - - batch_metrics_dict = asdict(batch_stats) - batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics_dict.items()} - - step_metrics = { - "scores": scores.mean(), - "real_batch_size_ratio": real_num_responses / expected_num_responses, - "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, - "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, - "val/solve_rate_hist": batch_stats.percent_solved_hist, - "val/total_reward_groups": real_num_responses / self.config.num_samples_per_prompt_rollout, - "val/sequence_lengths": sequence_lengths.mean(), - "val/sequence_lengths_min": sequence_lengths.min(), - "val/sequence_lengths_max": sequence_lengths.max(), - "val/sequence_lengths_unsolved": ( - 0 if len(sequence_length_unsolved) == 0 else sequence_length_unsolved.mean() - ), - "val/sequence_lengths_solved": ( - 0 if len(sequence_length_solved) == 0 else sequence_length_solved.mean() - ), - "val/sequence_lengths_unsolved_hist": sequence_length_unsolved, - "val/sequence_lengths_solved_hist": sequence_length_solved, - "val/stop_rate": stop_rate, - "val/advantages_mean": advantages.mean(), - "val/advantages_min": advantages.min(), - "val/advantages_max": advantages.max(), - "val/advantages_hist": advantages, - "val/num_calls_rate": np.array(result.request_info.num_calls).mean(), - "val/timeouts_rate": np.array(result.request_info.timeouts).mean(), - "val/tool_errors_rate": np.array( - [len(item) > 0 for item in result.request_info.tool_errors] - ).mean(), - "val/tool_runtimes_rate": np.array(result.request_info.tool_runtimes).mean(), - "val/tool_calleds_rate": np.array(result.request_info.tool_calleds).mean(), - **reward_metrics, - **batch_metrics_prefixed, - } - - total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens - step_metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time - - with self.lock: - self.prepared_data[step] = collated_data - self.metrics[step] = step_metrics - self.current_prepared_step = step - - def get_data(self, rank: int, step: int) -> dict: - """Called by each rank's StreamingDataLoader. Blocks until data ready.""" - while True: - with self.lock: - if step <= self.current_prepared_step: - data = self.prepared_data[step][rank].copy() - data["metrics"] = self.metrics[step] - self._cleanup_old_steps(step) - return data - time.sleep(0.01) - - def _cleanup_old_steps(self, current_step: int): - """Remove old step data to prevent memory leak.""" - steps_to_remove = [s for s in self.prepared_data if s < current_step - 1] - for s in steps_to_remove: - del self.prepared_data[s] - if s in self.metrics: - del self.metrics[s] - - def shutdown(self): - self.shutdown_requested = True - self._executor.shutdown(wait=True) - - def get_state(self) -> dict: - return { - "training_step": self.current_prepared_step + 1, - "iter_dataloader_state": self.iter_dataloader.state_dict(), - } - - def set_state(self, state: dict): - self.training_step = state["training_step"] - self.iter_dataloader.load_state_dict(state["iter_dataloader_state"]) diff --git a/open_instruct/vllm_utils.py b/open_instruct/vllm_utils.py index a0e371d07..cb1ca4441 100644 --- a/open_instruct/vllm_utils.py +++ b/open_instruct/vllm_utils.py @@ -296,7 +296,6 @@ def process_completed_request(request_id, outs, current_time, tools, request_met ), dataset_index=metadata["dataset_index"], prompt_id=metadata["prompt_id"], - epoch_number=metadata.get("epoch_number", 0), token_statistics=TokenStatistics( num_prompt_tokens=len(metadata["prompt_token_ids"]), num_response_tokens=total_generation_tokens, @@ -412,7 +411,6 @@ def add_request(actor: "LLMRayActor", request: PromptRequest) -> None: "is_eval": request.is_eval, "dataset_index": request.dataset_index, "prompt_id": request.prompt_id, - "epoch_number": request.epoch_number, "sampling_params": sampling_params, "original_sampling_params": request.generation_config, "prompt_token_ids": list(request.prompt), @@ -967,6 +965,7 @@ def create_vllm_engines( logger.info(f"num_gpus: {num_gpus}") if not use_hybrid_engine: + # Create a big placement group to ensure that all engines are packed bundles = [{"GPU": 1, "CPU": 1} for _ in range(num_engines * tensor_parallel_size)] pg = placement_group(bundles, strategy="PACK") ray.get(pg.ready()) From 028bb5e73b24753b185b4b9f783210d30e46558c Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 14:32:00 -0700 Subject: [PATCH 93/96] Restore StreamingDataLoader and reconcile epoch_number/prompt_id APIs - Restored data_loader.py with StreamingDataLoader and StreamingDataLoaderConfig - Updated PromptRequest to support both prompt_id (passed directly) and epoch_number (computed prompt_id) via get_prompt_id() method - Added epoch_number back to GenerationResult - Updated vllm_utils.py to use get_prompt_id() instead of prompt_id directly --- open_instruct/data_loader.py | 844 +++++++++++++++++++++++++++++++++++ open_instruct/data_types.py | 16 +- open_instruct/vllm_utils.py | 4 +- 3 files changed, 859 insertions(+), 5 deletions(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index be2e901ea..09fd9ef69 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -1,8 +1,49 @@ +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +import time from collections.abc import Iterable +from concurrent.futures import ThreadPoolExecutor +from dataclasses import asdict, dataclass +from pathlib import Path from typing import Any +import numpy as np +import ray +import torch +import vllm from datasets import Dataset from olmo_core.data import data_loader +from ray.util import queue as ray_queue +from tqdm import tqdm +from transformers import PreTrainedTokenizer + +from open_instruct import utils +from open_instruct.dataset_transformation import ( + GROUND_TRUTHS_KEY, + INPUT_IDS_PROMPT_KEY, + RAW_PROMPT_KEY, + VERIFIER_SOURCE_KEY, +) +from open_instruct.model_utils import Batch +from open_instruct.data_types import GenerationResult, PromptRequest, RequestInfo, ShutdownSentinel, TokenStatistics +from open_instruct.rl_utils import PackedSequences, pack_sequences +from open_instruct.utils import combine_reward_metrics, repeat_each + +logger = logging.getLogger(__name__) class HFDataLoader(data_loader.DataLoaderBase): @@ -126,3 +167,806 @@ def get_mock_batch(self) -> dict[str, Any]: The first item from the dataset. """ return self.dataset[0] + + +@dataclass +class StreamingDataLoaderConfig: + max_prompt_token_length: int = 256 + response_length: int = 256 + async_steps: int = 1 + num_samples_per_prompt_rollout: int = 4 + active_sampling: bool = False + filter_zero_std_samples: bool = True + no_resampling_pass_rate: float | None = None + advantage_normalization_type: str = "standard" + mask_truncated_completions: bool = False + pack_length: int = 512 + + def __post_init__(self): + assert self.pack_length >= self.max_prompt_token_length + self.response_length, ( + "The `pack_length` needs to be greater than the sum of `max_prompt_token_length` and `response_length`!" + ) + assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" + if self.num_samples_per_prompt_rollout == 1: + logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") + + if self.active_sampling: + assert self.async_steps > 1, ( + "With active_sampling, you should set async_steps > 1 to account for filtering of the first batch. " + "Otherwise, your generator only generates only one batch worth of prompts and a single filtered " + "prompt will cause the trainer to stall waiting for more data . " + ) + assert self.filter_zero_std_samples, ( + "filter_zero_std_samples must be True when active_sampling is True. " + "Active sampling requires filtering to work correctly." + ) + if self.num_samples_per_prompt_rollout == 1 and self.filter_zero_std_samples: + raise ValueError( + "`filter_zero_std_samples` cannot be True when `num_samples_per_prompt_rollout` is 1, " + "as the reward standard deviation will always be 0, causing all samples to be filtered." + ) + if self.async_steps < 1: + raise ValueError("`async_steps` must be greater than 0. Fully synchronous training is not supported.") + + def build_dataloader( + self, + data_prep_actor_name: str, + tokenizer: PreTrainedTokenizer, + dp_rank: int, + fs_local_rank: int, + num_training_steps: int, + work_dir: Path | str, + global_batch_size: int, + dp_world_size: int, + ) -> "StreamingDataLoader": + """Build a thin wrapper dataloader that pulls from the DataPreparationActor singleton.""" + return StreamingDataLoader( + data_prep_actor_name=data_prep_actor_name, + tokenizer=tokenizer, + work_dir=work_dir, + global_batch_size=global_batch_size, + num_training_steps=num_training_steps, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + fs_local_rank=fs_local_rank, + ) + + +class StreamingDataLoader(data_loader.DataLoaderBase): + """Thin wrapper dataloader that pulls pre-prepared data from the DataPreparationActor singleton.""" + + def __init__( + self, + *, + data_prep_actor_name: str, + tokenizer: PreTrainedTokenizer, + work_dir: Path | str, + global_batch_size: int, + num_training_steps: int = 0, + dp_world_size: int = 1, + dp_rank: int = 0, + fs_local_rank: int = 0, + ): + super().__init__( + work_dir=work_dir, + global_batch_size=global_batch_size, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + fs_local_rank=fs_local_rank, + ) + + self.data_prep_actor = ray.get_actor(data_prep_actor_name) + self.tokenizer = tokenizer + self.num_training_steps = num_training_steps + self.training_step = 0 + self.current_epoch = 0 + + @property + def total_batches(self) -> int | None: + return self.num_training_steps + + def state_dict(self) -> dict[str, Any]: + return {"training_step": self.training_step, "current_epoch": self.current_epoch} + + def load_state_dict(self, state_dict: dict[str, Any]): + self.training_step = state_dict["training_step"] + self.current_epoch = state_dict.get("current_epoch", 0) + + def reshuffle(self, epoch: int | None = None, **kwargs): + if epoch is not None: + self.current_epoch = epoch + + def get_mock_batch(self) -> dict[str, Any]: + dummy_qr = torch.tensor([self.tokenizer.pad_token_id, self.tokenizer.eos_token_id], dtype=torch.long) + dummy_attention = torch.tensor([1, 1], dtype=torch.long) + dummy_position_ids = torch.arange(len(dummy_qr), dtype=torch.long) + dummy_response_mask = torch.zeros_like(dummy_qr) + dummy_advantage = torch.zeros_like(dummy_qr, dtype=torch.float) + + return { + "collated_query_responses": [dummy_qr], + "collated_attention_masks": [dummy_attention], + "collated_position_ids": [dummy_position_ids], + "collated_advantages": [dummy_advantage], + "collated_response_masks": [dummy_response_mask], + "collated_vllm_logprobs": [torch.zeros_like(dummy_qr, dtype=torch.float)], + } + + def _iter_batches(self) -> Iterable[dict[str, Any]]: + for step in range(self.training_step, self.num_training_steps): + batch_data = ray.get(self.data_prep_actor.get_data.remote(rank=self.dp_rank, step=step)) + self.training_step = step + 1 + yield batch_data + + def shutdown(self): + pass + + +def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: + padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) + if pin_memory and torch.cuda.is_available(): + padded_tensor = padded_tensor.pin_memory() + return padded_tensor + + +@dataclass +class BatchStatistics: + prompt_lengths: list[int] + response_lengths: list[int] + filtered_prompts: int + filtered_prompts_zero: int + filtered_prompts_solved: int + filtered_prompts_nonzero: int + percent_solved_mean: float + percent_solved_hist: np.ndarray + no_resampled_prompts: int + total_prompts: int + + +class PendingQueriesMap: + def __init__(self): + self._map = {} + self._lock = threading.Lock() + + def insert(self, dataset_idx, query, ground_truth, dataset, raw_query): + with self._lock: + if dataset_idx in self._map: + existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[ + dataset_idx + ] + self._map[dataset_idx] = ( + existing_query, + existing_ground_truth, + existing_dataset, + existing_raw_query, + count + 1, + ) + else: + self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, 1) + + def pop(self, dataset_idx): + with self._lock: + if dataset_idx not in self._map: + raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") + + query, ground_truth, dataset, raw_query, count = self._map[dataset_idx] + + if count > 1: + self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, count - 1) + else: + del self._map[dataset_idx] + + return query, ground_truth, dataset, raw_query + + def __len__(self): + with self._lock: + return len(self._map) + + def __contains__(self, dataset_idx): + with self._lock: + return dataset_idx in self._map + + def __getitem__(self, dataset_idx): + with self._lock: + return self._map[dataset_idx] + + def keys(self): + with self._lock: + return list(self._map.keys()) + + +def add_prompt_to_generator( + example: dict[str, Any], + epoch_number: int, + training_step: int, + param_prompt_Q: ray_queue.Queue, + generation_config, + is_eval: bool, +) -> None: + param_prompt_Q.put( + PromptRequest( + prompt=example[INPUT_IDS_PROMPT_KEY], + generation_config=generation_config, + epoch_number=epoch_number, + training_step=training_step, + dataset_index=example["dataset_index"], + is_eval=is_eval, + ) + ) + + +def accumulate_inference_batches( + inference_results_Q: ray_queue.Queue, + generation_config: vllm.SamplingParams, + num_prompts: int, + model_dims: utils.ModelDims, + tokenizer: PreTrainedTokenizer, + dataset: Dataset, + actor_manager=None, + timeout: float | None = None, + active_sampling: bool = False, + filter_zero_std_samples: bool = False, + replenish_prompts: bool = False, + no_resampling_pass_rate: float | None = None, + iter_dataloader: HFDataLoader | None = None, + param_prompt_Q: ray_queue.Queue | None = None, + training_step: int = None, + verbose: bool = False, + max_possible_score: float = 1.0, +) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: + if no_resampling_pass_rate is not None: + assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed" + + if replenish_prompts: + assert param_prompt_Q is not None and iter_dataloader is not None and dataset is not None, ( + "replenish_prompts requires param_prompt_Q and iter_dataloader and dataset" + ) + + results = [] + all_queries = [] + all_ground_truths = [] + all_datasets = [] + all_raw_queries = [] + all_decoded_responses = [] + all_reward_metrics = [] + all_scores = [] + all_percent_solved = [] + total_filtered_prompts = 0 + filtered_prompt_zero = 0 + filtered_prompt_solved = 0 + filtered_prompt_nonzero = 0 + total_no_resampled = 0 + progress_bar = tqdm( + total=num_prompts, + desc=f"Accumulating Responses and Rewarding {num_prompts} prompts", + bar_format="{l_bar}{bar}{r_bar}\n", + disable=not verbose, + ) + num_prompts_sampled = 0 + while num_prompts_sampled < num_prompts: + result = inference_results_Q.get(timeout=timeout) + + if isinstance(result, ShutdownSentinel): + return result, None, None, None + + assert len(result.responses) == generation_config.n, ( + f"Mismatch: individual prompt result has {len(result.responses)} responses " + f"but expected {generation_config.n} samples per prompt. " + f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}" + ) + + example = dataset[result.dataset_index] + query = example[INPUT_IDS_PROMPT_KEY] + ground_truth = example[GROUND_TRUTHS_KEY] + dataset_name = example[VERIFIER_SOURCE_KEY] + raw_query = example[RAW_PROMPT_KEY] + + if replenish_prompts: + example = next(iter_dataloader) + add_prompt_to_generator( + example, iter_dataloader._epoch, training_step, param_prompt_Q, generation_config, is_eval=False + ) + + for i in range(len(result.finish_reasons)): + if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: + result.responses[i].append(tokenizer.eos_token_id) + result.masks[i].append(1) + result.logprobs[i].append(float("nan")) + + decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) + + k_queries = repeat_each([query], generation_config.n) + k_ground_truths = repeat_each([ground_truth], generation_config.n) + k_datasets = repeat_each([dataset_name], generation_config.n) + k_raw_queries = repeat_each([raw_query], generation_config.n) + + percent_solved = np.mean(result.reward_scores).item() / max_possible_score + if no_resampling_pass_rate is not None and percent_solved >= no_resampling_pass_rate: + iter_dataloader.exclude_index(result.dataset_index) + total_no_resampled += 1 + logging.debug( + f"[Data Preparation Thread] Prompt solved at {percent_solved}, will be excluded from resampling, total no resampled: {total_no_resampled}" + ) + + if filter_zero_std_samples and np.std(result.reward_scores) == 0: + if not active_sampling: + num_prompts_sampled += 1 + progress_bar.update(1) + + total_filtered_prompts += 1 + if result.reward_scores[0] == 0: + filtered_prompt_zero += 1 + elif result.reward_scores[0] == max_possible_score: + filtered_prompt_solved += 1 + else: + filtered_prompt_nonzero += 1 + logging.debug( + f"[Data Preparation Thread] Filtered prompt with reward std 0, total filtered {total_filtered_prompts}" + ) + continue + else: + num_prompts_sampled += 1 + progress_bar.update(1) + + results.append(result) + all_queries.extend(k_queries) + all_ground_truths.extend(k_ground_truths) + all_datasets.extend(k_datasets) + all_raw_queries.extend(k_raw_queries) + all_decoded_responses.extend(decoded_responses) + all_scores.extend(result.reward_scores) + all_reward_metrics.append(result.reward_metrics) + all_percent_solved.append(percent_solved) + + if len(results) == 0: + logging.warning( + "[Data Preparation Thread] All prompts were filtered during accumulation. " + f"Filtered: {total_filtered_prompts} (zero std: {filtered_prompt_zero}, " + f"solved: {filtered_prompt_solved}, nonzero: {filtered_prompt_nonzero})" + ) + return None, None, None, None + + combined_responses = [] + combined_finish_reasons = [] + combined_masks = [] + combined_num_calls = [] + combined_timeouts = [] + combined_tool_errors = [] + combined_tool_outputs = [] + combined_tool_runtimes = [] + combined_tool_calleds = [] + combined_logprobs = [] + + earliest_start_time = float("inf") + prompt_lengths = [] + response_lengths = [] + + total_prompt_tokens = 0 + total_response_tokens = 0 + max_generation_time = 0 + + for i, result in enumerate(results): + combined_responses.extend(result.responses) + combined_finish_reasons.extend(result.finish_reasons) + combined_masks.extend(result.masks) + combined_num_calls.extend(result.request_info.num_calls) + combined_timeouts.extend(result.request_info.timeouts) + combined_tool_errors.extend(result.request_info.tool_errors) + combined_tool_outputs.extend(result.request_info.tool_outputs) + combined_tool_runtimes.extend(result.request_info.tool_runtimes) + combined_tool_calleds.extend(result.request_info.tool_calleds) + + combined_logprobs.extend(result.logprobs) + + earliest_start_time = min(earliest_start_time, result.start_time) + + prompt_lengths.append(len(all_queries[i * generation_config.n])) + + for response in result.responses: + response_lengths.append(len(response)) + + total_prompt_tokens += result.token_statistics.num_prompt_tokens + total_response_tokens += result.token_statistics.num_response_tokens + max_generation_time = max(max_generation_time, result.token_statistics.generation_time) + + accumulated_stats = TokenStatistics( + num_prompt_tokens=total_prompt_tokens, + num_response_tokens=total_response_tokens, + generation_time=max_generation_time, + earliest_start_time=earliest_start_time, + ) + + combined_request_info = RequestInfo( + num_calls=combined_num_calls, + timeouts=combined_timeouts, + tool_errors=combined_tool_errors, + tool_outputs=combined_tool_outputs, + tool_runtimes=combined_tool_runtimes, + tool_calleds=combined_tool_calleds, + ) + + combined_result = GenerationResult( + responses=combined_responses, + finish_reasons=combined_finish_reasons, + masks=combined_masks, + request_info=combined_request_info, + dataset_index=None, + epoch_number=results[0].epoch_number, + token_statistics=accumulated_stats, + logprobs=combined_logprobs, + ) + + if actor_manager is not None: + ray.get(actor_manager.report_token_statistics.remote(accumulated_stats)) + + batch = Batch( + queries=all_queries, + ground_truths=all_ground_truths, + datasets=all_datasets, + raw_queries=all_raw_queries, + decoded_responses=all_decoded_responses, + indices=None, + scores=all_scores, + ) + + combined_reward_metrics = combine_reward_metrics(all_reward_metrics) + percent_solved_mean = np.mean(all_percent_solved) if all_percent_solved else 0.0 + + batch_stats = BatchStatistics( + prompt_lengths=prompt_lengths, + response_lengths=response_lengths, + filtered_prompts=total_filtered_prompts, + filtered_prompts_zero=filtered_prompt_zero, + filtered_prompts_solved=filtered_prompt_solved, + filtered_prompts_nonzero=filtered_prompt_nonzero, + percent_solved_mean=percent_solved_mean, + percent_solved_hist=np.array(all_percent_solved), + no_resampled_prompts=total_no_resampled, + total_prompts=len(results), + ) + return combined_result, batch, combined_reward_metrics, batch_stats + + +def prepare_collated_data_for_workers( + packed_sequences: PackedSequences, + world_size: int, + per_device_train_batch_size: int, + pad_token_id: int, + pin_memory: bool = True, +) -> list[dict[str, list[torch.Tensor]]]: + """Distributes and collates packed sequences for distributed training. + + Splits packed sequences across workers, randomly shuffles each worker's data, + and collates into micro-batches for training. + + Args: + packed_sequences: Packed training sequences containing query responses, + attention masks, position IDs, advantages, response masks, + and vllm logprobs. + world_size: Number of distributed workers. + per_device_train_batch_size: Batch size for each device's micro-batch. + pad_token_id: Token ID used for padding sequences. + pin_memory: Whether to pin memory for faster data transfer to GPU. + + Returns: + List of dictionaries, one per worker, each containing collated tensors + for query_responses, attention_masks, position_ids, + advantages, response_masks, and vllm_logprobs. + """ + B = len(packed_sequences.query_responses) // world_size + collated_data = [] + for i in range(world_size): + per_device_packed_query_responses = packed_sequences.query_responses[B * i : B * (i + 1)] + per_device_packed_attention_masks = packed_sequences.attention_masks[B * i : B * (i + 1)] + per_device_packed_position_ids = packed_sequences.position_ids[B * i : B * (i + 1)] + per_device_packed_advantages = packed_sequences.advantages[B * i : B * (i + 1)] + per_device_packed_response_masks = packed_sequences.response_masks[B * i : B * (i + 1)] + per_device_packed_vllm_logprobs = packed_sequences.vllm_logprobs[B * i : B * (i + 1)] + + b_inds = np.random.permutation(len(per_device_packed_query_responses)) + collated_query_responses = [] + collated_attention_masks = [] + collated_position_ids = [] + collated_response_masks = [] + collated_advantages = [] + collated_vllm_logprobs = [] + for j in range(0, len(per_device_packed_query_responses), per_device_train_batch_size): + micro_range = b_inds[j : j + per_device_train_batch_size] + collated_query_responses.append( + collate_fn([per_device_packed_query_responses[idx] for idx in micro_range], pad_token_id, pin_memory) + ) + collated_attention_masks.append( + collate_fn([per_device_packed_attention_masks[idx] for idx in micro_range], 0, pin_memory) + ) + collated_position_ids.append( + collate_fn([per_device_packed_position_ids[idx] for idx in micro_range], 0, pin_memory) + ) + collated_response_masks.append( + collate_fn([per_device_packed_response_masks[idx] for idx in micro_range], 0, pin_memory) + ) + collated_advantages.append( + collate_fn([per_device_packed_advantages[idx] for idx in micro_range], 0, pin_memory) + ) + collated_vllm_logprobs.append( + collate_fn([per_device_packed_vllm_logprobs[idx] for idx in micro_range], 0, pin_memory) + ) + collated_data.append( + { + "collated_query_responses": collated_query_responses, + "collated_attention_masks": collated_attention_masks, + "collated_position_ids": collated_position_ids, + "collated_advantages": collated_advantages, + "collated_response_masks": collated_response_masks, + "collated_vllm_logprobs": collated_vllm_logprobs, + } + ) + return collated_data + + +@ray.remote +class DataPreparationActor: + """Ray actor singleton that handles centralized data preparation for all ranks. + + This actor runs a background thread that continuously prepares training data, + ensuring all ranks receive the same number of micro-batches (preventing deadlock + from uneven filtering). + """ + + def __init__( + self, + dataset: Dataset, + inference_results_Q: ray_queue.Queue, + param_prompt_Q: ray_queue.Queue, + tokenizer: PreTrainedTokenizer, + config: StreamingDataLoaderConfig, + generation_config, + num_training_steps: int, + seed: int, + per_device_train_batch_size: int, + global_batch_size: int, + dp_world_size: int, + max_possible_score: float, + actor_manager, + model_dims: utils.ModelDims, + verbose: bool, + work_dir: str, + ): + self.inference_results_Q = inference_results_Q + self.param_prompt_Q = param_prompt_Q + self.tokenizer = tokenizer + self.config = config + self.config.max_possible_score = max_possible_score + self.generation_config = generation_config + self.num_training_steps = num_training_steps + self.per_device_train_batch_size = per_device_train_batch_size + self.global_batch_size = global_batch_size + self.dp_world_size = dp_world_size + self.actor_manager = actor_manager + self.model_dims = model_dims + self.verbose = verbose + self.dataset = dataset + + self.iter_dataloader = HFDataLoader( + dataset=dataset, batch_size=1, seed=seed, rank=0, world_size=1, work_dir=work_dir, automatic_reshuffle=True + ) + + self.prepared_data: dict[int, list[dict]] = {} + self.metrics: dict[int, dict] = {} + self.current_prepared_step = -1 + self.lock = threading.Lock() + self.shutdown_requested = False + self.training_step = 0 + + self._executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="DataPrepActor") + self._prep_future = self._executor.submit(self._data_preparation_loop) + + def _data_preparation_loop(self): + try: + self._data_preparation_loop_inner() + except Exception: + logger.exception("[DataPreparationActor] Exception in data preparation loop") + raise + + def _data_preparation_loop_inner(self): + for _ in range(self.config.async_steps * self.global_batch_size): + add_prompt_to_generator( + next(self.iter_dataloader), + self.iter_dataloader._epoch, + self.training_step, + self.param_prompt_Q, + self.generation_config, + is_eval=False, + ) + + for step in range(self.training_step, self.num_training_steps): + if self.shutdown_requested: + return + + result, batch, reward_metrics, batch_stats = accumulate_inference_batches( + self.inference_results_Q, + self.generation_config, + num_prompts=self.global_batch_size, + model_dims=self.model_dims, + tokenizer=self.tokenizer, + dataset=self.dataset, + actor_manager=self.actor_manager, + active_sampling=self.config.active_sampling, + filter_zero_std_samples=self.config.filter_zero_std_samples, + replenish_prompts=True, + no_resampling_pass_rate=self.config.no_resampling_pass_rate, + iter_dataloader=self.iter_dataloader, + param_prompt_Q=self.param_prompt_Q, + training_step=step, + verbose=self.verbose, + max_possible_score=self.config.max_possible_score, + ) + + if isinstance(result, ShutdownSentinel): + return + + if result is None: + empty_data = [ + { + "collated_query_responses": [], + "collated_attention_masks": [], + "collated_position_ids": [], + "collated_advantages": [], + "collated_response_masks": [], + "collated_vllm_logprobs": [], + } + for _ in range(self.dp_world_size) + ] + with self.lock: + self.prepared_data[step] = empty_data + self.metrics[step] = {} + self.current_prepared_step = step + continue + + scores = np.array(batch.scores) + scores_per_prompt = scores.reshape(-1, self.config.num_samples_per_prompt_rollout) + mean_grouped_rewards = scores_per_prompt.mean(axis=-1) + mean_grouped_rewards = np.repeat(mean_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) + std_grouped_rewards = scores_per_prompt.std(axis=-1) + std_grouped_rewards = np.repeat(std_grouped_rewards, self.config.num_samples_per_prompt_rollout, axis=0) + + if self.config.advantage_normalization_type == "standard": + advantages = (scores - mean_grouped_rewards) / (std_grouped_rewards + 1e-8) + elif self.config.advantage_normalization_type == "centered": + advantages = scores - mean_grouped_rewards + else: + raise ValueError(f"Invalid advantage normalization type: {self.config.advantage_normalization_type}") + + if self.config.mask_truncated_completions: + stop_idxes = torch.tensor( + [i for i in range(len(result.finish_reasons)) if result.finish_reasons[i] == "stop"] + ) + num_truncated = len(result.finish_reasons) - len(stop_idxes) + if num_truncated > 0: + logger.info( + f"[DataPreparationActor] Filtered {num_truncated} responses that didn't finish with 'stop'. " + f"Retention rate: {len(stop_idxes) / len(result.finish_reasons):.2%}" + ) + scores = scores[stop_idxes] + advantages = advantages[stop_idxes] + batch = batch[stop_idxes.tolist()] + result.responses = [result.responses[i] for i in stop_idxes] + result.masks = [result.masks[i] for i in stop_idxes] + result.finish_reasons = [result.finish_reasons[i] for i in stop_idxes] + result.logprobs = [result.logprobs[i] for i in stop_idxes] + + packed_sequences = pack_sequences( + queries=batch.queries, + responses=result.responses, + masks=result.masks, + pack_length=self.config.pack_length, + pad_token_id=self.tokenizer.pad_token_id, + vllm_logprobs=result.logprobs, + ) + lookup_advantages = np.zeros(len(advantages) + 1, dtype=np.float32) + lookup_advantages[1:] = advantages + packed_advantages = [ + torch.tensor(lookup_advantages[packed_mask], dtype=torch.float32) + for packed_mask in packed_sequences.response_masks + ] + packed_sequences.advantages = packed_advantages + + collated_data = prepare_collated_data_for_workers( + packed_sequences, self.dp_world_size, self.per_device_train_batch_size, self.tokenizer.pad_token_id + ) + + if len(result.responses) == 0: + step_metrics = {} + else: + real_num_responses = len(result.responses) + expected_num_responses = self.config.num_samples_per_prompt_rollout * self.global_batch_size + unsolved_num_responses = (scores < self.config.max_possible_score).sum() + sequence_lengths = np.array([len(response) for response in result.responses]) + sequence_length_solved = ( + np.array([]) + if np.all(scores == 0) + else np.array(sequence_lengths[scores == self.config.max_possible_score]) + ) + sequence_length_unsolved = ( + np.array([]) + if np.all(scores == self.config.max_possible_score) + else np.array(sequence_lengths[scores == 0]) + ) + stop_rate = sum(int(fr == "stop") for fr in result.finish_reasons) / len(result.finish_reasons) + + batch_metrics_dict = asdict(batch_stats) + batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics_dict.items()} + + step_metrics = { + "scores": scores.mean(), + "real_batch_size_ratio": real_num_responses / expected_num_responses, + "unsolved_batch_size_ratio": unsolved_num_responses / real_num_responses, + "packed_ratio": len(packed_sequences.query_responses) / real_num_responses, + "val/solve_rate_hist": batch_stats.percent_solved_hist, + "val/total_reward_groups": real_num_responses / self.config.num_samples_per_prompt_rollout, + "val/sequence_lengths": sequence_lengths.mean(), + "val/sequence_lengths_min": sequence_lengths.min(), + "val/sequence_lengths_max": sequence_lengths.max(), + "val/sequence_lengths_unsolved": ( + 0 if len(sequence_length_unsolved) == 0 else sequence_length_unsolved.mean() + ), + "val/sequence_lengths_solved": ( + 0 if len(sequence_length_solved) == 0 else sequence_length_solved.mean() + ), + "val/sequence_lengths_unsolved_hist": sequence_length_unsolved, + "val/sequence_lengths_solved_hist": sequence_length_solved, + "val/stop_rate": stop_rate, + "val/advantages_mean": advantages.mean(), + "val/advantages_min": advantages.min(), + "val/advantages_max": advantages.max(), + "val/advantages_hist": advantages, + "val/num_calls_rate": np.array(result.request_info.num_calls).mean(), + "val/timeouts_rate": np.array(result.request_info.timeouts).mean(), + "val/tool_errors_rate": np.array( + [len(item) > 0 for item in result.request_info.tool_errors] + ).mean(), + "val/tool_runtimes_rate": np.array(result.request_info.tool_runtimes).mean(), + "val/tool_calleds_rate": np.array(result.request_info.tool_calleds).mean(), + **reward_metrics, + **batch_metrics_prefixed, + } + + total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens + step_metrics["val/actor_tokens_per_second"] = total_tokens / result.token_statistics.generation_time + + with self.lock: + self.prepared_data[step] = collated_data + self.metrics[step] = step_metrics + self.current_prepared_step = step + + def get_data(self, rank: int, step: int) -> dict: + """Called by each rank's StreamingDataLoader. Blocks until data ready.""" + while True: + with self.lock: + if step <= self.current_prepared_step: + data = self.prepared_data[step][rank].copy() + data["metrics"] = self.metrics[step] + self._cleanup_old_steps(step) + return data + time.sleep(0.01) + + def _cleanup_old_steps(self, current_step: int): + """Remove old step data to prevent memory leak.""" + steps_to_remove = [s for s in self.prepared_data if s < current_step - 1] + for s in steps_to_remove: + del self.prepared_data[s] + if s in self.metrics: + del self.metrics[s] + + def shutdown(self): + self.shutdown_requested = True + self._executor.shutdown(wait=True) + + def get_state(self) -> dict: + return { + "training_step": self.current_prepared_step + 1, + "iter_dataloader_state": self.iter_dataloader.state_dict(), + } + + def set_state(self, state: dict): + self.training_step = state["training_step"] + self.iter_dataloader.load_state_dict(state["iter_dataloader_state"]) diff --git a/open_instruct/data_types.py b/open_instruct/data_types.py index fe0d7c92a..2d42d2d02 100644 --- a/open_instruct/data_types.py +++ b/open_instruct/data_types.py @@ -39,8 +39,9 @@ class GenerationResult: finish_reasons: list[str] masks: list[list[int]] request_info: RequestInfo - dataset_index: int | None - prompt_id: str | None + dataset_index: int | None = None + prompt_id: str | None = None + epoch_number: int = 0 token_statistics: TokenStatistics | None = None start_time: float | None = None logprobs: list[list[float]] | None = None @@ -55,14 +56,23 @@ class PromptRequest: Note: We intentionally type `generation_config` as `Any` to avoid importing heavy dependencies (e.g., vLLM) at import time in deserializers like Ray's `_QueueActor`. + + prompt_id can be passed directly, or computed from epoch_number and dataset_index. """ prompt: list[int] generation_config: Any dataset_index: int - prompt_id: str + prompt_id: str | None = None + epoch_number: int = 0 + training_step: int = 0 is_eval: bool = False + def get_prompt_id(self) -> str: + if self.prompt_id is not None: + return self.prompt_id + return f"{self.epoch_number}_{self.dataset_index}" + @dataclass class CollatedBatchData: diff --git a/open_instruct/vllm_utils.py b/open_instruct/vllm_utils.py index cb1ca4441..8f508cd91 100644 --- a/open_instruct/vllm_utils.py +++ b/open_instruct/vllm_utils.py @@ -183,7 +183,7 @@ def get_bundle_indices_list(placement_group: ray.util.placement_group) -> list[i def make_request_id(request: PromptRequest) -> str: """Generate a unique tracking key for a request.""" prefix = "eval" if request.is_eval else "train" - return f"{prefix}_{request.prompt_id}" + return f"{prefix}_{request.get_prompt_id()}" def split_request_id(full_request_id: str) -> dict: @@ -410,7 +410,7 @@ def add_request(actor: "LLMRayActor", request: PromptRequest) -> None: actor.request_metadata[request_id] = { "is_eval": request.is_eval, "dataset_index": request.dataset_index, - "prompt_id": request.prompt_id, + "prompt_id": request.get_prompt_id(), "sampling_params": sampling_params, "original_sampling_params": request.generation_config, "prompt_token_ids": list(request.prompt), From 1c5104f0a6ed4b15ecfa2f40b5eb49488b0e7a12 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 15:03:21 -0700 Subject: [PATCH 94/96] redid changes --- open_instruct/data_loader.py | 25 +++++++++---------------- open_instruct/data_types.py | 22 +++++++++------------- open_instruct/vllm_utils.py | 4 ++-- 3 files changed, 20 insertions(+), 31 deletions(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index 09fd9ef69..d50cd86d8 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -32,6 +32,7 @@ from transformers import PreTrainedTokenizer from open_instruct import utils +from open_instruct.data_types import GenerationResult, PromptRequest, RequestInfo, ShutdownSentinel, TokenStatistics from open_instruct.dataset_transformation import ( GROUND_TRUTHS_KEY, INPUT_IDS_PROMPT_KEY, @@ -39,7 +40,6 @@ VERIFIER_SOURCE_KEY, ) from open_instruct.model_utils import Batch -from open_instruct.data_types import GenerationResult, PromptRequest, RequestInfo, ShutdownSentinel, TokenStatistics from open_instruct.rl_utils import PackedSequences, pack_sequences from open_instruct.utils import combine_reward_metrics, repeat_each @@ -376,20 +376,16 @@ def keys(self): def add_prompt_to_generator( - example: dict[str, Any], - epoch_number: int, - training_step: int, - param_prompt_Q: ray_queue.Queue, - generation_config, - is_eval: bool, + example: dict[str, Any], epoch_number: int, param_prompt_Q: ray_queue.Queue, generation_config, is_eval: bool ) -> None: + dataset_index = example["dataset_index"] + prompt_id = f"{epoch_number}_{dataset_index}" param_prompt_Q.put( PromptRequest( prompt=example[INPUT_IDS_PROMPT_KEY], generation_config=generation_config, - epoch_number=epoch_number, - training_step=training_step, - dataset_index=example["dataset_index"], + dataset_index=dataset_index, + prompt_id=prompt_id, is_eval=is_eval, ) ) @@ -452,7 +448,7 @@ def accumulate_inference_batches( assert len(result.responses) == generation_config.n, ( f"Mismatch: individual prompt result has {len(result.responses)} responses " f"but expected {generation_config.n} samples per prompt. " - f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}" + f"Dataset index: {result.dataset_index}, Epoch: {result.epoch()}" ) example = dataset[result.dataset_index] @@ -463,9 +459,7 @@ def accumulate_inference_batches( if replenish_prompts: example = next(iter_dataloader) - add_prompt_to_generator( - example, iter_dataloader._epoch, training_step, param_prompt_Q, generation_config, is_eval=False - ) + add_prompt_to_generator(example, iter_dataloader._epoch, param_prompt_Q, generation_config, is_eval=False) for i in range(len(result.finish_reasons)): if result.finish_reasons[i] == "stop" and len(result.responses[i]) == 0: @@ -591,7 +585,7 @@ def accumulate_inference_batches( masks=combined_masks, request_info=combined_request_info, dataset_index=None, - epoch_number=results[0].epoch_number, + prompt_id=results[0].prompt_id, token_statistics=accumulated_stats, logprobs=combined_logprobs, ) @@ -772,7 +766,6 @@ def _data_preparation_loop_inner(self): add_prompt_to_generator( next(self.iter_dataloader), self.iter_dataloader._epoch, - self.training_step, self.param_prompt_Q, self.generation_config, is_eval=False, diff --git a/open_instruct/data_types.py b/open_instruct/data_types.py index 2d42d2d02..7a1eee9a3 100644 --- a/open_instruct/data_types.py +++ b/open_instruct/data_types.py @@ -39,15 +39,20 @@ class GenerationResult: finish_reasons: list[str] masks: list[list[int]] request_info: RequestInfo - dataset_index: int | None = None - prompt_id: str | None = None - epoch_number: int = 0 + dataset_index: int | None + prompt_id: str | None token_statistics: TokenStatistics | None = None start_time: float | None = None logprobs: list[list[float]] | None = None reward_scores: list[float] | None = None reward_metrics: dict[str, Any] | None = None + def epoch(self) -> int: + """Extract epoch number from prompt_id (format: '{epoch}_{dataset_index}').""" + if self.prompt_id is None: + return 0 + return int(self.prompt_id.split("_")[0]) + @dataclass class PromptRequest: @@ -56,23 +61,14 @@ class PromptRequest: Note: We intentionally type `generation_config` as `Any` to avoid importing heavy dependencies (e.g., vLLM) at import time in deserializers like Ray's `_QueueActor`. - - prompt_id can be passed directly, or computed from epoch_number and dataset_index. """ prompt: list[int] generation_config: Any dataset_index: int - prompt_id: str | None = None - epoch_number: int = 0 - training_step: int = 0 + prompt_id: str is_eval: bool = False - def get_prompt_id(self) -> str: - if self.prompt_id is not None: - return self.prompt_id - return f"{self.epoch_number}_{self.dataset_index}" - @dataclass class CollatedBatchData: diff --git a/open_instruct/vllm_utils.py b/open_instruct/vllm_utils.py index 8f508cd91..cb1ca4441 100644 --- a/open_instruct/vllm_utils.py +++ b/open_instruct/vllm_utils.py @@ -183,7 +183,7 @@ def get_bundle_indices_list(placement_group: ray.util.placement_group) -> list[i def make_request_id(request: PromptRequest) -> str: """Generate a unique tracking key for a request.""" prefix = "eval" if request.is_eval else "train" - return f"{prefix}_{request.get_prompt_id()}" + return f"{prefix}_{request.prompt_id}" def split_request_id(full_request_id: str) -> dict: @@ -410,7 +410,7 @@ def add_request(actor: "LLMRayActor", request: PromptRequest) -> None: actor.request_metadata[request_id] = { "is_eval": request.is_eval, "dataset_index": request.dataset_index, - "prompt_id": request.get_prompt_id(), + "prompt_id": request.prompt_id, "sampling_params": sampling_params, "original_sampling_params": request.generation_config, "prompt_token_ids": list(request.prompt), From 3b74d82258f18a89d55f49c2a53bd223cdd80113 Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 15:04:24 -0700 Subject: [PATCH 95/96] review changes --- .claude/commands/test-and-fix.md | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 .claude/commands/test-and-fix.md diff --git a/.claude/commands/test-and-fix.md b/.claude/commands/test-and-fix.md deleted file mode 100644 index 97761e2c1..000000000 --- a/.claude/commands/test-and-fix.md +++ /dev/null @@ -1,7 +0,0 @@ -Run, in order: - - 1. @scripts/train/debug/single_gpu_on_beaker.sh - 2. @scripts/train/debug/tool_grpo_fast.sh - 3. @scripts/train/debug/large_test_script.sh - -Wait for the previous one to finish before starting the next one. Fix any errors that come up. From 9ae9b95febe32df4e5b7da1a2b8749faf9e741df Mon Sep 17 00:00:00 2001 From: Finbarr Timbers Date: Fri, 5 Dec 2025 15:46:20 -0700 Subject: [PATCH 96/96] Cleane dup code. --- open_instruct/data_loader.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/open_instruct/data_loader.py b/open_instruct/data_loader.py index d50cd86d8..ce23b2677 100644 --- a/open_instruct/data_loader.py +++ b/open_instruct/data_loader.py @@ -379,13 +379,12 @@ def add_prompt_to_generator( example: dict[str, Any], epoch_number: int, param_prompt_Q: ray_queue.Queue, generation_config, is_eval: bool ) -> None: dataset_index = example["dataset_index"] - prompt_id = f"{epoch_number}_{dataset_index}" param_prompt_Q.put( PromptRequest( prompt=example[INPUT_IDS_PROMPT_KEY], generation_config=generation_config, dataset_index=dataset_index, - prompt_id=prompt_id, + prompt_id=f"{epoch_number}_{dataset_index}", is_eval=is_eval, ) )