@@ -2921,6 +2921,44 @@ def cleanup_training_resources(
29212921 logger .info ("✅ Process group destroyed" )
29222922
29232923
2924+ @Timer ("[Main Thread] 🗡️ Saving checkpoint state" )
2925+ def maybe_save_checkpoint (
2926+ args : Args ,
2927+ policy_group : ModelGroup ,
2928+ training_step : int ,
2929+ episode : int ,
2930+ num_total_tokens : int ,
2931+ iter_dataloader : ShufflingIterator | None ,
2932+ ) -> None :
2933+ """Save checkpoint state if checkpoint frequency conditions are met.
2934+
2935+ Args:
2936+ args: Training configuration arguments.
2937+ policy_group: Group of policy models to checkpoint.
2938+ training_step: Current training step number.
2939+ episode: Current episode count.
2940+ num_total_tokens: Total number of tokens processed.
2941+ iter_dataloader: Data iterator to save state from, or None.
2942+ """
2943+ if not (
2944+ args .checkpoint_state_freq > 0
2945+ and training_step % args .checkpoint_state_freq == 0
2946+ and args .checkpoint_state_dir is not None
2947+ ):
2948+ return
2949+
2950+ client_state = {"training_step" : training_step , "episode" : episode , "num_total_tokens" : num_total_tokens }
2951+
2952+ if iter_dataloader is not None :
2953+ client_state ["shuffling_iterator_state" ] = iter_dataloader .get_state ()
2954+
2955+ ray_get_with_progress (
2956+ [model .save_checkpoint_state .remote (args .checkpoint_state_dir , client_state ) for model in policy_group .models ],
2957+ desc = f"Saving checkpoint state at step { training_step } " ,
2958+ )
2959+ logger .info (f"Saved checkpoint state at step { training_step } to { args .checkpoint_state_dir } " )
2960+
2961+
29242962def run_training (
29252963 args ,
29262964 tokenizer ,
@@ -3098,52 +3136,37 @@ def health_check_fn():
30983136 iter_dataloader ,
30993137 )
31003138
3101- logger .debug (f"[Main Thread] Triggered weight sync for step { training_step } " )
3102- weight_sync_trigger_event .set ()
3139+ async_futures = []
31033140
3104- # Checkpoint after one_training_step (or even if it was skipped)
3105- # This ensures we checkpoint progress even if the exact checkpoint step has no data
3106- if (
3107- args .checkpoint_state_freq > 0
3108- and training_step % args .checkpoint_state_freq == 0
3109- and args .checkpoint_state_dir is not None
3110- ):
3111- with Timer ("[Main Thread] 🗡️ Saving checkpoint state" ):
3112- # Save comprehensive client state including ShufflingIterator state
3113- client_state = {
3114- "training_step" : training_step ,
3115- "episode" : episode ,
3116- "num_total_tokens" : num_total_tokens ,
3117- }
3118-
3119- # Save ShufflingIterator state
3120- if iter_dataloader is not None :
3121- client_state ["shuffling_iterator_state" ] = iter_dataloader .get_state ()
3122-
3123- ray_get_with_progress (
3124- [
3125- policy_group .models [i ].save_checkpoint_state .remote (args .checkpoint_state_dir , client_state )
3126- for i in range (args .world_size )
3127- ],
3128- desc = f"Saving checkpoint state at step { training_step } " ,
3129- )
3130- logger .info (f"Saved checkpoint state at step { training_step } to { args .checkpoint_state_dir } " )
3141+ async_futures .append (
3142+ executor .submit (
3143+ maybe_save_checkpoint , args , policy_group , training_step , episode , num_total_tokens , iter_dataloader
3144+ )
3145+ )
31313146
3132- maybe_evaluate (
3133- args ,
3134- training_step ,
3135- evaluation_inference_results_Q ,
3136- tokenizer ,
3137- reward_fn ,
3138- episode ,
3139- eval_pending_queries_map ,
3140- generation_configs ["eval" ],
3141- generate_metrics_Q ,
3142- len (eval_dataset ) if eval_dataset else 0 ,
3143- model_dims ,
3144- actor_manager ,
3147+ async_futures .append (
3148+ executor .submit (
3149+ maybe_evaluate ,
3150+ args ,
3151+ training_step ,
3152+ evaluation_inference_results_Q ,
3153+ tokenizer ,
3154+ reward_fn ,
3155+ episode ,
3156+ eval_pending_queries_map ,
3157+ generation_configs ["eval" ],
3158+ generate_metrics_Q ,
3159+ len (eval_dataset ) if eval_dataset else 0 ,
3160+ model_dims ,
3161+ actor_manager ,
3162+ )
31453163 )
31463164
3165+ futures .wait (async_futures )
3166+
3167+ logger .debug (f"[Main Thread] Triggered weight sync for step { training_step } " )
3168+ weight_sync_trigger_event .set ()
3169+
31473170 if resume_training_step > args .num_training_steps :
31483171 raise ValueError (f"Training didn't run since { resume_training_step = } > { args .num_training_steps = } " )
31493172
0 commit comments