Skip to content

Commit 3be6ae3

Browse files
Added a lock
1 parent 8fcf9c6 commit 3be6ae3

File tree

1 file changed

+33
-24
lines changed

1 file changed

+33
-24
lines changed

open_instruct/grpo_fast.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2465,6 +2465,7 @@ def weight_sync_thread(
24652465
policy_group: ModelGroup,
24662466
actor_manager: ActorManager,
24672467
weight_sync_metrics_Q: Queue,
2468+
params_lock: threading.Lock,
24682469
resume_training_step: int = 1,
24692470
):
24702471
"""Thread function that handles weight sync operations and actor manager coordination."""
@@ -2484,23 +2485,26 @@ def weight_sync_thread(
24842485
logger.debug("[Weight Sync Thread] Starting weight sync")
24852486

24862487
# Set actors to stop
2487-
ray.get(actor_manager.set_should_stop.remote(True))
2488-
logger.debug("[Weight Sync Thread] Set should_stop to True for weight sync")
2489-
2490-
# Broadcast weights to vLLM engines
2491-
# First get the futures
2492-
weight_broadcast_futures: list[ray.ObjectRef] = [m.broadcast_to_vllm.remote() for m in policy_group.models]
2493-
2494-
# Wait for all weight updates to complete and collect individual timings
2495-
_, actor_sync_times = ray_get_with_progress(
2496-
weight_broadcast_futures,
2497-
desc="[Weight Sync Thread] Waiting for weight updates to complete",
2498-
enable=args.verbose,
2499-
)
2488+
with params_lock:
2489+
ray.get(actor_manager.set_should_stop.remote(True))
2490+
logger.debug("[Weight Sync Thread] Set should_stop to True for weight sync")
2491+
2492+
# Broadcast weights to vLLM engines
2493+
# First get the futures
2494+
weight_broadcast_futures: list[ray.ObjectRef] = [
2495+
m.broadcast_to_vllm.remote() for m in policy_group.models
2496+
]
2497+
2498+
# Wait for all weight updates to complete and collect individual timings
2499+
_, actor_sync_times = ray_get_with_progress(
2500+
weight_broadcast_futures,
2501+
desc="[Weight Sync Thread] Waiting for weight updates to complete",
2502+
enable=args.verbose,
2503+
)
25002504

2501-
# Allow actors to resume
2502-
ray.get(actor_manager.set_should_stop.remote(False))
2503-
logger.debug("[Weight Sync Thread] Set should_stop to False after weight sync")
2505+
# Allow actors to resume
2506+
ray.get(actor_manager.set_should_stop.remote(False))
2507+
logger.debug("[Weight Sync Thread] Set should_stop to False after weight sync")
25042508

25052509
# Calculate distribution statistics
25062510
sync_time_stats = {
@@ -2946,6 +2950,7 @@ def run_training(
29462950
model_dims: utils.ModelDims,
29472951
checkpoint_state=None,
29482952
):
2953+
params_lock = threading.Lock()
29492954
if resume_training_step > 1:
29502955
logger.info(f"[Main Thread] Resuming training from step {resume_training_step}")
29512956

@@ -2959,6 +2964,7 @@ def run_training(
29592964
policy_group,
29602965
actor_manager,
29612966
weight_sync_metrics_Q,
2967+
params_lock,
29622968
resume_training_step,
29632969
)
29642970

@@ -3117,14 +3123,17 @@ def health_check_fn():
31173123
if iter_dataloader is not None:
31183124
client_state["shuffling_iterator_state"] = iter_dataloader.get_state()
31193125

3120-
ray_get_with_progress(
3121-
[
3122-
policy_group.models[i].save_checkpoint_state.remote(args.checkpoint_state_dir, client_state)
3123-
for i in range(args.world_size)
3124-
],
3125-
desc=f"Saving checkpoint state at step {training_step}",
3126-
)
3127-
logger.info(f"Saved checkpoint state at step {training_step} to {args.checkpoint_state_dir}")
3126+
with params_lock:
3127+
ray_get_with_progress(
3128+
[
3129+
policy_group.models[i].save_checkpoint_state.remote(
3130+
args.checkpoint_state_dir, client_state
3131+
)
3132+
for i in range(args.world_size)
3133+
],
3134+
desc=f"Saving checkpoint state at step {training_step}",
3135+
)
3136+
logger.info(f"Saved checkpoint state at step {training_step} to {args.checkpoint_state_dir}")
31283137

31293138
maybe_evaluate(
31303139
args,

0 commit comments

Comments
 (0)