Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/hyrax/pytorch_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,10 @@ def create_trainer(model: torch.nn.Module, config: dict, results_directory: Path

optimizer = extract_model_method(model, "optimizer")
scheduler = extract_model_method(model, "scheduler")
# Extract unwrapped model for attribute access (scheduler attributes like _learning_rates_history)
unwrapped_model = (
model.module if (type(model) is DataParallel or type(model) is DistributedDataParallel) else model
)

to_save = {
"model": model,
Expand Down Expand Up @@ -732,11 +736,11 @@ def log_epoch_metrics(trainer):
@trainer.on(HyraxEvents.HYRAX_EPOCH_COMPLETED)
def scheduler_step(trainer):
if scheduler:
if not hasattr(model, "_learning_rates_history"):
model._learning_rates_history = []
if not hasattr(unwrapped_model, "_learning_rates_history"):
unwrapped_model._learning_rates_history = []
epoch_lr = scheduler.get_last_lr()
epoch_number = trainer.state.epoch - 1
model._learning_rates_history.append(epoch_lr)
unwrapped_model._learning_rates_history.append(epoch_lr)
tensorboardx_logger.add_scalar("training/training/epoch/lr", epoch_lr, global_step=epoch_number)
scheduler.step()

Expand Down
37 changes: 37 additions & 0 deletions tests/hyrax/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,40 @@ def test_constant_scheduler_checkpointing(loopback_hyrax, tmp_path):

assert hasattr(model, "_learning_rates_history")
assert model._learning_rates_history == [[initial_lr * factor]] + [[initial_lr]] * 2


def test_scheduler_with_data_parallel(loopback_hyrax):
"""
Test that scheduler works correctly when model is wrapped in DataParallel.
This test validates the fix for PR #652 AttributeError bug.
"""
from unittest.mock import patch

from torch.nn.parallel import DataParallel

h, _ = loopback_hyrax
gamma = 0.5
h.config["scheduler"]["name"] = "torch.optim.lr_scheduler.ExponentialLR"
h.config["torch.optim.lr_scheduler.ExponentialLR"] = {"gamma": gamma}
h.config["train"]["epochs"] = 3
initial_lr = 64
h.config[h.config["optimizer"]["name"]]["lr"] = initial_lr

# Mock idist.auto_model to wrap the model in DataParallel
# This simulates what happens in distributed training environments

def mock_auto_model(model):
# Wrap the model in DataParallel to test the fix
if hasattr(model, "scheduler"):
return DataParallel(model)
return model

# Patch idist.auto_model in the pytorch_ignite module
with patch("hyrax.pytorch_ignite.idist.auto_model", side_effect=mock_auto_model):
# This should not raise AttributeError: 'DataParallel' object has no attribute 'scheduler'
model = h.train()

# Verify the scheduler worked correctly
assert hasattr(model, "_learning_rates_history")
expected_history = [[initial_lr * gamma**i] for i in range(h.config["train"]["epochs"])]
assert model._learning_rates_history == expected_history