diff --git a/open_instruct/actor_manager.py b/open_instruct/actor_manager.py index ae4cc6508..29932fc3a 100644 --- a/open_instruct/actor_manager.py +++ b/open_instruct/actor_manager.py @@ -54,6 +54,7 @@ def __init__(self, queues: dict, args): self._total_decode_tokens = 0 self._training_step_history = collections.deque(maxlen=self._sample_window) self._generation_batch_history = collections.deque(maxlen=self._sample_window) + self._current_training_step = 0 self._kv_cache_max_concurrency = None self._args = args if self._args.enable_queue_dashboard: @@ -171,6 +172,14 @@ def report_batch_generation_time(self, duration: float): """Report the time taken to generate a batch of data.""" self._generation_batch_history.append(duration) + def set_current_training_step(self, training_step: int): + """Record the most recent training step observed by the learner.""" + self._current_training_step = training_step + + def get_current_training_step(self) -> int: + """Return the latest training step recorded by the learner.""" + return self._current_training_step + def set_kv_cache_max_concurrency(self, max_concurrency: int): """Set the KV cache max concurrency value.""" self._kv_cache_max_concurrency = max_concurrency diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index cac12986e..36a7aff5f 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1642,6 +1642,9 @@ class BatchStatistics: percent_solved_mean: float no_resampled_prompts: int total_prompts: int + staleness_queue_steps: list[int | None] = field(default_factory=list) + staleness_generation_start_steps: list[int | None] = field(default_factory=list) + staleness_consumed_step: int | None = None def accumulate_inference_batches( @@ -1661,7 +1664,7 @@ def accumulate_inference_batches( no_resampling_pass_rate: float | None = None, iter_dataloader: ShufflingIterator | None = None, prompt_dataset: Dataset = None, - param_prompt_Q: ray_queue.Queue | None = None, + prompt_Q: ray_queue.Queue | None = None, training_step: int = None, ) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: """Accumulate multiple inference results into a single training batch. @@ -1679,7 +1682,7 @@ def accumulate_inference_batches( no_resampling_pass_rate: Optional rate at which to note samples solved at greater than this rate and exclude them from further sampling iter_dataloader: Optional, used for no_resampling_pass_rate - param_prompt_Q: Queue containing prompts to send to generator, used to replenish used prompts + prompt_Q: Queue containing prompts to send to generator, used to replenish used prompts Raises: queue.Empty: If timeout is specified and no data is available within timeout. @@ -1692,8 +1695,8 @@ def accumulate_inference_batches( assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed" if replenish_prompts: - assert param_prompt_Q is not None and iter_dataloader is not None and prompt_dataset is not None, ( - "replenish_prompts requires param_prompt_Q and iter_dataloader and prompt_dataset" + assert prompt_Q is not None and iter_dataloader is not None and prompt_dataset is not None, ( + "replenish_prompts requires prompt_Q and iter_dataloader and prompt_dataset" ) results = [] @@ -1710,6 +1713,8 @@ def accumulate_inference_batches( filtered_prompt_solved = 0 filtered_prompt_nonzero = 0 total_no_resampled = 0 + staleness_queue_steps: list[int | None] = [] + staleness_generation_start_steps: list[int | None] = [] progress_bar = tqdm( total=num_prompts, desc=f"Accumulating Responses and Rewarding {num_prompts} prompts", @@ -1723,6 +1728,11 @@ def accumulate_inference_batches( if isinstance(result, ShutdownSentinel): return result, None, None, None + queued_step = result.queued_training_step + generation_started_step = result.generation_started_training_step + staleness_queue_steps.append(queued_step) + staleness_generation_start_steps.append(generation_started_step) + # Validate that each individual result has the expected number of responses assert len(result.responses) == generation_config.n, ( f"Mismatch: individual prompt result has {len(result.responses)} responses " @@ -1741,7 +1751,7 @@ def accumulate_inference_batches( iter_dataloader.epoch_number, training_step, pending_queries_map, - param_prompt_Q, + prompt_Q, generation_config, is_eval=False, ) @@ -1918,6 +1928,9 @@ def accumulate_inference_batches( percent_solved_mean=percent_solved_mean, no_resampled_prompts=total_no_resampled, total_prompts=len(results), + staleness_queue_steps=staleness_queue_steps if staleness_queue_steps else [], + staleness_generation_start_steps=staleness_generation_start_steps if staleness_generation_start_steps else [], + staleness_consumed_step=training_step, ) logging.info( f"[Data Preparation Thread] Calculating rewards took {combined_reward_metrics['time/reward']} seconds" @@ -1929,7 +1942,7 @@ def accumulate_inference_batches( def data_preparation_thread( reward_fn: Callable, inference_results_Q: ray_queue.Queue, # Ray queue - param_prompt_Q: ray_queue.Queue, + prompt_Q: ray_queue.Queue, packed_sequences_Q: Queue, pending_queries_map: dict, args: Args, @@ -1961,7 +1974,7 @@ def data_preparation_thread( no_resampling_pass_rate=args.no_resampling_pass_rate, iter_dataloader=iter_dataloader, prompt_dataset=train_dataset, - param_prompt_Q=param_prompt_Q, + prompt_Q=prompt_Q, training_step=training_step, ) if isinstance(result, ShutdownSentinel): @@ -2081,8 +2094,26 @@ def data_preparation_thread( ) batch_metrics = asdict(batch_stats) - batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()} + # calculate staleness metrics + queue_steps = batch_stats.pop("staleness_queue_steps", []) + generation_steps = batch_stats.pop("staleness_generation_start_steps", []) + consumed_steps = batch_stats.pop("staleness_consumed_step", []) + + queue_to_generation = [] + generation_to_consume = [] + for queue_step, generation_step, consume_step in zip(queue_steps, generation_steps, consumed_steps): + queue_to_generation.append(generation_step - queue_step) + generation_to_consume.append(consume_step - generation_step) + + staleness_metrics = { + "staleness/queue_to_generation": queue_to_generation, + "staleness/queue_to_generation_mean": np.mean(queue_to_generation), + "staleness/generation_to_consume": generation_to_consume, + "staleness/generation_to_consume_mean": np.mean(generation_to_consume), + } + + batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()} metrics = { "scores": scores.mean(), "real_batch_size_ratio": real_num_responses / expected_num_responses, @@ -2115,6 +2146,7 @@ def data_preparation_thread( "time/getting_response": getting_response_time, **reward_metrics, **batch_metrics_prefixed, + **staleness_metrics, } total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens @@ -2263,7 +2295,7 @@ def create_model_and_optimizer( wandb_url: str, tokenizer: PreTrainedTokenizer, inference_results_Q: ray_queue.Queue, - param_prompt_Q: ray_queue.Queue, + prompt_Q: ray_queue.Queue, evaluation_inference_results_Q: ray_queue.Queue, ) -> tuple[ModelGroup, list[vllm_utils.LLMRayActor], dict, int, int]: """Create the model, optimizer, and vLLM engines.""" @@ -2308,7 +2340,7 @@ def create_model_and_optimizer( queues_to_monitor = { "Inference Results Queue": inference_results_Q, - "Param Prompt Queue": param_prompt_Q, + "Param Prompt Queue": prompt_Q, "Evaluation Queue": evaluation_inference_results_Q, } actor_manager = ray.remote(ActorManager).remote(queues_to_monitor, args) @@ -2329,7 +2361,7 @@ def create_model_and_optimizer( pg=pg if args.single_gpu_mode else None, tools=tool_objects, max_tool_calls=args.max_tool_calls, - prompt_queue=param_prompt_Q, + prompt_queue=prompt_Q, results_queue=inference_results_Q, eval_results_queue=evaluation_inference_results_Q, actor_manager=actor_manager, @@ -2396,7 +2428,7 @@ def add_prompt_to_generator( epoch_number: int, training_step: int, pending_queries_map: PendingQueriesMap, - param_prompt_Q: ray_queue.Queue, + prompt_Q: ray_queue.Queue, generation_config, is_eval: bool, ) -> None: @@ -2407,7 +2439,7 @@ def add_prompt_to_generator( raw_query = example[RAW_PROMPT_KEY] pending_queries_map.insert(example_index, query, ground_truth, dataset_name, raw_query) - param_prompt_Q.put( + prompt_Q.put( PromptRequest( prompt=query, generation_config=generation_config, @@ -2935,7 +2967,7 @@ def run_training( stop_event, executor, inference_results_Q, - param_prompt_Q, + prompt_Q, evaluation_inference_results_Q, packed_sequences_Q, pending_queries_map, @@ -2972,7 +3004,7 @@ def run_training( data_preparation_thread, reward_fn, inference_results_Q, - param_prompt_Q, + prompt_Q, packed_sequences_Q, pending_queries_map, args, @@ -3003,7 +3035,7 @@ def health_check_fn(): iter_dataloader.epoch_number, resume_training_step, pending_queries_map, - param_prompt_Q, + prompt_Q, generation_configs["train"], is_eval=False, ) @@ -3015,6 +3047,7 @@ def health_check_fn(): training_start_time = time.perf_counter() # Track overall training start time for training_step in range(resume_training_step, args.num_training_steps + 1): + actor_manager.set_current_training_step.remote(training_step) start_time = time.perf_counter() if ( @@ -3056,7 +3089,7 @@ def health_check_fn(): iter_dataloader.epoch_number, training_step, eval_pending_queries_map, - param_prompt_Q, + prompt_Q, generation_configs["eval"], is_eval=True, ) @@ -3178,7 +3211,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): # all prompts from async_steps + 1 training steps queue_size = (args.async_steps + 1) * args.num_unique_prompts_rollout inference_results_Q = ray_queue.Queue(maxsize=queue_size) - param_prompt_Q = ray_queue.Queue(maxsize=queue_size) + prompt_Q = ray_queue.Queue(maxsize=queue_size) # We don't care if we ever hit the max, so we let the queue be unbounded. evaluation_inference_results_Q = ray_queue.Queue() @@ -3191,7 +3224,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): wandb_url, tokenizer, inference_results_Q, - param_prompt_Q, + prompt_Q, evaluation_inference_results_Q, ) ) @@ -3249,7 +3282,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): stop_event, executor, inference_results_Q, - param_prompt_Q, + prompt_Q, evaluation_inference_results_Q, packed_sequences_Q, pending_queries_map, @@ -3266,7 +3299,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): raise finally: cleanup_training_resources( - stop_event, executor, [inference_results_Q, param_prompt_Q, evaluation_inference_results_Q], actor_manager + stop_event, executor, [inference_results_Q, prompt_Q, evaluation_inference_results_Q], actor_manager ) # Ai2 logic: we use /output to store the artifacts of the job, so we diff --git a/open_instruct/queue_types.py b/open_instruct/queue_types.py index 0cc047bca..bf2de9531 100644 --- a/open_instruct/queue_types.py +++ b/open_instruct/queue_types.py @@ -37,6 +37,8 @@ class GenerationResult: token_statistics: TokenStatistics | None = None start_time: float | None = None logprobs: list[list[float]] | None = None + queued_training_step: int | None = None + generation_started_training_step: int | None = None @dataclass diff --git a/open_instruct/test_vllm_utils.py b/open_instruct/test_vllm_utils.py index d73d7c29e..57349171c 100644 --- a/open_instruct/test_vllm_utils.py +++ b/open_instruct/test_vllm_utils.py @@ -72,6 +72,7 @@ def create_mock_logprobs(token_ids): "dataset_index": 43039, "epoch_number": 0, "training_step": 1, + "generation_started_training_step": 2, "prompt_token_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], "start_time": 1000.0, } @@ -107,6 +108,8 @@ def create_mock_logprobs(token_ids): self.assertEqual(result.request_info.tool_outputs, ["result1", "result2"]) self.assertEqual(result.request_info.tool_runtimes, [0.5, 0.3]) self.assertEqual(result.request_info.tool_calleds, [True, True]) + self.assertEqual(result.queued_training_step, 1) + self.assertEqual(result.generation_started_training_step, 2) def test_process_outputs_without_tools(self): """Test that process_completed_request correctly handles outputs without tool attributes.""" @@ -143,6 +146,7 @@ def create_mock_logprobs(token_ids): "dataset_index": 200, "epoch_number": 0, "training_step": 2, + "generation_started_training_step": 3, "prompt_token_ids": [1, 2, 3, 4, 5], "start_time": 2000.0, } @@ -181,6 +185,8 @@ def create_mock_logprobs(token_ids): self.assertEqual(result.request_info.tool_outputs, ["", ""]) self.assertEqual(result.request_info.tool_runtimes, [0.0, 0.0]) self.assertEqual(result.request_info.tool_calleds, [False, False]) + self.assertEqual(result.queued_training_step, 2) + self.assertEqual(result.generation_started_training_step, 3) if __name__ == "__main__": diff --git a/open_instruct/vllm_utils.py b/open_instruct/vllm_utils.py index 4635ee097..c09790d3f 100644 --- a/open_instruct/vllm_utils.py +++ b/open_instruct/vllm_utils.py @@ -383,6 +383,8 @@ def process_completed_request(request_id, outs, current_time, tools, request_met ), start_time=metadata["start_time"], logprobs=logprobs, + queued_training_step=metadata.get("training_step"), + generation_started_training_step=metadata.get("generation_started_training_step"), ) return result, metadata["is_eval"] @@ -470,10 +472,13 @@ def _prefetch_worker(actor: "LLMRayActor") -> None: continue request = actor.prompt_queue.get() - add_request(actor, request) + generation_started_training_step = actor._get_current_training_step() + add_request(actor, request, generation_started_training_step) -def add_request(actor: "LLMRayActor", request: PromptRequest) -> None: +def add_request( + actor: "LLMRayActor", request: PromptRequest, generation_started_training_step: int | None = None +) -> None: request_id = make_request_id(request) sampling_params = request.generation_config.clone() @@ -488,6 +493,7 @@ def add_request(actor: "LLMRayActor", request: PromptRequest) -> None: "original_sampling_params": request.generation_config, "prompt_token_ids": list(request.prompt), "start_time": time.perf_counter(), + "generation_started_training_step": generation_started_training_step, } tokens_prompt = vllm.TokensPrompt(prompt_token_ids=request.prompt, cache_salt=request_id) @@ -550,6 +556,14 @@ def _init_queues(self, prompt_queue, results_queue, eval_results_queue, actor_ma self._last_should_stop_update = float("-inf") self._should_stop_value = False + def _get_current_training_step(self) -> int | None: + """Fetch the learner's current training step for staleness accounting.""" + try: + return ray.get(self.actor_manager.get_current_training_step.remote()) + except Exception as exc: # pragma: no cover - log and fall back gracefully + logger.warning(f"Failed to fetch current training step from ActorManager: {exc}") + return None + def _init_executor(self) -> None: max_workers = NUM_PREFETCH_WORKERS + (NUM_TOOL_WORKERS if self.tools else 0) self.executor = futures.ThreadPoolExecutor(max_workers=max_workers)