Skip to content

feat(training): wire FGO optimizer-state offload in train loop#4217

Open
dingqingy-nv wants to merge 4 commits into
NVIDIA-NeMo:mainfrom
dingqingy-nv:dingqingy/wire-fgo-optimizer-offload
Open

feat(training): wire FGO optimizer-state offload in train loop#4217
dingqingy-nv wants to merge 4 commits into
NVIDIA-NeMo:mainfrom
dingqingy-nv:dingqingy/wire-fgo-optimizer-offload

Conversation

@dingqingy-nv

@dingqingy-nv dingqingy-nv commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Dependency

MCore main NVIDIA/Megatron-LM#2811
MCore dev NVIDIA/Megatron-LM#2760

Summary

cfg.optimizer.offload_optimizer_states = True (fine-grained optimizer-state offload, FGO) is currently a silent no-op in Bridge. mcore's OptimizerStateOffloader is correctly constructed in distrib_optimizer.py:631 when the flag is set, but Bridge's train.py never invokes the offload/reload/release methods on it — so master params and optimizer states stay on GPU through every iteration.

This PR wires the three FGO call sites that mcore's training.py already has, mirroring its pattern exactly. Pure glue code, no algorithm changes.

What changes

3 call sites in src/megatron/bridge/training/train.py, all guarded by cfg.optimizer.offload_optimizer_states:

Call site Where What Mirrors
1 train_step() — start of should_run_forward_backward loop body optim_instance.offload_states() (async D2H) mcore training.py:2030
2 train_step() — after _handle_mxfp8_param_buffer_copy() optim_instance.release_offloaded_gpu_states() (frees GPU memory once D2H completes) mcore training.py:2073
3 train() — pre-loop setup wrap model_config.finalize_model_grads_func with reload_offloaded_states() → finalize (async H2D overlaps with grad all-reduce) mcore training.py:3140

31 lines added, 0 removed. No change to public API or config schema.

Why MLM has this but Bridge doesn't

Bridge's train.py is a fork of mcore's training.py. The FGO offload machinery (OptimizerStateOffloader class, offload_optimizer_states config field) was added to mcore after Bridge's fork point. mcore got both the offloader and the training-loop calls; Bridge inherits the offloader (via the submodule) but the training-loop calls weren't propagated. This PR closes that gap.

Test plan

Verified on DSv4-Flash on GB200 (gb200 partition, 16 nodes / 64 GPUs):

  • Before patch: OOM at iter 2. iter-1 max-allocated = 175.63 GB on 184 GiB cap. Master params (~18 GB) live on GPU throughout.
  • After patch: Runs all 20 iters cleanly. iter-1 max-allocated = 157.43 GB (-18 GB, exactly the offloaded master params). Steady-state ~20.8 s/iter.
  • Same recipe ran on mcore's pretrain_gpt.py (which has the wiring upstream): 19.55 s/iter, 153.76 GB max-allocated. Bridge with this patch is functionally equivalent.

Test plan checklist

  • Manual: DSv4-Flash 64×GB200 with offload_optimizer_states=True — no OOM, 20 iters complete
  • CI: existing tests unchanged; no new unit test added since this is a thin wrapper around mcore's existing offloader and is gated by an opt-in flag

🤖 Generated with Claude Code

mcore's training.py invokes the OptimizerStateOffloader machinery via
three call sites — Bridge's train.py forked before they landed and
never picked them up. Result: cfg.optimizer.offload_optimizer_states
silently no-ops in Bridge (offloader is constructed but never
triggered), so master params stay on GPU through every iter.

Verified on DSv4-Flash GB200: ~18 GB master params alive at iter-1
peak in Bridge that aren't alive in MLM on the same recipe (memory
snapshot diff), which is the dominant contributor to Bridge OOMing on
B200 (184 GiB cap) while MLM (153 GB peak) fits.

This patch mirrors mcore training.py exactly:
- train_step (start of forward_backward_func loop): offload_states()
- train_step (after _copy_main_params_to_param_buffer): release_offloaded_gpu_states()
- train (pre-loop): wrap finalize_model_grads_func with reload-then-finalize

All three call sites are guarded by cfg.optimizer.offload_optimizer_states
so the patch is a no-op for callers that don't opt in to FGO.

