feat(training): wire FGO optimizer-state offload in train loop#4217
feat(training): wire FGO optimizer-state offload in train loop#4217dingqingy-nv wants to merge 4 commits into
Conversation
mcore's training.py invokes the OptimizerStateOffloader machinery via three call sites — Bridge's train.py forked before they landed and never picked them up. Result: cfg.optimizer.offload_optimizer_states silently no-ops in Bridge (offloader is constructed but never triggered), so master params stay on GPU through every iter. Verified on DSv4-Flash GB200: ~18 GB master params alive at iter-1 peak in Bridge that aren't alive in MLM on the same recipe (memory snapshot diff), which is the dominant contributor to Bridge OOMing on B200 (184 GiB cap) while MLM (153 GB peak) fits. This patch mirrors mcore training.py exactly: - train_step (start of forward_backward_func loop): offload_states() - train_step (after _copy_main_params_to_param_buffer): release_offloaded_gpu_states() - train (pre-loop): wrap finalize_model_grads_func with reload-then-finalize All three call sites are guarded by cfg.optimizer.offload_optimizer_states so the patch is a no-op for callers that don't opt in to FGO. Signed-off-by: Dingqing Yang <dingqingy@nvidia.com>
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
/claude review |
| if config.optimizer.offload_optimizer_states: | ||
|
|
||
| def finalize_model_grads_with_state_reload(*fmg_args, **fmg_kwargs): | ||
| for optim_instance in optimizer.chained_optimizers: | ||
| if isinstance(optim_instance, DistributedOptimizer): | ||
| optim_instance.reload_offloaded_states() | ||
| return finalize_model_grads(*fmg_args, **fmg_kwargs) | ||
|
|
||
| model_config.finalize_model_grads_func = finalize_model_grads_with_state_reload |
There was a problem hiding this comment.
Bug: this wrapper calls the module-level finalize_model_grads directly, but model_config.finalize_model_grads_func may already be set to partial(finalize_model_grads, pg_collection=pg_collection) by setup.py:443 (or to a MIMO-specific variant by setup_megatron_mimo.py:107).
Replacing it with a wrapper that calls the bare import discards the pg_collection binding (or the MIMO override), which would break grad finalization.
Save the existing func before overwriting:
| if config.optimizer.offload_optimizer_states: | |
| def finalize_model_grads_with_state_reload(*fmg_args, **fmg_kwargs): | |
| for optim_instance in optimizer.chained_optimizers: | |
| if isinstance(optim_instance, DistributedOptimizer): | |
| optim_instance.reload_offloaded_states() | |
| return finalize_model_grads(*fmg_args, **fmg_kwargs) | |
| model_config.finalize_model_grads_func = finalize_model_grads_with_state_reload | |
| if config.optimizer.offload_optimizer_states: | |
| _orig_finalize = model_config.finalize_model_grads_func | |
| def finalize_model_grads_with_state_reload(*fmg_args, **fmg_kwargs): | |
| for optim_instance in optimizer.chained_optimizers: | |
| if isinstance(optim_instance, DistributedOptimizer): | |
| optim_instance.reload_offloaded_states() | |
| return _orig_finalize(*fmg_args, **fmg_kwargs) | |
| model_config.finalize_model_grads_func = finalize_model_grads_with_state_reload |
This also makes the from megatron.core.distributed import finalize_model_grads import unnecessary (can be removed).
Light Code ReviewCritical bug
Test coverage There are no unit or functional tests covering the FGO code paths. Even a lightweight mock-based unit test that verifies Suggested test cases No perf tests impacted. |
…wrapper The FGO wrapper was calling the bare module-level finalize_model_grads, which discards bindings established earlier: - setup.py:443 binds finalize_model_grads_func to partial(finalize_model_grads, pg_collection=pg_collection) - setup_megatron_mimo.py:107 binds it to a MIMO-specific finalize_model_grads_multimodule Capture model_config.finalize_model_grads_func before overwriting and delegate to it, so both bindings (and any future ones) are preserved. Caught by claude bot review on PR NVIDIA-NeMo#4217. Signed-off-by: Dingqing Yang <dingqingy@nvidia.com>
|
Thanks for catching this — applied your suggestion in b1644eb. Captured |
|
/claude review |
| # FGO: wrap finalize_model_grads_func to async-reload offloaded optimizer | ||
| # states before grad finalize. H2D overlaps with grad all-reduce. Mirrors | ||
| # mcore training.py:3140 in intent, but DELEGATES to whatever | ||
| # finalize_model_grads_func was previously bound. setup.py:443 already | ||
| # binds it to partial(finalize_model_grads, pg_collection=pg_collection), | ||
| # and setup_megatron_mimo.py:107 binds it to a MIMO-specific | ||
| # finalize_model_grads_multimodule — calling the bare module-level | ||
| # finalize_model_grads here would drop those bindings. |
There was a problem hiding this comment.
Nit: these comments reference specific mcore line numbers (training.py:3140, training.py:2030, training.py:2073) that will drift as mcore evolves. Consider referencing the function/pattern name instead (e.g., "mirrors mcore's FGO offload pattern in train_step") so the comments don't go stale on the next mcore bump.
Light Code ReviewThe change is clean - pure glue code that correctly wires mcore's existing offload_states / reload_offloaded_states / release_offloaded_gpu_states APIs into Bridge's training loop. The three call sites follow the established chained_optimizers + isinstance(_, DistributedOptimizer) pattern already used elsewhere in train.py (lines 670, 1654). The delegation approach for finalize_model_grads_func (wrapping the existing callable rather than calling the bare module-level function) is the right call - it preserves the setup.py partial binding and the MIMO multimodule binding. Findings
No bugs or logic errors found The offload -> forward/backward (reload via finalize wrapper) -> release lifecycle is correct. The _orig_finalize closure capture is safe because setup.py:443 (or setup_megatron_mimo.py:107) always sets finalize_model_grads_func to a real callable before train() reaches the wrapping point. Suggested test cases No perf tests impacted. |
Replace 'mcore training.py:NNNN', 'setup.py:443', and 'setup_megatron_mimo.py:107' line-number references with the actual function names (_update_model_config_funcs, _update_megatron_mimo_model_config_funcs) and the mcore pattern names (FGO offload site, post-_copy_main_params_to_param_buffer release site). Line numbers drift on every refactor; function names are stable and grep-friendly. Addresses claude bot nit on PR NVIDIA-NeMo#4217. Signed-off-by: Dingqing Yang <dingqingy@nvidia.com>
|
Addressed the nit in 40bcec3 — line-number refs replaced with function names ( |
The field ships on mcore dev only (mcore PR NVIDIA-NeMo#2811 is still open against main). Bridge users tracking mcore main would AttributeError. Use getattr(..., 'offload_optimizer_states', False) at all three FGO sites so mcore-main users get a safe no-op fallback. Signed-off-by: Dingqing Yang <dingqingy@nvidia.com>
Dependency
MCore main NVIDIA/Megatron-LM#2811
MCore dev NVIDIA/Megatron-LM#2760
Summary
cfg.optimizer.offload_optimizer_states = True(fine-grained optimizer-state offload, FGO) is currently a silent no-op in Bridge. mcore'sOptimizerStateOffloaderis correctly constructed indistrib_optimizer.py:631when the flag is set, but Bridge'strain.pynever invokes the offload/reload/release methods on it — so master params and optimizer states stay on GPU through every iteration.This PR wires the three FGO call sites that mcore's
training.pyalready has, mirroring its pattern exactly. Pure glue code, no algorithm changes.What changes
3 call sites in
src/megatron/bridge/training/train.py, all guarded bycfg.optimizer.offload_optimizer_states:train_step()— start ofshould_run_forward_backwardloop bodyoptim_instance.offload_states()(async D2H)training.py:2030train_step()— after_handle_mxfp8_param_buffer_copy()optim_instance.release_offloaded_gpu_states()(frees GPU memory once D2H completes)training.py:2073train()— pre-loop setupmodel_config.finalize_model_grads_funcwithreload_offloaded_states()→ finalize (async H2D overlaps with grad all-reduce)training.py:314031 lines added, 0 removed. No change to public API or config schema.
Why MLM has this but Bridge doesn't
Bridge's
train.pyis a fork of mcore'straining.py. The FGO offload machinery (OptimizerStateOffloaderclass,offload_optimizer_statesconfig field) was added to mcore after Bridge's fork point. mcore got both the offloader and the training-loop calls; Bridge inherits the offloader (via the submodule) but the training-loop calls weren't propagated. This PR closes that gap.Test plan
Verified on DSv4-Flash on GB200 (
gb200partition, 16 nodes / 64 GPUs):pretrain_gpt.py(which has the wiring upstream): 19.55 s/iter, 153.76 GB max-allocated. Bridge with this patch is functionally equivalent.Test plan checklist
offload_optimizer_states=True— no OOM, 20 iters complete🤖 Generated with Claude Code