@@ -46,7 +46,7 @@ def __init__(
46
46
ignore_index : int | None = None ,
47
47
lr : float = 0.001 ,
48
48
# the following are optional so CLI doesnt need to pass them
49
- optimizer : str | None = "Adam" ,
49
+ optimizer : str | None = None ,
50
50
optimizer_hparams : dict | None = None ,
51
51
scheduler : str | None = None ,
52
52
scheduler_hparams : dict | None = None ,
@@ -77,7 +77,7 @@ def __init__(
77
77
ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
78
78
lr (float, optional): Learning rate to be used. Defaults to 0.001.
79
79
optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
80
- Defaults to "Adam" . Overriden by config / cli specification through LightningCLI.
80
+ If None, will use Adam. Defaults to None . Overriden by config / cli specification through LightningCLI.
81
81
optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
82
82
Overriden by config / cli specification through LightningCLI.
83
83
scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
@@ -121,6 +121,9 @@ def configure_models(self) -> None:
121
121
def configure_optimizers (
122
122
self ,
123
123
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig" :
124
+ optimizer = self .hparams ["optimizer" ]
125
+ if optimizer is None :
126
+ optimizer = "Adam"
124
127
return optimizer_factory (
125
128
self .hparams ["optimizer" ],
126
129
self .hparams ["lr" ],
@@ -158,7 +161,9 @@ def configure_losses(self) -> None:
158
161
elif loss == "dice" :
159
162
self .criterion = smp .losses .DiceLoss ("multiclass" , ignore_index = ignore_index )
160
163
else :
161
- exception_message = f"Loss type '{ loss } ' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss."
164
+ exception_message = (
165
+ f"Loss type '{ loss } ' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss."
166
+ )
162
167
raise ValueError (exception_message )
163
168
164
169
def configure_metrics (self ) -> None :
0 commit comments