Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/hyrax/hyrax_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,19 @@ momentum = 0.9
# learning rate for torch.optim.SGD optimizer.
lr = 0.01

[scheduler]
# name of the learning rate scheduler
# With gamma=1, ExponentialLR will keep the learning rate constant
# https://docs.pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ExponentialLR.html
name = "torch.optim.lr_scheduler.ExponentialLR"

["torch.optim.lr_scheduler.ExponentialLR"]
# the decay multipler on each epoch
gamma = 1

["torch.optim.lr_scheduler.ConstantLR"]

last_epoch = -1

[train]
# The name of the file were the model weights will be saved after training.
Expand Down
46 changes: 46 additions & 0 deletions src/hyrax/models/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch.nn as nn
import torch.optim as optim

from hyrax.plugin_utils import (
get_or_load_class,
Expand Down Expand Up @@ -149,6 +150,39 @@ def _torch_optimizer(self: nn.Module):
return optimizer_cls(self.parameters(), **arguments)


def _torch_schedulers(self: nn.Module):
"""Load the scheduler classes using the names defined in the config and
instantiate it with the arguments defined in the config."""

config = cast(dict[str, Any], self.config)

# Load the class and get any parameters from the config dictionary
scheduler_name = config["scheduler"]["name"]
if not scheduler_name:
logger.warning("No scheduler specified in config or self.scheduler in model.")
return None

scheduler_cls = get_or_load_class(scheduler_name)

arguments = {}
if scheduler_name in config:
arguments = config[scheduler_name]

# Print some debugging info about the scheduler function and parameters used
log_string = f"Setting model's self.scheduler from config: {scheduler_name}\n"
if arguments:
log_string += f"with arguments: {arguments}."
else:
log_string += "with default arguments."
logger.info(log_string)

if not isinstance(self.optimizer, optim.Optimizer):
raise RuntimeError("Model optimizer must be a torch.optim.Optimizer")

scheduler = scheduler_cls(self.optimizer, **arguments)
return scheduler


def hyrax_model(cls):
"""Decorator to register a model with the model registry, and to add common interface functions

Expand Down Expand Up @@ -189,6 +223,18 @@ def wrapped_init(self, config, *args, **kwargs):
crit_name = f"{type(self.criterion).__module__}.{type(self.criterion).__qualname__}"
logger.info(f"Using self.criterion defined in model: {crit_name}")

if not hasattr(self, "scheduler"):
self.scheduler = _torch_schedulers(self)
else:
if config["scheduler"]["name"]:
logger.warning(
"Both model and config define a scheduler. "
"Hyrax will use self.scheduler defined in the model."
)

sched_name = f"{type(self.scheduler).__module__}.{type(self.scheduler).__qualname__}"
logger.info(f"Using self.scheduler defined in model: {sched_name}")

cls.__init__ = wrapped_init

def default_prepare_inputs(data_dict):
Expand Down
15 changes: 15 additions & 0 deletions src/hyrax/pytorch_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,13 +656,17 @@ def create_trainer(model: torch.nn.Module, config: dict, results_directory: Path
fixup_engine(trainer)

optimizer = extract_model_method(model, "optimizer")
scheduler = extract_model_method(model, "scheduler")

to_save = {
"model": model,
"optimizer": optimizer,
"trainer": trainer,
}

if scheduler:
to_save["scheduler"] = scheduler

#! We may want to move the checkpointing logic over to the `validator`.
#! It was created here initially because this was the only place where the
#! model training was happening.
Expand Down Expand Up @@ -725,6 +729,17 @@ def log_epoch_metrics(trainer):
)
mlflow.log_metrics({f"training/epoch/{m}": epoch_metrics[m]}, step=epoch_number)

@trainer.on(HyraxEvents.HYRAX_EPOCH_COMPLETED)
def scheduler_step(trainer):
if model.scheduler:
if not hasattr(model, "_learning_rates_history"):
model._learning_rates_history = []
epoch_lr = model.scheduler.get_last_lr()
epoch_number = trainer.state.epoch - 1
model._learning_rates_history.append(epoch_lr)
tensorboardx_logger.add_scalar("training/training/epoch/lr", epoch_lr, global_step=epoch_number)
model.scheduler.step()

trainer.add_event_handler(HyraxEvents.HYRAX_EPOCH_COMPLETED, latest_checkpoint)
trainer.add_event_handler(HyraxEvents.HYRAX_EPOCH_COMPLETED, best_checkpoint)

Expand Down
85 changes: 85 additions & 0 deletions tests/hyrax/test_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(self, config, data_sample=None):
self.config = config
self.unused_module = nn.Linear(1, 1)
self.optimizer = "model_optimizer"
self.scheduler = None

h = Hyrax()
h.set_config("model.name", "TestModel")
Expand Down Expand Up @@ -54,6 +55,7 @@ def __init__(self, config, data_sample=None):
super().__init__()
self.config = config
self.unused_module = nn.Linear(1, 1)
self.scheduler = None

h = Hyrax()
h.set_config("model.name", "TestModel")
Expand Down Expand Up @@ -154,6 +156,7 @@ def __init__(self, config, data_sample=None):
self.config = config
self.unused_module = nn.Linear(1, 1)
self.optimizer = "model_optimizer"
self.scheduler = None

h = Hyrax()
h.set_config("model.name", "TestModel")
Expand All @@ -164,3 +167,85 @@ def __init__(self, config, data_sample=None):
assert "Both model and config define an optimizer" in caplog.text

assert model.optimizer == "model_optimizer" # Should use the model's own optimizer


def test_use_model_scheduler():
"""Test that the config will not override a scheduler defined in the model."""

@hyrax_model
class TestModel(nn.Module):
def __init__(self, config, data_sample=None):
super().__init__()
self.config = config
self.unused_module = nn.Linear(1, 1)
self.scheduler = "model_scheduler"

h = Hyrax()
h.set_config("model.name", "TestModel")
h.set_config("scheduler.name", "torch.optim.lr_scheduler.ConstantLR")

model = TestModel(h.config)
assert hasattr(model, "scheduler")
assert model.scheduler == "model_scheduler" # Should use the model's own scheduler, not the config


def test_use_config_scheduler(caplog):
"""Test that the config will inject a scheduler if the model does not define one."""

@hyrax_model
class TestModel(nn.Module):
def __init__(self, config, data_sample=None):
super().__init__()
self.config = config
self.unused_module = nn.Linear(1, 1)

h = Hyrax()
h.set_config("model.name", "TestModel")
h.set_config("scheduler.name", "torch.optim.lr_scheduler.ConstantLR")

model = TestModel(h.config)
assert hasattr(model, "scheduler")
assert model.scheduler.__class__.__name__ == "ConstantLR" # Should use the config's scheduler


def test_no_scheduler_defined_logs_warning(caplog):
"""Test that if neither model nor config define a scheduler, a warning is logged."""

@hyrax_model
class TestModel(nn.Module):
def __init__(self, config, data_sample=None):
super().__init__()
self.config = config
self.unused_module = nn.Linear(1, 1)

h = Hyrax()
h.set_config("model.name", "TestModel")
h.set_config("scheduler.name", "")

with caplog.at_level(logging.WARNING):
_ = TestModel(h.config)
assert "No scheduler specified in config or" in caplog.text


def test_scheduler_defined_in_model_and_config(caplog):
"""Test that if both model and config define a scheduler, a warning is logged."""

@hyrax_model
class TestModel(nn.Module):
def __init__(self, config, data_sample=None):
super().__init__()
self.config = config
self.unused_module = nn.Linear(1, 1)
self.optimizer = "model_optimizer"
self.scheduler = "model_scheduler"

h = Hyrax()
h.set_config("model.name", "TestModel")
h.set_config("optimizer.name", "torch.optim.SGD")
h.set_config("scheduler.name", "torch.optim.lr_scheduler.ConstantLR")

with caplog.at_level(logging.WARNING):
model = TestModel(h.config)
assert "Both model and config define a scheduler" in caplog.text

assert model.scheduler == "model_scheduler" # Should use the model's own scheduler
2 changes: 2 additions & 0 deletions tests/hyrax/test_patch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_patch_prepare_inputs(tmp_path):
"criterion": {"name": "torch.nn.MSELoss"},
"optimizer": {"name": "torch.optim.SGD"},
"torch.optim.SGD": {"lr": 0.01},
"scheduler": {"name": None},
}

# create an instance of the dummy model
Expand Down Expand Up @@ -87,6 +88,7 @@ def test_patch_prepare_inputs_over_default(tmp_path):
"criterion": {"name": "torch.nn.MSELoss"},
"optimizer": {"name": "torch.optim.SGD"},
"torch.optim.SGD": {"lr": 0.01},
"scheduler": {"name": None},
}

# create an instance of the dummy model
Expand Down
3 changes: 3 additions & 0 deletions tests/hyrax/test_plugin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def train_batch(self, batch):
"criterion": {"name": "torch.nn.MSELoss"},
"optimizer": {"name": "torch.optim.SGD"},
"torch.optim.SGD": {"lr": 0.01},
"scheduler": {"name": None},
}

# Get the expected device from idist (same as _torch_load uses)
Expand Down Expand Up @@ -240,6 +241,7 @@ def infer_batch(self, batch):
"criterion": {"name": "torch.nn.MSELoss"},
"optimizer": {"name": "torch.optim.SGD"},
"torch.optim.SGD": {"lr": 0.01},
"scheduler": {"name": None},
}

# Create and save a model (this will create prepare_inputs.py)
Expand Down Expand Up @@ -301,6 +303,7 @@ def infer_batch(self, batch):
"criterion": {"name": "torch.nn.MSELoss"},
"optimizer": {"name": "torch.optim.SGD"},
"torch.optim.SGD": {"lr": 0.01},
"scheduler": {"name": None},
}

# Create and save a model (no prepare_inputs defined in class, so won't create prepare_inputs.py)
Expand Down
118 changes: 118 additions & 0 deletions tests/hyrax/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,121 @@ def test_train_percent_split(tmp_path):
# Finally, run full training to exercise `train.py` end-to-end and ensure
# the training verb functions correctly with percent-based splits.
h.train()


def test_constant_scheduler(loopback_hyrax):
"""
Ensure that setting a ConstantLR works properly
"""
h, _ = loopback_hyrax
factor = 0.5
h.config["scheduler"]["name"] = "torch.optim.lr_scheduler.ConstantLR"
h.config["torch.optim.lr_scheduler.ConstantLR"] = {"total_iters": 4, "factor": factor}
h.config["train"]["epochs"] = 6
initial_lr = 128
h.config[h.config["optimizer"]["name"]]["lr"] = 128
model = h.train()

assert hasattr(model, "_learning_rates_history")
assert model._learning_rates_history == [[initial_lr * factor]] * 4 + [[initial_lr]] * 2


def test_exponential_scheduler(loopback_hyrax):
"""
Ensure that setting an ExponentialLR scheduler works properly
"""
h, _ = loopback_hyrax
gamma = 0.5
h.config["scheduler"]["name"] = "torch.optim.lr_scheduler.ExponentialLR"
h.config["torch.optim.lr_scheduler.ExponentialLR"] = {"gamma": gamma}
h.config["train"]["epochs"] = 5
initial_lr = 128
h.config[h.config["optimizer"]["name"]]["lr"] = initial_lr
model = h.train()

assert hasattr(model, "_learning_rates_history")
assert model._learning_rates_history == [
[initial_lr * gamma**i] for i in range(h.config["train"]["epochs"])
]


def test_exponential_scheduler_checkpointing(loopback_hyrax, tmp_path):
"""
Ensure that ExponentialLR scheduler resumes from a checkpoint properly
"""
checkpoint_filename = "checkpoint_epoch_3.pt"
h, _ = loopback_hyrax

# set results directory to a temporary path
h.config["general"]["results_dir"] = str(tmp_path)

# set the scheduler up
gamma = 0.5
initial_lr = 128
h.config["scheduler"]["name"] = "torch.optim.lr_scheduler.ExponentialLR"
h.config["torch.optim.lr_scheduler.ExponentialLR"] = {"gamma": gamma}
h.config["train"]["epochs"] = 3
h.config[h.config["optimizer"]["name"]]["lr"] = initial_lr

# run initial training to create a saved model file
model = h.train()

# first 3 epochs working as expected
assert hasattr(model, "_learning_rates_history")
assert model._learning_rates_history == [[initial_lr * gamma**i] for i in range(3)]

# find the model file in the most recent results directory
results_dir = find_most_recent_results_dir(h.config, "train")
checkpoint_path = results_dir / checkpoint_filename

# Now, set the resume config to point to this checkpoint
h.config["train"]["resume"] = str(checkpoint_path)

# We will try running for two more epochs
h.config["train"]["epochs"] = 5
# Resume training
model = h.train()

assert hasattr(model, "_learning_rates_history")
assert model._learning_rates_history == [[initial_lr * gamma**i] for i in range(3, 5)]


def test_constant_scheduler_checkpointing(loopback_hyrax, tmp_path):
"""
Ensure that ConstantLR scheduler resumes from a checkpoint properly
"""
checkpoint_filename = "checkpoint_epoch_2.pt"
h, _ = loopback_hyrax

# set results directory to a temporary path
h.config["general"]["results_dir"] = str(tmp_path)

# set the scheduler up
factor = 0.5
initial_lr = 128
h.config["scheduler"]["name"] = "torch.optim.lr_scheduler.ConstantLR"
h.config["torch.optim.lr_scheduler.ConstantLR"] = {"total_iters": 3, "factor": factor}
h.config["train"]["epochs"] = 2
h.config[h.config["optimizer"]["name"]]["lr"] = initial_lr

# run initial training to create a saved model file
model = h.train()

# first 2 epochs working as expected
assert hasattr(model, "_learning_rates_history")
assert model._learning_rates_history == [[initial_lr * factor]] * h.config["train"]["epochs"]

# find the model file in the most recent results directory
results_dir = find_most_recent_results_dir(h.config, "train")
checkpoint_path = results_dir / checkpoint_filename

# Now, set the resume config to point to this checkpoint
h.config["train"]["resume"] = str(checkpoint_path)

# We will try running for three more epochs
h.config["train"]["epochs"] = 5
# Resume training
model = h.train()

assert hasattr(model, "_learning_rates_history")
assert model._learning_rates_history == [[initial_lr * factor]] + [[initial_lr]] * 2