Skip to content

Commit 74ff4e5

Browse files
updated code
1 parent 626284f commit 74ff4e5

File tree

2 files changed

+69
-42
lines changed

2 files changed

+69
-42
lines changed

CLAUDE.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,7 @@
88
- To run the `./scripts/train/build_image_and_launch.sh` script, you must commit the current changes.
99
- Launch tool use experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/tool_grpo_fast.sh`.
1010
- Launch multi-node non-tool experiments by running `./scripts/train/build_image_and_launch.sh scripts/train/debug/large_test_script.sh`.
11+
12+
# Code Style
13+
- Always add type annotations to function signatures.
14+
- Always add docstrings to functions, using Google-style docstring format.

open_instruct/grpo_fast.py

Lines changed: 65 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2921,6 +2921,44 @@ def cleanup_training_resources(
29212921
logger.info("✅ Process group destroyed")
29222922

29232923

2924+
@Timer("[Main Thread] 🗡️ Saving checkpoint state")
2925+
def maybe_save_checkpoint(
2926+
args: Args,
2927+
policy_group: ModelGroup,
2928+
training_step: int,
2929+
episode: int,
2930+
num_total_tokens: int,
2931+
iter_dataloader: ShufflingIterator | None,
2932+
) -> None:
2933+
"""Save checkpoint state if checkpoint frequency conditions are met.
2934+
2935+
Args:
2936+
args: Training configuration arguments.
2937+
policy_group: Group of policy models to checkpoint.
2938+
training_step: Current training step number.
2939+
episode: Current episode count.
2940+
num_total_tokens: Total number of tokens processed.
2941+
iter_dataloader: Data iterator to save state from, or None.
2942+
"""
2943+
if not (
2944+
args.checkpoint_state_freq > 0
2945+
and training_step % args.checkpoint_state_freq == 0
2946+
and args.checkpoint_state_dir is not None
2947+
):
2948+
return
2949+
2950+
client_state = {"training_step": training_step, "episode": episode, "num_total_tokens": num_total_tokens}
2951+
2952+
if iter_dataloader is not None:
2953+
client_state["shuffling_iterator_state"] = iter_dataloader.get_state()
2954+
2955+
ray_get_with_progress(
2956+
[model.save_checkpoint_state.remote(args.checkpoint_state_dir, client_state) for model in policy_group.models],
2957+
desc=f"Saving checkpoint state at step {training_step}",
2958+
)
2959+
logger.info(f"Saved checkpoint state at step {training_step} to {args.checkpoint_state_dir}")
2960+
2961+
29242962
def run_training(
29252963
args,
29262964
tokenizer,
@@ -3098,52 +3136,37 @@ def health_check_fn():
30983136
iter_dataloader,
30993137
)
31003138

3101-
logger.debug(f"[Main Thread] Triggered weight sync for step {training_step}")
3102-
weight_sync_trigger_event.set()
3139+
async_futures = []
31033140

3104-
# Checkpoint after one_training_step (or even if it was skipped)
3105-
# This ensures we checkpoint progress even if the exact checkpoint step has no data
3106-
if (
3107-
args.checkpoint_state_freq > 0
3108-
and training_step % args.checkpoint_state_freq == 0
3109-
and args.checkpoint_state_dir is not None
3110-
):
3111-
with Timer("[Main Thread] 🗡️ Saving checkpoint state"):
3112-
# Save comprehensive client state including ShufflingIterator state
3113-
client_state = {
3114-
"training_step": training_step,
3115-
"episode": episode,
3116-
"num_total_tokens": num_total_tokens,
3117-
}
3118-
3119-
# Save ShufflingIterator state
3120-
if iter_dataloader is not None:
3121-
client_state["shuffling_iterator_state"] = iter_dataloader.get_state()
3122-
3123-
ray_get_with_progress(
3124-
[
3125-
policy_group.models[i].save_checkpoint_state.remote(args.checkpoint_state_dir, client_state)
3126-
for i in range(args.world_size)
3127-
],
3128-
desc=f"Saving checkpoint state at step {training_step}",
3129-
)
3130-
logger.info(f"Saved checkpoint state at step {training_step} to {args.checkpoint_state_dir}")
3141+
async_futures.append(
3142+
executor.submit(
3143+
maybe_save_checkpoint, args, policy_group, training_step, episode, num_total_tokens, iter_dataloader
3144+
)
3145+
)
31313146

3132-
maybe_evaluate(
3133-
args,
3134-
training_step,
3135-
evaluation_inference_results_Q,
3136-
tokenizer,
3137-
reward_fn,
3138-
episode,
3139-
eval_pending_queries_map,
3140-
generation_configs["eval"],
3141-
generate_metrics_Q,
3142-
len(eval_dataset) if eval_dataset else 0,
3143-
model_dims,
3144-
actor_manager,
3147+
async_futures.append(
3148+
executor.submit(
3149+
maybe_evaluate,
3150+
args,
3151+
training_step,
3152+
evaluation_inference_results_Q,
3153+
tokenizer,
3154+
reward_fn,
3155+
episode,
3156+
eval_pending_queries_map,
3157+
generation_configs["eval"],
3158+
generate_metrics_Q,
3159+
len(eval_dataset) if eval_dataset else 0,
3160+
model_dims,
3161+
actor_manager,
3162+
)
31453163
)
31463164

3165+
futures.wait(async_futures)
3166+
3167+
logger.debug(f"[Main Thread] Triggered weight sync for step {training_step}")
3168+
weight_sync_trigger_event.set()
3169+
31473170
if resume_training_step > args.num_training_steps:
31483171
raise ValueError(f"Training didn't run since {resume_training_step=} > {args.num_training_steps=}")
31493172

0 commit comments

Comments
 (0)