Skip to content
41 changes: 41 additions & 0 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,41 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def adapt_checkpoint_hparams(self, subcommand: str, checkpoint_hparams: dict[str, Any]) -> dict[str, Any]:
"""Adapt checkpoint hyperparameters before instantiating the model class.

This method allows for customization of hyperparameters loaded from a checkpoint when
using a different model class than the one used for training. For example, when loading
a checkpoint from a TrainingModule to use with an InferenceModule that has different
``__init__`` parameters, you can remove or modify incompatible hyperparameters.

Args:
subcommand: The subcommand being executed (e.g., 'fit', 'validate', 'test', 'predict').
This allows you to apply different hyperparameter adaptations depending on the context.
checkpoint_hparams: Dictionary of hyperparameters loaded from the checkpoint.

Returns:
Dictionary of adapted hyperparameters to be used for model instantiation.

Example::

class MyCLI(LightningCLI):
def adapt_checkpoint_hparams(
self, subcommand: str, checkpoint_hparams: dict[str, Any]
) -> dict[str, Any]:
# Only remove training-specific hyperparameters for non-fit subcommands
if subcommand != "fit":
checkpoint_hparams.pop("lr", None)
checkpoint_hparams.pop("weight_decay", None)
return checkpoint_hparams

Note:
If subclass module mode is enabled and ``_class_path`` is present in the checkpoint
hyperparameters, you may need to modify it as well to point to your new module class.

"""
return checkpoint_hparams

def _parse_ckpt_path(self) -> None:
"""If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
if not self.config.get("subcommand"):
Expand All @@ -571,6 +606,12 @@ def _parse_ckpt_path(self) -> None:
hparams.pop("_instantiator", None)
if not hparams:
return

# Allow customization of checkpoint hyperparameters via adapt_checkpoint_hparams hook
hparams = self.adapt_checkpoint_hparams(self.config.subcommand, hparams)
if not hparams:
return

if "_class_path" in hparams:
hparams = {
"class_path": hparams.pop("_class_path"),
Expand Down
62 changes: 62 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,68 @@ def add_arguments_to_parser(self, parser):
assert cli.model.layer.out_features == 4


def test_adapt_checkpoint_hparams_hook(cleandir):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_adapt_checkpoint_hparams_hook(cleandir):
def test_adapt_checkpoint_hparams_hook_pop_keys(cleandir):

"""Test that the adapt_checkpoint_hparams hook is called and modifications are applied."""

class AdaptHparamsCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def add_arguments_to_parser(self, parser):
parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)

Linking of arguments is not relevant to test this hook. Better to not have it to avoid distraction.

def adapt_checkpoint_hparams(self, subcommand, checkpoint_hparams):
"""Remove out_dim and hidden_dim for non-fit subcommands."""
if subcommand != "fit":
checkpoint_hparams.pop("out_dim", None)
checkpoint_hparams.pop("hidden_dim", None)
return checkpoint_hparams

# First, create a checkpoint by running fit
cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsCLI(BoringCkptPathModel)

assert cli.config.fit.model.out_dim == 3
assert cli.config.fit.model.hidden_dim == 6

checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))

# Test that predict uses adapted hparams (without out_dim and hidden_dim)
cli_args = ["predict", f"--ckpt_path={checkpoint_path}", "--model.out_dim=5", "--model.hidden_dim=10"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsCLI(BoringCkptPathModel)

# Since we removed out_dim and hidden_dim for predict, the CLI values should be used
assert cli.config.predict.model.out_dim == 5
assert cli.config.predict.model.hidden_dim == 10


def test_adapt_checkpoint_hparams_hook_empty_dict(cleandir):
"""Test that returning empty dict from adapt_checkpoint_hparams disables checkpoint hyperparameter loading."""

class AdaptHparamsEmptyCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def add_arguments_to_parser(self, parser):
parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)

Linking of arguments is not relevant to test this hook. Better to not have it to avoid distraction.

def adapt_checkpoint_hparams(self, subcommand, checkpoint_hparams):
"""Disable checkpoint hyperparameter loading."""
return {}

# First, create a checkpoint
cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsEmptyCLI(BoringCkptPathModel)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test fails because of BoringCkptPathModel has a module torch.nn.Linear(32, out_dim). If the out_dim is changed, then there is a tensor size mismatch.

Instead of using BoringCkptPathModel, implement a new class for these two tests, that just sets an attribute that can be asserted after instantiation.


checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))

# Test that predict uses default values when hook returns empty dict
cli_args = ["predict", f"--ckpt_path={checkpoint_path}"]
with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = AdaptHparamsEmptyCLI(BoringCkptPathModel)

# Model should use default values (out_dim=8, hidden_dim=16)
assert cli.config_init.predict.model.out_dim == 8
assert cli.config_init.predict.model.hidden_dim == 16


def test_lightning_cli_submodules(cleandir):
class MainModule(BoringModel):
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):
Expand Down