-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add adapt_checkpoint_hparams hook for customizing checkpoint hyperparameter loading #21408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 5 commits
910a712
ad1a028
998ea3e
00e7032
b3b1025
fc8cc3a
a0f0d77
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -562,6 +562,68 @@ def add_arguments_to_parser(self, parser): | |||||
| assert cli.model.layer.out_features == 4 | ||||||
|
|
||||||
|
|
||||||
| def test_adapt_checkpoint_hparams_hook(cleandir): | ||||||
| """Test that the adapt_checkpoint_hparams hook is called and modifications are applied.""" | ||||||
|
|
||||||
| class AdaptHparamsCLI(LightningCLI): | ||||||
| def add_arguments_to_parser(self, parser): | ||||||
| parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) | ||||||
|
|
||||||
|
||||||
| def add_arguments_to_parser(self, parser): | |
| parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) |
Linking of arguments is not relevant to test this hook. Better to not have it to avoid distraction.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| def add_arguments_to_parser(self, parser): | |
| parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) |
Linking of arguments is not relevant to test this hook. Better to not have it to avoid distraction.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test fails because of BoringCkptPathModel has a module torch.nn.Linear(32, out_dim). If the out_dim is changed, then there is a tensor size mismatch.
Instead of using BoringCkptPathModel, implement a new class for these two tests, that just sets an attribute that can be asserted after instantiation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.