Skip to content

Commit 3dc98de

Browse files
authored
Add test coverage for Muon muon_lr/adam_lr overrides (#8047)
## Summary Add coverage for separate learning rate overrides in the Muon optimizer path and fix the related Muon blog documentation. ## Background Muon parameters and non-Muon parameters are automatically split into separate optimizer groups. The intended behavior is: - `muon_lr` applies to Muon parameter groups - `adam_lr` applies to Adam parameter groups - `lr` remains the fallback for both groups when overrides are not provided ## Changes - add a parameterized test covering: - legacy `lr` fallback behavior - separate `muon_lr` / `adam_lr` override behavior - fix the Muon blog table header to label `muon_lr` and `adam_lr` correctly ## Validation Ran: `python -m pytest DeepSpeed/tests/unit/ops/muon/test_muon_partial_training.py -k learning_rate_overrides -q -rs` Result: - test collected successfully - skipped locally because this distributed test requires 2 GPUs, while the local environment has 1 GPU --------- Signed-off-by: Sowndappan S <147894621+sowndappan5@users.noreply.github.com>
1 parent 28a196f commit 3dc98de

1 file changed

Lines changed: 41 additions & 0 deletions

File tree

tests/unit/ops/muon/test_muon_partial_training.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import torch.nn as nn
2424
import deepspeed
25+
import pytest
2526
from unit.common import DistributedTest
2627

2728

@@ -173,3 +174,43 @@ def test_muon_with_mixed_trainable_params(self):
173174

174175
# Verify the model was initialized successfully
175176
assert model_engine is not None
177+
178+
@pytest.mark.parametrize(
179+
"optimizer_params, expected_muon_lr, expected_adam_lr",
180+
[
181+
({
182+
"lr": 0.02,
183+
"weight_decay": 0.01
184+
}, 0.02, 0.02),
185+
({
186+
"lr": 0.02,
187+
"muon_lr": 0.04,
188+
"adam_lr": 0.001,
189+
"weight_decay": 0.01
190+
}, 0.04, 0.001),
191+
],
192+
)
193+
def test_muon_adam_learning_rate_overrides(self, optimizer_params, expected_muon_lr, expected_adam_lr):
194+
model = PartialTrainableModel()
195+
196+
ds_config = {
197+
"train_micro_batch_size_per_gpu": 1,
198+
"optimizer": {
199+
"type": "Muon",
200+
"params": optimizer_params
201+
},
202+
"zero_optimization": {
203+
"stage": 2
204+
},
205+
}
206+
207+
model_engine, _, _, _ = deepspeed.initialize(model=model,
208+
model_parameters=model.parameters(),
209+
config=ds_config)
210+
211+
group_lrs = {
212+
param_group["use_muon"]: param_group["lr"]
213+
for param_group in model_engine.basic_optimizer.param_groups
214+
}
215+
assert group_lrs[True] == expected_muon_lr
216+
assert group_lrs[False] == expected_adam_lr

0 commit comments

Comments
 (0)