Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions open_instruct/actor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(self, queues: dict, args):
self._total_decode_tokens = 0
self._training_step_history = collections.deque(maxlen=self._sample_window)
self._generation_batch_history = collections.deque(maxlen=self._sample_window)
self._current_training_step = 0
self._kv_cache_max_concurrency = None
self._args = args
if self._args.enable_queue_dashboard:
Expand Down Expand Up @@ -171,6 +172,14 @@ def report_batch_generation_time(self, duration: float):
"""Report the time taken to generate a batch of data."""
self._generation_batch_history.append(duration)

def set_current_training_step(self, training_step: int):
"""Record the most recent training step observed by the learner."""
self._current_training_step = training_step

def get_current_training_step(self) -> int:
"""Return the latest training step recorded by the learner."""
return self._current_training_step

def set_kv_cache_max_concurrency(self, max_concurrency: int):
"""Set the KV cache max concurrency value."""
self._kv_cache_max_concurrency = max_concurrency
Expand Down
75 changes: 54 additions & 21 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,6 +1642,9 @@ class BatchStatistics:
percent_solved_mean: float
no_resampled_prompts: int
total_prompts: int
staleness_queue_steps: list[int | None] = field(default_factory=list)
staleness_generation_start_steps: list[int | None] = field(default_factory=list)
staleness_consumed_step: int | None = None