Signed-off-by: Dingqing Yang <dingqingy@nvidia.com>
@copy-pr-bot

copy-pr-bot Bot commented Jun 8, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@dingqingy-nv dingqingy-nv marked this pull request as ready for review June 8, 2026 23:38
@dingqingy-nv

Copy link
Copy Markdown
Contributor Author

/claude review

Comment thread src/megatron/bridge/training/train.py Outdated
Comment on lines +320 to +328
if config.optimizer.offload_optimizer_states:

def finalize_model_grads_with_state_reload(*fmg_args, **fmg_kwargs):
for optim_instance in optimizer.chained_optimizers:
if isinstance(optim_instance, DistributedOptimizer):
optim_instance.reload_offloaded_states()
return finalize_model_grads(*fmg_args, **fmg_kwargs)

model_config.finalize_model_grads_func = finalize_model_grads_with_state_reload

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: this wrapper calls the module-level finalize_model_grads directly, but model_config.finalize_model_grads_func may already be set to partial(finalize_model_grads, pg_collection=pg_collection) by setup.py:443 (or to a MIMO-specific variant by setup_megatron_mimo.py:107).

Replacing it with a wrapper that calls the bare import discards the pg_collection binding (or the MIMO override), which would break grad finalization.

Save the existing func before overwriting:

Suggested change
if config.optimizer.offload_optimizer_states:
def finalize_model_grads_with_state_reload(*fmg_args, **fmg_kwargs):
for optim_instance in optimizer.chained_optimizers:
if isinstance(optim_instance, DistributedOptimizer):
optim_instance.reload_offloaded_states()
return finalize_model_grads(*fmg_args, **fmg_kwargs)
model_config.finalize_model_grads_func = finalize_model_grads_with_state_reload
if config.optimizer.offload_optimizer_states:
_orig_finalize = model_config.finalize_model_grads_func
def finalize_model_grads_with_state_reload(*fmg_args, **fmg_kwargs):
for optim_instance in optimizer.chained_optimizers:
if isinstance(optim_instance, DistributedOptimizer):
optim_instance.reload_offloaded_states()
return _orig_finalize(*fmg_args, **fmg_kwargs)
model_config.finalize_model_grads_func = finalize_model_grads_with_state_reload

This also makes the from megatron.core.distributed import finalize_model_grads import unnecessary (can be removed).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@claude

claude Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Light Code Review

Critical bug

finalize_model_grads_func wrapper discards prior bindingsetup.py:443 sets model_config.finalize_model_grads_func = partial(finalize_model_grads, pg_collection=pg_collection) before train() is called. The new wrapper at lines 322-326 replaces it but calls the bare module-level finalize_model_grads (no pg_collection kwarg), discarding the partial binding. Similarly, setup_megatron_mimo.py:107 sets a MIMO-specific override that would also be lost. Fix: capture _orig_finalize = model_config.finalize_model_grads_func before overwriting, and delegate to it instead of the import. See inline comment for the suggested fix.

Test coverage

There are no unit or functional tests covering the FGO code paths. Even a lightweight mock-based unit test that verifies offload_states(), reload_offloaded_states(), and release_offloaded_gpu_states() are called when offload_optimizer_states=True (and not called when False) would catch regressions and validate the wiring. Consider adding one before merge.

Suggested test cases

No perf tests impacted.

…wrapper

The FGO wrapper was calling the bare module-level finalize_model_grads,
which discards bindings established earlier:
- setup.py:443 binds finalize_model_grads_func to
  partial(finalize_model_grads, pg_collection=pg_collection)
- setup_megatron_mimo.py:107 binds it to a MIMO-specific
  finalize_model_grads_multimodule

Capture model_config.finalize_model_grads_func before overwriting and
delegate to it, so both bindings (and any future ones) are preserved.

Caught by claude bot review on PR NVIDIA-NeMo#4217.

Signed-off-by: Dingqing Yang <dingqingy@nvidia.com>
@dingqingy-nv

Copy link
Copy Markdown
Contributor Author

