@@ -2465,6 +2465,7 @@ def weight_sync_thread(
24652465 policy_group : ModelGroup ,
24662466 actor_manager : ActorManager ,
24672467 weight_sync_metrics_Q : Queue ,
2468+ params_lock : threading .Lock ,
24682469 resume_training_step : int = 1 ,
24692470):
24702471 """Thread function that handles weight sync operations and actor manager coordination."""
@@ -2484,23 +2485,26 @@ def weight_sync_thread(
24842485 logger .debug ("[Weight Sync Thread] Starting weight sync" )
24852486
24862487 # Set actors to stop
2487- ray .get (actor_manager .set_should_stop .remote (True ))
2488- logger .debug ("[Weight Sync Thread] Set should_stop to True for weight sync" )
2489-
2490- # Broadcast weights to vLLM engines
2491- # First get the futures
2492- weight_broadcast_futures : list [ray .ObjectRef ] = [m .broadcast_to_vllm .remote () for m in policy_group .models ]
2493-
2494- # Wait for all weight updates to complete and collect individual timings
2495- _ , actor_sync_times = ray_get_with_progress (
2496- weight_broadcast_futures ,
2497- desc = "[Weight Sync Thread] Waiting for weight updates to complete" ,
2498- enable = args .verbose ,
2499- )
2488+ with params_lock :
2489+ ray .get (actor_manager .set_should_stop .remote (True ))
2490+ logger .debug ("[Weight Sync Thread] Set should_stop to True for weight sync" )
2491+
2492+ # Broadcast weights to vLLM engines
2493+ # First get the futures
2494+ weight_broadcast_futures : list [ray .ObjectRef ] = [
2495+ m .broadcast_to_vllm .remote () for m in policy_group .models
2496+ ]
2497+
2498+ # Wait for all weight updates to complete and collect individual timings
2499+ _ , actor_sync_times = ray_get_with_progress (
2500+ weight_broadcast_futures ,
2501+ desc = "[Weight Sync Thread] Waiting for weight updates to complete" ,
2502+ enable = args .verbose ,
2503+ )
25002504
2501- # Allow actors to resume
2502- ray .get (actor_manager .set_should_stop .remote (False ))
2503- logger .debug ("[Weight Sync Thread] Set should_stop to False after weight sync" )
2505+ # Allow actors to resume
2506+ ray .get (actor_manager .set_should_stop .remote (False ))
2507+ logger .debug ("[Weight Sync Thread] Set should_stop to False after weight sync" )
25042508
25052509 # Calculate distribution statistics
25062510 sync_time_stats = {
@@ -2946,6 +2950,7 @@ def run_training(
29462950 model_dims : utils .ModelDims ,
29472951 checkpoint_state = None ,
29482952):
2953+ params_lock = threading .Lock ()
29492954 if resume_training_step > 1 :
29502955 logger .info (f"[Main Thread] Resuming training from step { resume_training_step } " )
29512956
@@ -2959,6 +2964,7 @@ def run_training(
29592964 policy_group ,
29602965 actor_manager ,
29612966 weight_sync_metrics_Q ,
2967+ params_lock ,
29622968 resume_training_step ,
29632969 )
29642970
@@ -3117,14 +3123,17 @@ def health_check_fn():
31173123 if iter_dataloader is not None :
31183124 client_state ["shuffling_iterator_state" ] = iter_dataloader .get_state ()
31193125
3120- ray_get_with_progress (
3121- [
3122- policy_group .models [i ].save_checkpoint_state .remote (args .checkpoint_state_dir , client_state )
3123- for i in range (args .world_size )
3124- ],
3125- desc = f"Saving checkpoint state at step { training_step } " ,
3126- )
3127- logger .info (f"Saved checkpoint state at step { training_step } to { args .checkpoint_state_dir } " )
3126+ with params_lock :
3127+ ray_get_with_progress (
3128+ [
3129+ policy_group .models [i ].save_checkpoint_state .remote (
3130+ args .checkpoint_state_dir , client_state
3131+ )
3132+ for i in range (args .world_size )
3133+ ],
3134+ desc = f"Saving checkpoint state at step { training_step } " ,
3135+ )
3136+ logger .info (f"Saved checkpoint state at step { training_step } to { args .checkpoint_state_dir } " )
31283137
31293138 maybe_evaluate (
31303139 args ,
0 commit comments