@@ -2320,6 +2320,7 @@ def weight_sync_thread(
23202320 policy_group : ModelGroup ,
23212321 actor_manager : ActorManager ,
23222322 weight_sync_metrics_Q : Queue ,
2323+ params_lock : threading .Lock ,
23232324 resume_training_step : int = 1 ,
23242325):
23252326 """Thread function that handles weight sync operations and actor manager coordination."""
@@ -2338,24 +2339,27 @@ def weight_sync_thread(
23382339 with Timer ("[Weight Sync]" ) as timer :
23392340 logger .debug ("[Weight Sync Thread] Starting weight sync" )
23402341
2341- # Set actors to stop
2342- ray .get (actor_manager .set_should_stop .remote (True ))
2343- logger .debug ("[Weight Sync Thread] Set should_stop to True for weight sync" )
2342+ with params_lock :
2343+ # Set actors to stop
2344+ ray .get (actor_manager .set_should_stop .remote (True ))
2345+ logger .debug ("[Weight Sync Thread] Set should_stop to True for weight sync" )
23442346
2345- # Broadcast weights to vLLM engines
2346- # First get the futures
2347- weight_broadcast_futures : list [ray .ObjectRef ] = [m .broadcast_to_vllm .remote () for m in policy_group .models ]
2347+ # Broadcast weights to vLLM engines
2348+ # First get the futures
2349+ weight_broadcast_futures : list [ray .ObjectRef ] = [
2350+ m .broadcast_to_vllm .remote () for m in policy_group .models
2351+ ]
23482352
2349- # Wait for all weight updates to complete and collect individual timings
2350- _ , actor_sync_times = ray_get_with_progress (
2351- weight_broadcast_futures ,
2352- desc = "[Weight Sync Thread] Waiting for weight updates to complete" ,
2353- enable = args .verbose ,
2354- )
2353+ # Wait for all weight updates to complete and collect individual timings
2354+ _ , actor_sync_times = ray_get_with_progress (
2355+ weight_broadcast_futures ,
2356+ desc = "[Weight Sync Thread] Waiting for weight updates to complete" ,
2357+ enable = args .verbose ,
2358+ )
23552359
2356- # Allow actors to resume
2357- ray .get (actor_manager .set_should_stop .remote (False ))
2358- logger .debug ("[Weight Sync Thread] Set should_stop to False after weight sync" )
2360+ # Allow actors to resume
2361+ ray .get (actor_manager .set_should_stop .remote (False ))
2362+ logger .debug ("[Weight Sync Thread] Set should_stop to False after weight sync" )
23592363
23602364 # Calculate distribution statistics
23612365 sync_time_stats = {
@@ -2812,6 +2816,7 @@ def run_training(
28122816
28132817 logger .info ("======== ✅ weight sync thread starts =========" )
28142818 weight_sync_trigger_event = threading .Event ()
2819+ params_lock = threading .Lock ()
28152820 weight_sync_thread_future = executor .submit (
28162821 weight_sync_thread ,
28172822 args ,
@@ -2820,6 +2825,7 @@ def run_training(
28202825 policy_group ,
28212826 actor_manager ,
28222827 weight_sync_metrics_Q ,
2828+ params_lock ,
28232829 resume_training_step ,
28242830 )
28252831
@@ -2994,17 +3000,20 @@ def health_check_fn():
29943000 max_retries = 3
29953001 for attempt in range (max_retries ):
29963002 try :
2997- ray_get_with_progress (
2998- [
2999- policy_group .models [i ].save_checkpoint_state .remote (
3000- args .checkpoint_state_dir , client_state
3001- )
3002- for i in range (args .world_size )
3003- ],
3004- desc = f"Saving checkpoint state at step { training_step } " ,
3005- timeout = 600 ,
3006- )
3007- logger .info (f"Saved checkpoint state at step { training_step } to { args .checkpoint_state_dir } " )
3003+ with params_lock :
3004+ ray_get_with_progress (
3005+ [
3006+ policy_group .models [i ].save_checkpoint_state .remote (
3007+ args .checkpoint_state_dir , client_state
3008+ )
3009+ for i in range (args .world_size )
3010+ ],
3011+ desc = f"Saving checkpoint state at step { training_step } " ,
3012+ timeout = 600 ,
3013+ )
3014+ logger .info (
3015+ f"Saved checkpoint state at step { training_step } to { args .checkpoint_state_dir } "
3016+ )
30083017 break
30093018 except Exception as e :
30103019 if attempt < max_retries - 1 :
0 commit comments