diff --git a/src/hyrax/pytorch_ignite.py b/src/hyrax/pytorch_ignite.py index 4b33a726..62be4f28 100644 --- a/src/hyrax/pytorch_ignite.py +++ b/src/hyrax/pytorch_ignite.py @@ -657,6 +657,10 @@ def create_trainer(model: torch.nn.Module, config: dict, results_directory: Path optimizer = extract_model_method(model, "optimizer") scheduler = extract_model_method(model, "scheduler") + # Extract unwrapped model for attribute access (scheduler attributes like _learning_rates_history) + unwrapped_model = ( + model.module if (type(model) is DataParallel or type(model) is DistributedDataParallel) else model + ) to_save = { "model": model, @@ -732,11 +736,11 @@ def log_epoch_metrics(trainer): @trainer.on(HyraxEvents.HYRAX_EPOCH_COMPLETED) def scheduler_step(trainer): if scheduler: - if not hasattr(model, "_learning_rates_history"): - model._learning_rates_history = [] + if not hasattr(unwrapped_model, "_learning_rates_history"): + unwrapped_model._learning_rates_history = [] epoch_lr = scheduler.get_last_lr() epoch_number = trainer.state.epoch - 1 - model._learning_rates_history.append(epoch_lr) + unwrapped_model._learning_rates_history.append(epoch_lr) tensorboardx_logger.add_scalar("training/training/epoch/lr", epoch_lr, global_step=epoch_number) scheduler.step() diff --git a/tests/hyrax/test_train.py b/tests/hyrax/test_train.py index d31b7222..d2f13514 100644 --- a/tests/hyrax/test_train.py +++ b/tests/hyrax/test_train.py @@ -222,3 +222,40 @@ def test_constant_scheduler_checkpointing(loopback_hyrax, tmp_path): assert hasattr(model, "_learning_rates_history") assert model._learning_rates_history == [[initial_lr * factor]] + [[initial_lr]] * 2 + + +def test_scheduler_with_data_parallel(loopback_hyrax): + """ + Test that scheduler works correctly when model is wrapped in DataParallel. + This test validates the fix for PR #652 AttributeError bug. + """ + from unittest.mock import patch + + from torch.nn.parallel import DataParallel + + h, _ = loopback_hyrax + gamma = 0.5 + h.config["scheduler"]["name"] = "torch.optim.lr_scheduler.ExponentialLR" + h.config["torch.optim.lr_scheduler.ExponentialLR"] = {"gamma": gamma} + h.config["train"]["epochs"] = 3 + initial_lr = 64 + h.config[h.config["optimizer"]["name"]]["lr"] = initial_lr + + # Mock idist.auto_model to wrap the model in DataParallel + # This simulates what happens in distributed training environments + + def mock_auto_model(model): + # Wrap the model in DataParallel to test the fix + if hasattr(model, "scheduler"): + return DataParallel(model) + return model + + # Patch idist.auto_model in the pytorch_ignite module + with patch("hyrax.pytorch_ignite.idist.auto_model", side_effect=mock_auto_model): + # This should not raise AttributeError: 'DataParallel' object has no attribute 'scheduler' + model = h.train() + + # Verify the scheduler worked correctly + assert hasattr(model, "_learning_rates_history") + expected_history = [[initial_lr * gamma**i] for i in range(h.config["train"]["epochs"])] + assert model._learning_rates_history == expected_history