Skip to content

Commit 413ef48

Browse files
rebased onto pad-out-32b
1 parent 6e3b6c8 commit 413ef48

File tree

1 file changed

+35
-26
lines changed

1 file changed

+35
-26
lines changed

open_instruct/grpo_fast.py

Lines changed: 35 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2320,6 +2320,7 @@ def weight_sync_thread(
23202320
policy_group: ModelGroup,
23212321
actor_manager: ActorManager,
23222322
weight_sync_metrics_Q: Queue,
2323+
params_lock: threading.Lock,
23232324
resume_training_step: int = 1,
23242325
):
23252326
"""Thread function that handles weight sync operations and actor manager coordination."""
@@ -2338,24 +2339,27 @@ def weight_sync_thread(
23382339
with Timer("[Weight Sync]") as timer:
23392340
logger.debug("[Weight Sync Thread] Starting weight sync")
23402341

2341-
# Set actors to stop
2342-
ray.get(actor_manager.set_should_stop.remote(True))
2343-
logger.debug("[Weight Sync Thread] Set should_stop to True for weight sync")
2342+
with params_lock:
2343+
# Set actors to stop
2344+
ray.get(actor_manager.set_should_stop.remote(True))
2345+
logger.debug("[Weight Sync Thread] Set should_stop to True for weight sync")
23442346

2345-
# Broadcast weights to vLLM engines
2346-
# First get the futures
2347-
weight_broadcast_futures: list[ray.ObjectRef] = [m.broadcast_to_vllm.remote() for m in policy_group.models]
2347+
# Broadcast weights to vLLM engines
2348+
# First get the futures
2349+
weight_broadcast_futures: list[ray.ObjectRef] = [
2350+
m.broadcast_to_vllm.remote() for m in policy_group.models
2351+
]
23482352

2349-
# Wait for all weight updates to complete and collect individual timings
2350-
_, actor_sync_times = ray_get_with_progress(
2351-
weight_broadcast_futures,
2352-
desc="[Weight Sync Thread] Waiting for weight updates to complete",
2353-
enable=args.verbose,
2354-
)
2353+
# Wait for all weight updates to complete and collect individual timings
2354+
_, actor_sync_times = ray_get_with_progress(
2355+
weight_broadcast_futures,
2356+
desc="[Weight Sync Thread] Waiting for weight updates to complete",
2357+
enable=args.verbose,
2358+
)
23552359

2356-
# Allow actors to resume
2357-
ray.get(actor_manager.set_should_stop.remote(False))
2358-
logger.debug("[Weight Sync Thread] Set should_stop to False after weight sync")
2360+
# Allow actors to resume
2361+
ray.get(actor_manager.set_should_stop.remote(False))
2362+
logger.debug("[Weight Sync Thread] Set should_stop to False after weight sync")
23592363

23602364
# Calculate distribution statistics
23612365
sync_time_stats = {
@@ -2812,6 +2816,7 @@ def run_training(
28122816

28132817
logger.info("======== ✅ weight sync thread starts =========")
28142818
weight_sync_trigger_event = threading.Event()
2819+
params_lock = threading.Lock()
28152820
weight_sync_thread_future = executor.submit(
28162821
weight_sync_thread,
28172822
args,
@@ -2820,6 +2825,7 @@ def run_training(
28202825
policy_group,
28212826
actor_manager,
28222827
weight_sync_metrics_Q,
2828+
params_lock,
28232829
resume_training_step,
28242830
)
28252831

@@ -2994,17 +3000,20 @@ def health_check_fn():
29943000
max_retries = 3
29953001
for attempt in range(max_retries):
29963002
try:
2997-
ray_get_with_progress(
2998-
[
2999-
policy_group.models[i].save_checkpoint_state.remote(
3000-
args.checkpoint_state_dir, client_state
3001-
)
3002-
for i in range(args.world_size)
3003-
],
3004-
desc=f"Saving checkpoint state at step {training_step}",
3005-
timeout=600,
3006-
)
3007-
logger.info(f"Saved checkpoint state at step {training_step} to {args.checkpoint_state_dir}")
3003+
with params_lock:
3004+
ray_get_with_progress(
3005+
[
3006+
policy_group.models[i].save_checkpoint_state.remote(
3007+
args.checkpoint_state_dir, client_state
3008+
)
3009+
for i in range(args.world_size)
3010+
],
3011+
desc=f"Saving checkpoint state at step {training_step}",
3012+
timeout=600,
3013+
)
3014+
logger.info(
3015+
f"Saved checkpoint state at step {training_step} to {args.checkpoint_state_dir}"
3016+
)
30083017
break
30093018
except Exception as e:
30103019
if attempt < max_retries - 1:

0 commit comments

Comments
 (0)