Skip to content

Commit 2c1ce31

Browse files
authored
replenish prompts directly after accumulating (#1174)
* replenish prompts directly after accumulating * add back replenish prompts * fix dataset var * fix tests * another fix * removed space
1 parent 5c00225 commit 2c1ce31

File tree

2 files changed

+149
-134
lines changed

2 files changed

+149
-134
lines changed

open_instruct/grpo_fast.py

Lines changed: 83 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
import torch.utils.data
6767
import vllm
6868
import wandb
69+
from datasets import Dataset
6970
from huggingface_hub import HfApi
7071
from peft import PeftModel, get_peft_model_state_dict
7172
from ray.util import queue as ray_queue
@@ -539,20 +540,6 @@ def __post_init__(self):
539540
)
540541

541542

542-
def next_batch(dataset_indices: list[int], dataset: datasets.Dataset) -> Batch:
543-
"""Extract next batch of data based on indices."""
544-
data_next = dataset[dataset_indices]
545-
return Batch(
546-
queries=data_next[INPUT_IDS_PROMPT_KEY],
547-
ground_truths=data_next[GROUND_TRUTHS_KEY],
548-
datasets=data_next[VERIFIER_SOURCE_KEY],
549-
raw_queries=data_next[RAW_PROMPT_KEY],
550-
indices=dataset_indices,
551-
decoded_responses=None,
552-
scores=None,
553-
)
554-
555-
556543
def masked_mean(
557544
values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None
558545
) -> torch.Tensor:
@@ -673,7 +660,8 @@ def __init__(self, data: np.ndarray, batch_size: int, seed: int | None = None):
673660
def __iter__(self) -> Iterator[list[int]]:
674661
return self
675662

676-
def __next__(self) -> list[int]:
663+
def __next__(self) -> list[int] | int:
664+
"""Return a list of next indices or a single index if batch size is 1"""
677665
if self.index >= self.effective_size:
678666
self.index = 0
679667
self._update_effective_size()
@@ -682,6 +670,8 @@ def __next__(self) -> list[int]:
682670

683671
end_index = self.index + self.batch_size
684672
batch = self.data[self.index : end_index].tolist()
673+
if self.batch_size == 1:
674+
batch = batch[0]
685675
self.index = end_index
686676

687677
return batch
@@ -1672,8 +1662,12 @@ def accumulate_inference_batches(
16721662
timeout: float | None = None,
16731663
active_sampling: bool = False,
16741664
filter_zero_std_samples: bool = False,
1665+
replenish_prompts: bool = False,
16751666
no_resampling_pass_rate: float | None = None,
16761667
iter_dataloader: ShufflingIterator | None = None,
1668+
prompt_dataset: Dataset = None,
1669+
param_prompt_Q: ray_queue.Queue | None = None,
1670+
training_step: int = None,
16771671
) -> tuple[GenerationResult, Batch, dict, BatchStatistics]:
16781672
"""Accumulate multiple inference results into a single training batch.
16791673
@@ -1686,9 +1680,11 @@ def accumulate_inference_batches(
16861680
timeout: Optional timeout in seconds for queue get operations. If None, blocks indefinitely.
16871681
active_sampling: Whether to continue sampling until we have sampled num_prompts prompts with non-zero std
16881682
filter_zero_std_samples: Whether to filter samples with zero reward std
1683+
replenish_prompts: Add a prompt back onto the prompt_Q after receiving a finished result
16891684
no_resampling_pass_rate: Optional rate at which to note samples solved at greater than this rate
16901685
and exclude them from further sampling
16911686
iter_dataloader: Optional, used for no_resampling_pass_rate
1687+
param_prompt_Q: Queue containing prompts to send to generator, used to replenish used prompts
16921688
16931689
Raises:
16941690
queue.Empty: If timeout is specified and no data is available within timeout.
@@ -1697,6 +1693,14 @@ def accumulate_inference_batches(
16971693
Tuple of (combined_result, Batch with queries, ground_truths, datasets, prompt_lengths, response_lengths)
16981694
or (ShutdownSentinel, None, None, None) if shutdown signal received
16991695
"""
1696+
if no_resampling_pass_rate is not None:
1697+
assert iter_dataloader is not None, "no_resampling requires the iter_dataloader passed"
1698+
1699+
if replenish_prompts:
1700+
assert param_prompt_Q is not None and iter_dataloader is not None and prompt_dataset is not None, (
1701+
"replenish_prompts requires param_prompt_Q and iter_dataloader and prompt_dataset"
1702+
)
1703+
17001704
results = []
17011705
all_queries = []
17021706
all_ground_truths = []
@@ -1731,7 +1735,21 @@ def accumulate_inference_batches(
17311735
f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}"
17321736
)
17331737

1734-
query, ground_truth, dataset, raw_query = pending_queries_map.pop(result.dataset_index)
1738+
query, ground_truth, dataset_name, raw_query = pending_queries_map.pop(result.dataset_index)
1739+
1740+
# Replenish generation queue with new prompt
1741+
if replenish_prompts:
1742+
dataset_index = next(iter_dataloader)
1743+
add_prompt_to_generator(
1744+
prompt_dataset[dataset_index],
1745+
dataset_index,
1746+
iter_dataloader.epoch_number,
1747+
training_step,
1748+
pending_queries_map,
1749+
param_prompt_Q,
1750+
generation_config,
1751+
is_eval=False,
1752+
)
17351753

17361754
# TODO(finbarrtimbers): Move this to LLMRayActor.
17371755
for i in range(len(result.finish_reasons)):
@@ -1745,7 +1763,7 @@ def accumulate_inference_batches(
17451763
# TODO(finbarrtimbers): Make PendingQueriesMap.pop return a Batch, and add a Batch.repeat method.
17461764
k_queries = repeat_each([query], generation_config.n)
17471765
k_ground_truths = repeat_each([ground_truth], generation_config.n)
1748-
k_datasets = repeat_each([dataset], generation_config.n)
1766+
k_datasets = repeat_each([dataset_name], generation_config.n)
17491767
k_raw_queries = repeat_each([raw_query], generation_config.n)
17501768

17511769
scores, reward_metrics = asyncio.run(
@@ -1916,6 +1934,7 @@ def accumulate_inference_batches(
19161934
def data_preparation_thread(
19171935
reward_fn: Callable,
19181936
inference_results_Q: ray_queue.Queue, # Ray queue
1937+
param_prompt_Q: ray_queue.Queue,
19191938
packed_sequences_Q: Queue,
19201939
pending_queries_map: dict,
19211940
args: Args,
@@ -1924,6 +1943,7 @@ def data_preparation_thread(
19241943
generation_config,
19251944
resume_training_step: int,
19261945
iter_dataloader: ShufflingIterator,
1946+
train_dataset: Dataset,
19271947
actor_manager=None,
19281948
model_dims: utils.ModelDims = None,
19291949
):
@@ -1942,8 +1962,12 @@ def data_preparation_thread(
19421962
actor_manager=actor_manager,
19431963
active_sampling=args.active_sampling,
19441964
filter_zero_std_samples=args.filter_zero_std_samples,
1965+
replenish_prompts=True,
19451966
no_resampling_pass_rate=args.no_resampling_pass_rate,
19461967
iter_dataloader=iter_dataloader,
1968+
prompt_dataset=train_dataset,
1969+
param_prompt_Q=param_prompt_Q,
1970+
training_step=training_step,
19471971
)
19481972
if isinstance(result, ShutdownSentinel):
19491973
logger.info("[Data Preparation Thread] Received shutdown sentinel, exiting")
@@ -2366,8 +2390,9 @@ def create_generation_configs(args: Args):
23662390
return {"train": generation_config, "eval": eval_generation_config}
23672391

23682392

2369-
def split_and_insert_batch(
2370-
batch: Batch,
2393+
def add_prompt_to_generator(
2394+
example: dict[str, Any],
2395+
example_index: int,
23712396
epoch_number: int,
23722397
training_step: int,
23732398
pending_queries_map: PendingQueriesMap,
@@ -2376,20 +2401,22 @@ def split_and_insert_batch(
23762401
is_eval: bool,
23772402
) -> None:
23782403
"""Split a batch into multiple inference batches and insert individual prompts into queues and mapping."""
2379-
for idx, query, ground_truth, dataset, raw_query in zip(
2380-
batch.indices, batch.queries, batch.ground_truths, batch.datasets, batch.raw_queries
2381-
):
2382-
pending_queries_map.insert(idx, query, ground_truth, dataset, raw_query)
2383-
param_prompt_Q.put(
2384-
PromptRequest(
2385-
prompt=query,
2386-
generation_config=generation_config,
2387-
epoch_number=epoch_number,
2388-
training_step=training_step,
2389-
dataset_index=idx,
2390-
is_eval=is_eval,
2391-
)
2404+
query = example[INPUT_IDS_PROMPT_KEY]
2405+
ground_truth = example[GROUND_TRUTHS_KEY]
2406+
dataset_name = example[VERIFIER_SOURCE_KEY]
2407+
raw_query = example[RAW_PROMPT_KEY]
2408+
pending_queries_map.insert(example_index, query, ground_truth, dataset_name, raw_query)
2409+
2410+
param_prompt_Q.put(
2411+
PromptRequest(
2412+
prompt=query,
2413+
generation_config=generation_config,
2414+
epoch_number=epoch_number,
2415+
training_step=training_step,
2416+
dataset_index=example_index,
2417+
is_eval=is_eval,
23922418
)
2419+
)
23932420

23942421

23952422
def load_data_from_packing_thread(
@@ -2641,8 +2668,7 @@ def maybe_evaluate(
26412668
timeout=timeout,
26422669
active_sampling=False,
26432670
filter_zero_std_samples=False,
2644-
no_resampling_pass_rate=None,
2645-
iter_dataloader=None,
2671+
replenish_prompts=False,
26462672
)
26472673

26482674
logger.info("[Main Thread] 📊 Evaluation responses received")
@@ -2896,7 +2922,7 @@ def run_training(
28962922
args,
28972923
tokenizer,
28982924
train_dataset,
2899-
eval_batch,
2925+
eval_dataset,
29002926
policy_group,
29012927
vllm_engines,
29022928
generation_configs,
@@ -2946,6 +2972,7 @@ def run_training(
29462972
data_preparation_thread,
29472973
reward_fn,
29482974
inference_results_Q,
2975+
param_prompt_Q,
29492976
packed_sequences_Q,
29502977
pending_queries_map,
29512978
args,
@@ -2954,6 +2981,7 @@ def run_training(
29542981
generation_configs["train"],
29552982
resume_training_step,
29562983
iter_dataloader,
2984+
train_dataset,
29572985
actor_manager,
29582986
model_dims,
29592987
)
@@ -2967,11 +2995,11 @@ def health_check_fn():
29672995
)
29682996

29692997
# Send initial data to ensure we have a N-step offset.
2970-
for _ in range(args.async_steps):
2971-
dataset_indices = next(iter_dataloader)
2972-
batch = next_batch(dataset_indices, train_dataset)
2973-
split_and_insert_batch(
2974-
batch,
2998+
for _ in range(args.async_steps * args.num_unique_prompts_rollout):
2999+
dataset_index = next(iter_dataloader)
3000+
add_prompt_to_generator(
3001+
train_dataset[dataset_index],
3002+
dataset_index,
29753003
iter_dataloader.epoch_number,
29763004
resume_training_step,
29773005
pending_queries_map,
@@ -2985,7 +3013,6 @@ def health_check_fn():
29853013
else:
29863014
num_total_tokens = 0
29873015

2988-
num_prompts_to_refill = 0
29893016
training_start_time = time.perf_counter() # Track overall training start time
29903017
for training_step in range(resume_training_step, args.num_training_steps + 1):
29913018
start_time = time.perf_counter()
@@ -3017,35 +3044,22 @@ def health_check_fn():
30173044
num_filtered_prompts,
30183045
) = load_data_from_packing_thread(packed_sequences_Q, num_total_tokens, stop_event, health_check_fn)
30193046

3020-
num_prompts_to_refill += args.num_unique_prompts_rollout + num_filtered_prompts
3021-
3022-
while num_prompts_to_refill >= args.num_unique_prompts_rollout:
3023-
batch = next_batch(next(iter_dataloader), train_dataset)
3024-
split_and_insert_batch(
3025-
batch,
3026-
iter_dataloader.epoch_number,
3027-
training_step,
3028-
pending_queries_map,
3029-
param_prompt_Q,
3030-
generation_configs["train"],
3031-
is_eval=False,
3032-
)
3033-
num_prompts_to_refill -= args.num_unique_prompts_rollout
3034-
30353047
if (
30363048
training_step % args.local_eval_every == 0
3037-
and eval_batch is not None
3049+
and eval_dataset is not None
30383050
and (args.eval_on_step_0 or training_step > 1)
30393051
):
3040-
split_and_insert_batch(
3041-
eval_batch,
3042-
iter_dataloader.epoch_number,
3043-
training_step,
3044-
eval_pending_queries_map,
3045-
param_prompt_Q,
3046-
generation_configs["eval"],
3047-
is_eval=True,
3048-
)
3052+
for eval_index, eval_example in enumerate(eval_dataset):
3053+
add_prompt_to_generator(
3054+
eval_example,
3055+
eval_index,
3056+
iter_dataloader.epoch_number,
3057+
training_step,
3058+
eval_pending_queries_map,
3059+
param_prompt_Q,
3060+
generation_configs["eval"],
3061+
is_eval=True,
3062+
)
30493063
if collated_data is None:
30503064
continue
30513065

@@ -3122,7 +3136,7 @@ def health_check_fn():
31223136
eval_pending_queries_map,
31233137
generation_configs["eval"],
31243138
generate_metrics_Q,
3125-
len(eval_batch.queries) if eval_batch else 0,
3139+
len(eval_dataset) if eval_dataset else 0,
31263140
model_dims,
31273141
actor_manager,
31283142
)
@@ -3199,7 +3213,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
31993213
logger.info(f"Restored episode count: {episode}")
32003214

32013215
train_dataset_idxs = np.arange(len(train_dataset))
3202-
iter_dataloader = ShufflingIterator(train_dataset_idxs, args.num_unique_prompts_rollout, seed=args.seed)
3216+
iter_dataloader = ShufflingIterator(train_dataset_idxs, 1, seed=args.seed)
32033217

32043218
if checkpoint_state and "shuffling_iterator_state" in checkpoint_state:
32053219
iter_dataloader.set_state(checkpoint_state["shuffling_iterator_state"])
@@ -3212,11 +3226,6 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
32123226
generate_metrics_Q = Queue(maxsize=args.async_steps)
32133227
weight_sync_metrics_Q = Queue(maxsize=args.async_steps)
32143228

3215-
if eval_dataset is None:
3216-
eval_batch = None
3217-
else:
3218-
eval_dataset_indices = list(range(len(eval_dataset)))
3219-
eval_batch = next_batch(eval_dataset_indices, eval_dataset)
32203229
reward_fn = make_reward_fn(args)
32213230

32223231
stop_event = threading.Event()
@@ -3227,7 +3236,7 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
32273236
args,
32283237
tokenizer,
32293238
train_dataset,
3230-
eval_batch,
3239+
eval_dataset,
32313240
policy_group,
32323241
vllm_engines,
32333242
generation_configs,

0 commit comments

Comments
 (0)