def accumulate_inference_batches(
Expand All @@ -1661,7 +1664,7 @@ def accumulate_inference_batches(
no_resampling_pass_rate: float | None = None,
iter_dataloader: ShufflingIterator | None = None,
prompt_dataset: Dataset = None,
param_prompt_Q: ray_queue.Queue | None = None,
prompt_Q: ray_queue.Queue | None = None,
training_step: int = None,
) -> tuple[GenerationResult, Batch, dict, BatchStatistics]:
"""Accumulate multiple inference results into a single training batch.
Expand All @@ -1679,7 +1682,7 @@ def accumulate_inference_batches(
no_resampling_pass_rate: Optional rate at which to note samples solved at greater than this rate
and exclude them from further sampling
iter_dataloader: Optional, used for no_resampling_pass_rate
param_prompt_Q: Queue containing prompts to send to generator, used to replenish used prompts
prompt_Q: Queue containing prompts to send to generator, used to replenish used prompts

Raises:
queue.Empty: If timeout is specified and no data is available within timeout.
Expand All @@ -1692,8 +1695,8 @@ def accumulate_inference_batches(
assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed"

if replenish_prompts:
assert param_prompt_Q is not None and iter_dataloader is not None and prompt_dataset is not None, (
"replenish_prompts requires param_prompt_Q and iter_dataloader and prompt_dataset"
assert prompt_Q is not None and iter_dataloader is not None and prompt_dataset is not None, (
"replenish_prompts requires prompt_Q and iter_dataloader and prompt_dataset"
)

results = []
Expand All @@ -1710,6 +1713,8 @@ def accumulate_inference_batches(
filtered_prompt_solved = 0
filtered_prompt_nonzero = 0
total_no_resampled = 0
staleness_queue_steps: list[int | None] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the type is a bit complicated here. Can you make it list[int]?

staleness_generation_start_steps: list[int | None] = []
progress_bar = tqdm(
total=num_prompts,
desc=f"Accumulating Responses and Rewarding {num_prompts} prompts",
Expand All @@ -1723,6 +1728,11 @@ def accumulate_inference_batches(
if isinstance(result, ShutdownSentinel):
return result, None, None, None

queued_step = result.queued_training_step
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inline these?

generation_started_step = result.generation_started_training_step
staleness_queue_steps.append(queued_step)
staleness_generation_start_steps.append(generation_started_step)

# Validate that each individual result has the expected number of responses
assert len(result.responses) == generation_config.n, (
f"Mismatch: individual prompt result has {len(result.responses)} responses "
Expand All @@ -1741,7 +1751,7 @@ def accumulate_inference_batches(
iter_dataloader.epoch_number,
training_step,
pending_queries_map,
param_prompt_Q,
prompt_Q,
generation_config,
is_eval=False,
)
Expand Down Expand Up @@ -1918,6 +1928,9 @@ def accumulate_inference_batches(
percent_solved_mean=percent_solved_mean,
no_resampled_prompts=total_no_resampled,
total_prompts=len(results),
staleness_queue_steps=staleness_queue_steps if staleness_queue_steps else [],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If staleness_queue_steps is Falsey, then it's [], right? So we can remove the ternary

staleness_generation_start_steps=staleness_generation_start_steps if staleness_generation_start_steps else [],
Comment on lines +1931 to +1932
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The conditional assignments here are redundant. The variables staleness_queue_steps and staleness_generation_start_steps are initialized as empty lists. If they remain empty, the expression ... if [] else [] evaluates to []. If they are populated, the expression ... if [1, 2] else [] evaluates to [1, 2]. In both cases, the result is the same as the original variable. You can simplify this by directly assigning the variables.

Suggested change
staleness_queue_steps=staleness_queue_steps if staleness_queue_steps else [],
staleness_generation_start_steps=staleness_generation_start_steps if staleness_generation_start_steps else [],
staleness_queue_steps=staleness_queue_steps,
staleness_generation_start_steps=staleness_generation_start_steps,

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here about ternary

staleness_consumed_step=training_step,
)
logging.info(
f"[Data Preparation Thread] Calculating rewards took {combined_reward_metrics['time/reward']} seconds"
Expand All @@ -1929,7 +1942,7 @@ def accumulate_inference_batches(
def data_preparation_thread(
reward_fn: Callable,
inference_results_Q: ray_queue.Queue, # Ray queue
param_prompt_Q: ray_queue.Queue,
prompt_Q: ray_queue.Queue,
packed_sequences_Q: Queue,
pending_queries_map: dict,
args: Args,
Expand Down Expand Up @@ -1961,7 +1974,7 @@ def data_preparation_thread(
no_resampling_pass_rate=args.no_resampling_pass_rate,
iter_dataloader=iter_dataloader,
prompt_dataset=train_dataset,
param_prompt_Q=param_prompt_Q,
prompt_Q=prompt_Q,
training_step=training_step,
)
if isinstance(result, ShutdownSentinel):
Expand Down Expand Up @@ -2081,8 +2094,26 @@ def data_preparation_thread(
)

batch_metrics = asdict(batch_stats)
batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()}

# calculate staleness metrics
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this a function please! staleness_metrics = get_staleness_metrics(...)

queue_steps = batch_stats.pop("staleness_queue_steps", [])
generation_steps = batch_stats.pop("staleness_generation_start_steps", [])
consumed_steps = batch_stats.pop("staleness_consumed_step", [])

queue_to_generation = []
generation_to_consume = []
for queue_step, generation_step, consume_step in zip(queue_steps, generation_steps, consumed_steps):
queue_to_generation.append(generation_step - queue_step)
generation_to_consume.append(consume_step - generation_step)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Zip Error: Single Value Not Iterable

The staleness_consumed_step field in BatchStatistics is a single int | None value, but the code tries to zip() it with two lists (queue_steps and generation_steps). This causes a runtime error because zip() expects an iterable, not a single integer. The consumed step should be repeated for each prompt or the code should be restructured to handle the single value properly.

Fix in Cursor Fix in Web

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Staleness Calculation: Robustness Against Missing Data

The staleness calculation performs arithmetic on queue_step, generation_step, and consume_step values that can be None. When _get_current_training_step() fails or when metadata keys are missing, these values default to None, causing TypeError: unsupported operand type(s) for -: 'NoneType' and 'int' when attempting subtraction. The code needs to handle None values before performing arithmetic operations.

Fix in Cursor Fix in Web


staleness_metrics = {
"staleness/queue_to_generation": queue_to_generation,
"staleness/queue_to_generation_mean": np.mean(queue_to_generation),
"staleness/generation_to_consume": generation_to_consume,
"staleness/generation_to_consume_mean": np.mean(generation_to_consume),
}
Comment on lines +2098 to +2114
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There are a few issues in this block that will cause runtime errors and incorrect behavior:

  1. AttributeError: You are calling .pop() on batch_stats, which is a dataclass instance and does not have a pop method. You should be using batch_metrics, which is the dictionary created from asdict(batch_stats).
  2. TypeError: batch_metrics.pop("staleness_consumed_step", []) will return an integer. The subsequent zip call attempts to iterate over this integer (consumed_steps), which will raise a TypeError. The consumed step is the same for all items in the batch and should be treated as a scalar.
  3. None values: The queue_steps and generation_steps lists can contain None values. The subtraction generation_step - queue_step will raise a TypeError if either value is None. These cases should be handled.
  4. RuntimeWarning from np.mean: If queue_to_generation or generation_to_consume is an empty list (e.g., if all steps were None), np.mean will return NaN and raise a RuntimeWarning. It's better to handle this case and return 0.0.

I've provided a suggestion that fixes these issues by using batch_metrics, handling the scalar consumed_step correctly, checking for None before subtraction, and safely calculating the mean.

Suggested change
# calculate staleness metrics
queue_steps = batch_stats.pop("staleness_queue_steps", [])
generation_steps = batch_stats.pop("staleness_generation_start_steps", [])
consumed_steps = batch_stats.pop("staleness_consumed_step", [])
queue_to_generation = []
generation_to_consume = []
for queue_step, generation_step, consume_step in zip(queue_steps, generation_steps, consumed_steps):
queue_to_generation.append(generation_step - queue_step)
generation_to_consume.append(consume_step - generation_step)
staleness_metrics = {
"staleness/queue_to_generation": queue_to_generation,
"staleness/queue_to_generation_mean": np.mean(queue_to_generation),
"staleness/generation_to_consume": generation_to_consume,
"staleness/generation_to_consume_mean": np.mean(generation_to_consume),
}
# calculate staleness metrics
queue_steps = batch_metrics.pop("staleness_queue_steps", [])
generation_steps = batch_metrics.pop("staleness_generation_start_steps", [])
consumed_step = batch_metrics.pop("staleness_consumed_step", None)
queue_to_generation = []
generation_to_consume = []
if consumed_step is not None:
for q_step, g_step in zip(queue_steps, generation_steps):
if q_step is not None and g_step is not None:
queue_to_generation.append(g_step - q_step)
generation_to_consume.append(consumed_step - g_step)
staleness_metrics = {
"staleness/queue_to_generation": queue_to_generation,
"staleness/queue_to_generation_mean": np.mean(queue_to_generation) if queue_to_generation else 0.0,
"staleness/generation_to_consume": generation_to_consume,
"staleness/generation_to_consume_mean": np.mean(generation_to_consume) if generation_to_consume else 0.0,
}


batch_metrics_prefixed = {f"batch/{k}": v for k, v in batch_metrics.items()}
metrics = {
"scores": scores.mean(),
"real_batch_size_ratio": real_num_responses / expected_num_responses,
Expand Down Expand Up @@ -2115,6 +2146,7 @@ def data_preparation_thread(
"time/getting_response": getting_response_time,
**reward_metrics,
**batch_metrics_prefixed,
**staleness_metrics,
}

total_tokens = result.token_statistics.num_prompt_tokens + result.token_statistics.num_response_tokens
Expand Down Expand Up @@ -2263,7 +2295,7 @@ def create_model_and_optimizer(
wandb_url: str,
tokenizer: PreTrainedTokenizer,
inference_results_Q: ray_queue.Queue,
param_prompt_Q: ray_queue.Queue,
prompt_Q: ray_queue.Queue,
evaluation_inference_results_Q: ray_queue.Queue,
) -> tuple[ModelGroup, list[vllm_utils.LLMRayActor], dict, int, int]:
"""Create the model, optimizer, and vLLM engines."""
Expand Down Expand Up @@ -2308,7 +2340,7 @@ def create_model_and_optimizer(

queues_to_monitor = {
"Inference Results Queue": inference_results_Q,
"Param Prompt Queue": param_prompt_Q,
"Param Prompt Queue": prompt_Q,
"Evaluation Queue": evaluation_inference_results_Q,
}
actor_manager = ray.remote(ActorManager).remote(queues_to_monitor, args)
Expand All @@ -2329,7 +2361,7 @@ def create_model_and_optimizer(
pg=pg if args.single_gpu_mode else None,
tools=tool_objects,
max_tool_calls=args.max_tool_calls,
prompt_queue=param_prompt_Q,
prompt_queue=prompt_Q,
results_queue=inference_results_Q,
eval_results_queue=evaluation_inference_results_Q,
actor_manager=actor_manager,
Expand Down Expand Up @@ -2396,7 +2428,7 @@ def add_prompt_to_generator(
epoch_number: int,
training_step: int,
pending_queries_map: PendingQueriesMap,
param_prompt_Q: ray_queue.Queue,
prompt_Q: ray_queue.Queue,
generation_config,
is_eval: bool,
) -> None:
Expand All @@ -2407,7 +2439,7 @@ def add_prompt_to_generator(
raw_query = example[RAW_PROMPT_KEY]
pending_queries_map.insert(example_index, query, ground_truth, dataset_name, raw_query)

param_prompt_Q.put(
prompt_Q.put(
PromptRequest(
prompt=query,
generation_config=generation_config,
Expand Down Expand Up @@ -2935,7 +2967,7 @@ def run_training(
stop_event,
executor,
inference_results_Q,
param_prompt_Q,
prompt_Q,
evaluation_inference_results_Q,
packed_sequences_Q,
pending_queries_map,
Expand Down Expand Up @@ -2972,7 +3004,7 @@ def run_training(
data_preparation_thread,
reward_fn,
inference_results_Q,
param_prompt_Q,
prompt_Q,
packed_sequences_Q,
pending_queries_map,
args,
Expand Down Expand Up @@ -3003,7 +3035,7 @@ def health_check_fn():
iter_dataloader.epoch_number,
resume_training_step,
pending_queries_map,
param_prompt_Q,
prompt_Q,
generation_configs["train"],
is_eval=False,
)
Expand All @@ -3015,6 +3047,7 @@ def health_check_fn():

training_start_time = time.perf_counter() # Track overall training start time
for training_step in range(resume_training_step, args.num_training_steps + 1):
actor_manager.set_current_training_step.remote(training_step)
start_time = time.perf_counter()

if (
Expand Down Expand Up @@ -3056,7 +3089,7 @@ def health_check_fn():
iter_dataloader.epoch_number,
training_step,
eval_pending_queries_map,
param_prompt_Q,
prompt_Q,
generation_configs["eval"],
is_eval=True,
)
Expand Down Expand Up @@ -3178,7 +3211,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
# all prompts from async_steps + 1 training steps
queue_size = (args.async_steps + 1) * args.num_unique_prompts_rollout
inference_results_Q = ray_queue.Queue(maxsize=queue_size)
param_prompt_Q = ray_queue.Queue(maxsize=queue_size)
prompt_Q = ray_queue.Queue(maxsize=queue_size)
# We don't care if we ever hit the max, so we let the queue be unbounded.
evaluation_inference_results_Q = ray_queue.Queue()

Expand All @@ -3191,7 +3224,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
wandb_url,
tokenizer,
inference_results_Q,
param_prompt_Q,
prompt_Q,
evaluation_inference_results_Q,
)
)
Expand Down Expand Up @@ -3249,7 +3282,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
stop_event,
executor,
inference_results_Q,
param_prompt_Q,
prompt_Q,
evaluation_inference_results_Q,
packed_sequences_Q,
pending_queries_map,
Expand All @@ -3266,7 +3299,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
raise
finally:
cleanup_training_resources(
stop_event, executor, [inference_results_Q, param_prompt_Q, evaluation_inference_results_Q], actor_manager
stop_event, executor, [inference_results_Q, prompt_Q, evaluation_inference_results_Q], actor_manager
)

# Ai2 logic: we use /output to store the artifacts of the job, so we
Expand Down
2 changes: 2 additions & 0 deletions open_instruct/queue_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class GenerationResult:
token_statistics: TokenStatistics | None = None
start_time: float | None = None
logprobs: list[list[float]] | None = None
queued_training_step: int | None = None
generation_started_training_step: int | None = None


@dataclass
Expand Down
6 changes: 6 additions & 0 deletions open_instruct/test_vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def create_mock_logprobs(token_ids):
"dataset_index": 43039,
"epoch_number": 0,
"training_step": 1,
"generation_started_training_step": 2,
"prompt_token_ids": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"start_time": 1000.0,
}
Expand Down Expand Up @@ -107,6 +108,8 @@ def create_mock_logprobs(token_ids):
self.assertEqual(result.request_info.tool_outputs, ["result1", "result2"])
self.assertEqual(result.request_info.tool_runtimes, [0.5, 0.3])
self.assertEqual(result.request_info.tool_calleds, [True, True])
self.assertEqual(result.queued_training_step, 1)
self.assertEqual(result.generation_started_training_step, 2)

def test_process_outputs_without_tools(self):
"""Test that process_completed_request correctly handles outputs without tool attributes."""
Expand Down Expand Up @@ -143,6 +146,7 @@ def create_mock_logprobs(token_ids):
"dataset_index": 200,
"epoch_number": 0,
"training_step": 2,
"generation_started_training_step": 3,
"prompt_token_ids": [1, 2, 3, 4, 5],
"start_time": 2000.0,
}
Expand Down Expand Up @@ -181,6 +185,8 @@ def create_mock_logprobs(token_ids):
self.assertEqual(result.request_info.tool_outputs, ["", ""])
self.assertEqual(result.request_info.tool_runtimes, [0.0, 0.0])
self.assertEqual(result.request_info.tool_calleds, [False, False])
self.assertEqual(result.queued_training_step, 2)
self.assertEqual(result.generation_started_training_step, 3)


if __name__ == "__main__":
Expand Down
18 changes: 16 additions & 2 deletions open_instruct/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ def process_completed_request(request_id, outs, current_time, tools, request_met
),
start_time=metadata["start_time"],
logprobs=logprobs,
queued_training_step=metadata.get("training_step"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metadata["training_step"]?

generation_started_training_step=metadata.get("generation_started_training_step"),
)
return result, metadata["is_eval"]

Expand Down Expand Up @@ -470,10 +472,13 @@ def _prefetch_worker(actor: "LLMRayActor") -> None:
continue

request = actor.prompt_queue.get()
add_request(actor, request)
generation_started_training_step = actor._get_current_training_step()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just inline this if it's called once?

Also don't wrap in a try/except, let's make the actor manager have a default value

add_request(actor, request, generation_started_training_step)


def add_request(actor: "LLMRayActor", request: PromptRequest) -> None:
def add_request(
actor: "LLMRayActor", request: PromptRequest, generation_started_training_step: int | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this not be done?

) -> None:
request_id = make_request_id(request)

sampling_params = request.generation_config.clone()
Expand All @@ -488,6 +493,7 @@ def add_request(actor: "LLMRayActor", request: PromptRequest) -> None:
"original_sampling_params": request.generation_config,
"prompt_token_ids": list(request.prompt),
"start_time": time.perf_counter(),
"generation_started_training_step": generation_started_training_step,
}

tokens_prompt = vllm.TokensPrompt(prompt_token_ids=request.prompt, cache_salt=request_id)
Expand Down Expand Up @@ -550,6 +556,14 @@ def _init_queues(self, prompt_queue, results_queue, eval_results_queue, actor_ma
self._last_should_stop_update = float("-inf")
self._should_stop_value = False

def _get_current_training_step(self) -> int | None:
"""Fetch the learner's current training step for staleness accounting."""
try:
return ray.get(self.actor_manager.get_current_training_step.remote())
except Exception as exc: # pragma: no cover - log and fall back gracefully
logger.warning(f"Failed to fetch current training step from ActorManager: {exc}")
return None

def _init_executor(self) -> None:
max_workers = NUM_PREFETCH_WORKERS + (NUM_TOOL_WORKERS if self.tools else 0)
self.executor = futures.ThreadPoolExecutor(max_workers=max_workers)
Expand Down
Loading