-
Notifications
You must be signed in to change notification settings - Fork 459
added staleness metrics #1204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
added staleness metrics #1204
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1642,6 +1642,9 @@ class BatchStatistics: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| percent_solved_mean: float | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| no_resampled_prompts: int | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total_prompts: int | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| staleness_queue_steps: list[int | None] = field(default_factory=list) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| staleness_generation_start_steps: list[int | None] = field(default_factory=list) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| staleness_consumed_step: int | None = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def accumulate_inference_batches( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1661,7 +1664,7 @@ def accumulate_inference_batches( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| no_resampling_pass_rate: float | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| iter_dataloader: ShufflingIterator | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_dataset: Dataset = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q: ray_queue.Queue | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q: ray_queue.Queue | None = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| training_step: int = None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> tuple[GenerationResult, Batch, dict, BatchStatistics]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Accumulate multiple inference results into a single training batch. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1679,7 +1682,7 @@ def accumulate_inference_batches( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| no_resampling_pass_rate: Optional rate at which to note samples solved at greater than this rate | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| and exclude them from further sampling | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| iter_dataloader: Optional, used for no_resampling_pass_rate | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q: Queue containing prompts to send to generator, used to replenish used prompts | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q: Queue containing prompts to send to generator, used to replenish used prompts | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Raises: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| queue.Empty: If timeout is specified and no data is available within timeout. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1692,8 +1695,8 @@ def accumulate_inference_batches( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if replenish_prompts: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert param_prompt_Q is not None and iter_dataloader is not None and prompt_dataset is not None, ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "replenish_prompts requires param_prompt_Q and iter_dataloader and prompt_dataset" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert prompt_Q is not None and iter_dataloader is not None and prompt_dataset is not None, ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "replenish_prompts requires prompt_Q and iter_dataloader and prompt_dataset" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| results = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1710,6 +1713,8 @@ def accumulate_inference_batches( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| filtered_prompt_solved = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| filtered_prompt_nonzero = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total_no_resampled = 0 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| staleness_queue_steps: list[int | None] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| staleness_generation_start_steps: list[int | None] = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| progress_bar = tqdm( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total=num_prompts, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| desc=f"Accumulating Responses and Rewarding {num_prompts} prompts", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1723,6 +1728,11 @@ def accumulate_inference_batches( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(result, ShutdownSentinel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return result, None, None, None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| queued_step = result.queued_training_step | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. inline these? |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_started_step = result.generation_started_training_step | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| staleness_queue_steps.append(queued_step) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| staleness_generation_start_steps.append(generation_started_step) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Validate that each individual result has the expected number of responses | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert len(result.responses) == generation_config.n, ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Mismatch: individual prompt result has {len(result.responses)} responses " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1741,7 +1751,7 @@ def accumulate_inference_batches( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| iter_dataloader.epoch_number, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| training_step, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pending_queries_map, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_eval=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1918,6 +1928,9 @@ def accumulate_inference_batches( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| percent_solved_mean=percent_solved_mean, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| no_resampled_prompts=total_no_resampled, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total_prompts=len(results), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| staleness_queue_steps=staleness_queue_steps if staleness_queue_steps else [], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If staleness_queue_steps is Falsey, then it's [], right? So we can remove the ternary |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| staleness_generation_start_steps=staleness_generation_start_steps if staleness_generation_start_steps else [], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+1931
to
+1932
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The conditional assignments here are redundant. The variables
Suggested change
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here about ternary |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| staleness_consumed_step=training_step, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logging.info( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"[Data Preparation Thread] Calculating rewards took {combined_reward_metrics['time/reward']} seconds" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1929,7 +1942,7 @@ def accumulate_inference_batches( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def data_preparation_thread( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reward_fn: Callable, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_results_Q: ray_queue.Queue, # Ray queue | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q: ray_queue.Queue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q: ray_queue.Queue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| packed_sequences_Q: Queue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pending_queries_map: dict, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args: Args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -1961,7 +1974,7 @@ def data_preparation_thread( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| no_resampling_pass_rate=args.no_resampling_pass_rate, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| iter_dataloader=iter_dataloader, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_dataset=train_dataset, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q=param_prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q=prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| training_step=training_step, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(result, ShutdownSentinel): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2081,8 +2094,26 @@ def data_preparation_thread( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_metrics = asdict(batch_stats) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # calculate staleness metrics | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make this a function please! staleness_metrics = get_staleness_metrics(...) |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| queue_steps = batch_stats.pop("staleness_queue_steps", []) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_steps = batch_stats.pop("staleness_generation_start_steps", []) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| consumed_steps = batch_stats.pop("staleness_consumed_step", []) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| queue_to_generation = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_to_consume = [] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for queue_step, generation_step, consume_step in zip(queue_steps, generation_steps, consumed_steps): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| queue_to_generation.append(generation_step - queue_step) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_to_consume.append(consume_step - generation_step) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Zip Error: Single Value Not IterableThe There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Staleness Calculation: Robustness Against Missing DataThe staleness calculation performs arithmetic on |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| staleness_metrics = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "staleness/queue_to_generation": queue_to_generation, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "staleness/queue_to_generation_mean": np.mean(queue_to_generation), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "staleness/generation_to_consume": generation_to_consume, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "staleness/generation_to_consume_mean": np.mean(generation_to_consume), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+2098
to
+2114
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are a few issues in this block that will cause runtime errors and incorrect behavior:
I've provided a suggestion that fixes these issues by using
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| metrics = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "scores": scores.mean(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "real_batch_size_ratio": real_num_responses / expected_num_responses, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2115,6 +2146,7 @@ def data_preparation_thread( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "time/getting_response": getting_response_time, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **reward_metrics, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **batch_metrics_prefixed, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| **staleness_metrics, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2263,7 +2295,7 @@ def create_model_and_optimizer( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| wandb_url: str, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer: PreTrainedTokenizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_results_Q: ray_queue.Queue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q: ray_queue.Queue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q: ray_queue.Queue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| evaluation_inference_results_Q: ray_queue.Queue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> tuple[ModelGroup, list[vllm_utils.LLMRayActor], dict, int, int]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Create the model, optimizer, and vLLM engines.""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2308,7 +2340,7 @@ def create_model_and_optimizer( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| queues_to_monitor = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Inference Results Queue": inference_results_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Param Prompt Queue": param_prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Param Prompt Queue": prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| "Evaluation Queue": evaluation_inference_results_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actor_manager = ray.remote(ActorManager).remote(queues_to_monitor, args) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2329,7 +2361,7 @@ def create_model_and_optimizer( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pg=pg if args.single_gpu_mode else None, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tools=tool_objects, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| max_tool_calls=args.max_tool_calls, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_queue=param_prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_queue=prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| results_queue=inference_results_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| eval_results_queue=evaluation_inference_results_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actor_manager=actor_manager, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2396,7 +2428,7 @@ def add_prompt_to_generator( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| epoch_number: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| training_step: int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pending_queries_map: PendingQueriesMap, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q: ray_queue.Queue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q: ray_queue.Queue, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_eval: bool, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2407,7 +2439,7 @@ def add_prompt_to_generator( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raw_query = example[RAW_PROMPT_KEY] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pending_queries_map.insert(example_index, query, ground_truth, dataset_name, raw_query) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q.put( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q.put( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PromptRequest( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt=query, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_config=generation_config, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2935,7 +2967,7 @@ def run_training( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stop_event, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| executor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_results_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| evaluation_inference_results_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| packed_sequences_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pending_queries_map, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -2972,7 +3004,7 @@ def run_training( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| data_preparation_thread, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| reward_fn, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_results_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| packed_sequences_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pending_queries_map, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -3003,7 +3035,7 @@ def health_check_fn(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| iter_dataloader.epoch_number, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| resume_training_step, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pending_queries_map, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_configs["train"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_eval=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -3015,6 +3047,7 @@ def health_check_fn(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| training_start_time = time.perf_counter() # Track overall training start time | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for training_step in range(resume_training_step, args.num_training_steps + 1): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| actor_manager.set_current_training_step.remote(training_step) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| start_time = time.perf_counter() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -3056,7 +3089,7 @@ def health_check_fn(): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| iter_dataloader.epoch_number, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| training_step, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| eval_pending_queries_map, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| generation_configs["eval"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| is_eval=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -3178,7 +3211,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # all prompts from async_steps + 1 training steps | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| queue_size = (args.async_steps + 1) * args.num_unique_prompts_rollout | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_results_Q = ray_queue.Queue(maxsize=queue_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q = ray_queue.Queue(maxsize=queue_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q = ray_queue.Queue(maxsize=queue_size) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # We don't care if we ever hit the max, so we let the queue be unbounded. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| evaluation_inference_results_Q = ray_queue.Queue() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -3191,7 +3224,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| wandb_url, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_results_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| evaluation_inference_results_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -3249,7 +3282,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stop_event, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| executor, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inference_results_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| param_prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| evaluation_inference_results_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| packed_sequences_Q, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pending_queries_map, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -3266,7 +3299,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| finally: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| cleanup_training_resources( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stop_event, executor, [inference_results_Q, param_prompt_Q, evaluation_inference_results_Q], actor_manager | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| stop_event, executor, [inference_results_Q, prompt_Q, evaluation_inference_results_Q], actor_manager | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Ai2 logic: we use /output to store the artifacts of the job, so we | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -383,6 +383,8 @@ def process_completed_request(request_id, outs, current_time, tools, request_met | |
| ), | ||
| start_time=metadata["start_time"], | ||
| logprobs=logprobs, | ||
| queued_training_step=metadata.get("training_step"), | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| generation_started_training_step=metadata.get("generation_started_training_step"), | ||
| ) | ||
| return result, metadata["is_eval"] | ||
|
|
||
|
|
@@ -470,10 +472,13 @@ def _prefetch_worker(actor: "LLMRayActor") -> None: | |
| continue | ||
|
|
||
| request = actor.prompt_queue.get() | ||
| add_request(actor, request) | ||
| generation_started_training_step = actor._get_current_training_step() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just inline this if it's called once? Also don't wrap in a try/except, let's make the actor manager have a default value |
||
| add_request(actor, request, generation_started_training_step) | ||
|
|
||
|
|
||
| def add_request(actor: "LLMRayActor", request: PromptRequest) -> None: | ||
| def add_request( | ||
| actor: "LLMRayActor", request: PromptRequest, generation_started_training_step: int | None = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we make this not be done? |
||
| ) -> None: | ||
| request_id = make_request_id(request) | ||
|
|
||
| sampling_params = request.generation_config.clone() | ||
|
|
@@ -488,6 +493,7 @@ def add_request(actor: "LLMRayActor", request: PromptRequest) -> None: | |
| "original_sampling_params": request.generation_config, | ||
| "prompt_token_ids": list(request.prompt), | ||
| "start_time": time.perf_counter(), | ||
| "generation_started_training_step": generation_started_training_step, | ||
| } | ||
|
|
||
| tokens_prompt = vllm.TokensPrompt(prompt_token_ids=request.prompt, cache_salt=request_id) | ||
|
|
@@ -550,6 +556,14 @@ def _init_queues(self, prompt_queue, results_queue, eval_results_queue, actor_ma | |
| self._last_should_stop_update = float("-inf") | ||
| self._should_stop_value = False | ||
|
|
||
| def _get_current_training_step(self) -> int | None: | ||
| """Fetch the learner's current training step for staleness accounting.""" | ||
| try: | ||
| return ray.get(self.actor_manager.get_current_training_step.remote()) | ||
| except Exception as exc: # pragma: no cover - log and fall back gracefully | ||
| logger.warning(f"Failed to fetch current training step from ActorManager: {exc}") | ||
| return None | ||
|
|
||
| def _init_executor(self) -> None: | ||
| max_workers = NUM_PREFETCH_WORKERS + (NUM_TOOL_WORKERS if self.tools else 0) | ||
| self.executor = futures.ThreadPoolExecutor(max_workers=max_workers) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the type is a bit complicated here. Can you make it list[int]?