Thanks for catching this — applied your suggestion in b1644eb. Captured _orig_finalize = model_config.finalize_model_grads_func before overwriting and the wrapper now delegates to it, so the setup.py:443 partial(..., pg_collection=...) binding and the setup_megatron_mimo.py:107 MIMO override both survive. Dropped the now-unused from megatron.core.distributed import finalize_model_grads import.

@dingqingy-nv

Copy link
Copy Markdown
Contributor Author

/claude review

Comment thread src/megatron/bridge/training/train.py Outdated
Comment on lines +315 to +322
# FGO: wrap finalize_model_grads_func to async-reload offloaded optimizer
# states before grad finalize. H2D overlaps with grad all-reduce. Mirrors
# mcore training.py:3140 in intent, but DELEGATES to whatever
# finalize_model_grads_func was previously bound. setup.py:443 already
# binds it to partial(finalize_model_grads, pg_collection=pg_collection),
# and setup_megatron_mimo.py:107 binds it to a MIMO-specific
# finalize_model_grads_multimodule — calling the bare module-level
# finalize_model_grads here would drop those bindings.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: these comments reference specific mcore line numbers (training.py:3140, training.py:2030, training.py:2073) that will drift as mcore evolves. Consider referencing the function/pattern name instead (e.g., "mirrors mcore's FGO offload pattern in train_step") so the comments don't go stale on the next mcore bump.

@claude

claude Bot commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

Light Code Review

The change is clean - pure glue code that correctly wires mcore's existing offload_states / reload_offloaded_states / release_offloaded_gpu_states APIs into Bridge's training loop. The three call sites follow the established chained_optimizers + isinstance(_, DistributedOptimizer) pattern already used elsewhere in train.py (lines 670, 1654).

The delegation approach for finalize_model_grads_func (wrapping the existing callable rather than calling the bare module-level function) is the right call - it preserves the setup.py partial binding and the MIMO multimodule binding.

Findings

  • Stale line-number references in comments (minor): The comments reference specific mcore line numbers (training.py:3140, training.py:2030, training.py:2073) that will drift on the next mcore bump. Consider referencing the pattern/function name instead.

  • No unit test coverage: The PR acknowledges this. A lightweight unit test that mocks optimizer.chained_optimizers with a DistributedOptimizer stub and verifies the offload/release/reload call sequence under offload_optimizer_states=True would be valuable and cheap to write. It would catch regressions if this code is refactored.

No bugs or logic errors found

The offload -> forward/backward (reload via finalize wrapper) -> release lifecycle is correct. The _orig_finalize closure capture is safe because setup.py:443 (or setup_megatron_mimo.py:107) always sets finalize_model_grads_func to a real callable before train() reaches the wrapping point.

Suggested test cases

No perf tests impacted.

Replace 'mcore training.py:NNNN', 'setup.py:443', and
'setup_megatron_mimo.py:107' line-number references with the actual
function names (_update_model_config_funcs,
_update_megatron_mimo_model_config_funcs) and the mcore pattern names
(FGO offload site, post-_copy_main_params_to_param_buffer release
site). Line numbers drift on every refactor; function names are stable
and grep-friendly.

Addresses claude bot nit on PR NVIDIA-NeMo#4217.

Signed-off-by: Dingqing Yang <dingqingy@nvidia.com>
@dingqingy-nv

Copy link
Copy Markdown
Contributor Author

Addressed the nit in 40bcec3 — line-number refs replaced with function names (_update_model_config_funcs, _update_megatron_mimo_model_config_funcs) and mcore pattern descriptions instead. Comments won't drift on next mcore bump.

The field ships on mcore dev only (mcore PR NVIDIA-NeMo#2811 is still open against
main). Bridge users tracking mcore main would AttributeError. Use
getattr(..., 'offload_optimizer_states', False) at all three FGO sites
so mcore-main users get a safe no-op fallback.

Signed-off-by: Dingqing Yang <dingqingy@nvidia.com>

@cuichenx cuichenx left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@yaoyu-33 yaoyu-33 added area:training Training loop, callbacks, and runtime integration blocked Work cannot move forward until an external dependency is cleared feature New capabilities, enhancements, or enablement work labels Jun 9, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:training Training loop, callbacks, and runtime integration blocked Work cannot move forward until an external dependency is cleared feature New capabilities, enhancements, or enablement work

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants