diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index cac12986e..382700d38 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1514,44 +1514,36 @@ class PendingQueriesMap: """Thread-safe map for tracking pending queries with reference counting.""" def __init__(self): - self._map = {} # dataset_idx -> (query, ground_truth, dataset, count) + # dataset_idx -> [data, count] + self._map: dict[int, list[Any]] = {} self._lock = threading.Lock() - def insert(self, dataset_idx, query, ground_truth, dataset, raw_query): + def insert(self, dataset_idx: int, data: dict[str, Any]) -> None: """Insert or increment count for a dataset index.""" with self._lock: if dataset_idx in self._map: # Already exists - just increment count - existing_query, existing_ground_truth, existing_dataset, existing_raw_query, count = self._map[ - dataset_idx - ] - self._map[dataset_idx] = ( - existing_query, - existing_ground_truth, - existing_dataset, - existing_raw_query, - count + 1, - ) + self._map[dataset_idx][1] += 1 else: # New entry - count starts at 1 - self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, 1) + self._map[dataset_idx] = [data.copy(), 1] - def pop(self, dataset_idx): + def pop(self, dataset_idx: int) -> dict[str, Any]: """Retrieve data and decrement count. Removes entry when count reaches 0.""" with self._lock: if dataset_idx not in self._map: raise RuntimeError(f"Dataset index {dataset_idx} not found in pending_queries_map") - query, ground_truth, dataset, raw_query, count = self._map[dataset_idx] + data, count = self._map[dataset_idx] if count > 1: # More results expected - just decrement - self._map[dataset_idx] = (query, ground_truth, dataset, raw_query, count - 1) + self._map[dataset_idx][1] -= 1 else: # Last result - remove entry del self._map[dataset_idx] - return query, ground_truth, dataset, raw_query + return data.copy() def __len__(self): """Return the number of entries in the map.""" @@ -1730,7 +1722,7 @@ def accumulate_inference_batches( f"Dataset index: {result.dataset_index}, Epoch: {result.epoch_number}" ) - query, ground_truth, dataset_name, raw_query = pending_queries_map.pop(result.dataset_index) + pending_data = pending_queries_map.pop(result.dataset_index) # Replenish generation queue with new prompt if replenish_prompts: @@ -1756,10 +1748,10 @@ def accumulate_inference_batches( decoded_responses = tokenizer.batch_decode(result.responses, skip_special_tokens=True) # TODO(finbarrtimbers): Make PendingQueriesMap.pop return a Batch, and add a Batch.repeat method. - k_queries = repeat_each([query], generation_config.n) - k_ground_truths = repeat_each([ground_truth], generation_config.n) - k_datasets = repeat_each([dataset_name], generation_config.n) - k_raw_queries = repeat_each([raw_query], generation_config.n) + k_queries = repeat_each([pending_data["query"]], generation_config.n) + k_ground_truths = repeat_each([pending_data["ground_truth"]], generation_config.n) + k_datasets = repeat_each([pending_data["dataset"]], generation_config.n) + k_raw_queries = repeat_each([pending_data["raw_query"]], generation_config.n) scores, reward_metrics = asyncio.run( reward_fn( @@ -1931,7 +1923,7 @@ def data_preparation_thread( inference_results_Q: ray_queue.Queue, # Ray queue param_prompt_Q: ray_queue.Queue, packed_sequences_Q: Queue, - pending_queries_map: dict, + pending_queries_map: PendingQueriesMap, args: Args, tokenizer: PreTrainedTokenizer, num_training_steps: int, @@ -2401,15 +2393,19 @@ def add_prompt_to_generator( is_eval: bool, ) -> None: """Split a batch into multiple inference batches and insert individual prompts into queues and mapping.""" - query = example[INPUT_IDS_PROMPT_KEY] - ground_truth = example[GROUND_TRUTHS_KEY] - dataset_name = example[VERIFIER_SOURCE_KEY] - raw_query = example[RAW_PROMPT_KEY] - pending_queries_map.insert(example_index, query, ground_truth, dataset_name, raw_query) + pending_queries_map.insert( + example_index, + { + "query": example[INPUT_IDS_PROMPT_KEY], + "ground_truth": example[GROUND_TRUTHS_KEY], + "dataset": example[VERIFIER_SOURCE_KEY], + "raw_query": example[RAW_PROMPT_KEY], + }, + ) param_prompt_Q.put( PromptRequest( - prompt=query, + prompt=example[INPUT_IDS_PROMPT_KEY], generation_config=generation_config, epoch_number=epoch_number, training_step=training_step, diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index 17268253f..1afee52e9 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -262,6 +262,14 @@ def setup_and_add_prompts_to_generator( return param_prompt_Q, inference_results_Q, pending_queries_map + def create_pending_data(self, query, ground_truth, dataset, raw_query): + return { + "query": query, + "ground_truth": ground_truth, + "dataset": dataset, + "raw_query": raw_query, + } + class TestGrpoFastVLLM(TestGrpoFastBase): def test_vllm_queue_system_single_prompt(self): @@ -379,13 +387,13 @@ def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int, dataset_index = result.dataset_index # Get query from pending_queries_map - q, gt, d, raw_q = pending_queries_map.pop(dataset_index) + pending_data = pending_queries_map.pop(dataset_index) combined_responses.extend(result.responses) - combined_queries.append(q) - combined_raw_queries.append(raw_q) - combined_ground_truths.append(gt) - combined_datasets.append(d) + combined_queries.append(pending_data["query"]) + combined_raw_queries.append(pending_data["raw_query"]) + combined_ground_truths.append(pending_data["ground_truth"]) + combined_datasets.append(pending_data["dataset"]) combined_result = GenerationResult( responses=combined_responses, @@ -452,11 +460,11 @@ def test_dataset_index_preservation_through_pipeline(self): result = inference_results_Q.get() dataset_index = result.dataset_index - q, gt, d, raw_q = pending_queries_map.pop(dataset_index) - combined_queries.append(q) - combined_raw_queries.append(raw_q) - combined_ground_truths.append(gt) - combined_datasets.append(d) + pending_data = pending_queries_map.pop(dataset_index) + combined_queries.append(pending_data["query"]) + combined_raw_queries.append(pending_data["raw_query"]) + combined_ground_truths.append(pending_data["ground_truth"]) + combined_datasets.append(pending_data["dataset"]) # Verify results self.assertEqual(combined_queries, queries_next) @@ -485,7 +493,10 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe for idx, query, ground_truth, dataset, raw_query in zip( dataset_indices, queries_next, ground_truths_next, datasets_next, raw_queries_next ): - pending_queries_map.insert(idx, query, ground_truth, dataset, raw_query) + pending_queries_map.insert( + idx, + self.create_pending_data(query, ground_truth, dataset, raw_query), + ) # Simulate vLLM processing with multiple samples batch_idx = 0 @@ -507,16 +518,16 @@ def test_multiple_samples_per_prompt(self, vllm_num_engines: int, num_samples_pe dataset_index = result.dataset_index # Pop the query data for this specific result - pop multiple times for multiple samples - q, gt, d, raw_q = pending_queries_map.pop(dataset_index) + pending_data = pending_queries_map.pop(dataset_index) # Pop additional times to handle multiple samples per prompt for _ in range(num_samples_per_prompt - 1): pending_queries_map.pop(dataset_index) combined_responses.extend(result.responses) - combined_queries.append(q) - combined_raw_queries.append(raw_q) - combined_ground_truths.append(gt) - combined_datasets.append(d) + combined_queries.append(pending_data["query"]) + combined_raw_queries.append(pending_data["raw_query"]) + combined_ground_truths.append(pending_data["ground_truth"]) + combined_datasets.append(pending_data["dataset"]) combined_result = GenerationResult( responses=combined_responses, @@ -647,10 +658,12 @@ def add_and_remove_entries(thread_id): for i in range(start_idx, start_idx + entries_per_thread): pending_queries_map.insert( i, - f"query_{thread_id}_{i}", - f"truth_{thread_id}_{i}", - f"dataset_{thread_id}_{i}", - f"query_{thread_id}_{i}", + self.create_pending_data( + f"query_{thread_id}_{i}", + f"truth_{thread_id}_{i}", + f"dataset_{thread_id}_{i}", + f"query_{thread_id}_{i}", + ), ) time.sleep(0.0001) @@ -696,7 +709,10 @@ def test_accumulate_waits_for_all_engines(self): # Add entries to map for i in range(num_prompts): - pending_queries_map.insert(i, f"q_{i}", f"t_{i}", f"d_{i}", f"q_{i}") + pending_queries_map.insert( + i, + self.create_pending_data(f"q_{i}", f"t_{i}", f"d_{i}", f"q_{i}"), + ) # Add results from only 3 engines (missing one) # With individual prompts, we add individual results @@ -863,7 +879,10 @@ def test_streaming_accumulation_basic(self): # Insert data into pending_queries_map for i in range(num_prompts): - pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i], raw_queries[i]) + pending_queries_map.insert( + i, + self.create_pending_data(queries[i], ground_truths[i], datasets[i], raw_queries[i]), + ) # Create mock results - one per prompt for i in range(num_prompts): @@ -882,8 +901,8 @@ def test_streaming_accumulation_basic(self): # Get query for this prompt dataset_index = result.dataset_index - q, gt, d, raw_q = pending_queries_map.pop(dataset_index) - queries_list.append((q, gt, d, raw_q)) + pending_data = pending_queries_map.pop(dataset_index) + queries_list.append(pending_data) # Verify all results processed self.assertEqual(len(results_list), expected_results) @@ -892,8 +911,7 @@ def test_streaming_accumulation_basic(self): # Combine in order combined_queries = [] for i in range(num_prompts): - q, _, _, _ = queries_list[i] - combined_queries.append(q) + combined_queries.append(queries_list[i]["query"]) # Verify order is preserved self.assertEqual(combined_queries, queries) @@ -916,7 +934,12 @@ def test_streaming_with_multiple_samples(self): # Insert data with reference counting for multiple samples for i in range(num_prompts): for _ in range(num_samples): - pending_queries_map.insert(i, queries[i], ground_truths[i], datasets[i], raw_queries[i]) + pending_queries_map.insert( + i, + self.create_pending_data( + queries[i], ground_truths[i], datasets[i], raw_queries[i] + ), + ) # Create results - one per prompt with multiple samples for i in range(num_prompts):