Skip to content

Commit 0ad180b

Browse files
Merge pull request #110 from IBM/fix/remove_confusion_default_optimizer_config
set default optimizer to None
2 parents cbc98f1 + 2752882 commit 0ad180b

File tree

3 files changed

+18
-7
lines changed

3 files changed

+18
-7
lines changed

terratorch/tasks/classification_tasks.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(
5050
ignore_index: int | None = None,
5151
lr: float = 0.001,
5252
# the following are optional so CLI doesnt need to pass them
53-
optimizer: str | None = "torch.optim.Adam",
53+
optimizer: str | None = None,
5454
optimizer_hparams: dict | None = None,
5555
scheduler: str | None = None,
5656
scheduler_hparams: dict | None = None,
@@ -80,7 +80,7 @@ def __init__(
8080
ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
8181
lr (float, optional): Learning rate to be used. Defaults to 0.001.
8282
optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
83-
Defaults to "Adam". Overriden by config / cli specification through LightningCLI.
83+
If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
8484
optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
8585
Overriden by config / cli specification through LightningCLI.
8686
scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
@@ -118,6 +118,9 @@ def configure_models(self) -> None:
118118
def configure_optimizers(
119119
self,
120120
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
121+
optimizer = self.hparams["optimizer"]
122+
if optimizer is None:
123+
optimizer = "Adam"
121124
return optimizer_factory(
122125
self.hparams["optimizer"],
123126
self.hparams["lr"],

terratorch/tasks/regression_tasks.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __init__(
139139
ignore_index: int | None = None,
140140
lr: float = 0.001,
141141
# the following are optional so CLI doesnt need to pass them
142-
optimizer: str | None = "Adam",
142+
optimizer: str | None = None,
143143
optimizer_hparams: dict | None = None,
144144
scheduler: str | None = None,
145145
scheduler_hparams: dict | None = None,
@@ -166,7 +166,7 @@ def __init__(
166166
ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
167167
lr (float, optional): Learning rate to be used. Defaults to 0.001.
168168
optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
169-
Defaults to "Adam". Overriden by config / cli specification through LightningCLI.
169+
If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
170170
optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
171171
Overriden by config / cli specification through LightningCLI.
172172
scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
@@ -208,6 +208,9 @@ def configure_models(self) -> None:
208208
def configure_optimizers(
209209
self,
210210
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
211+
optimizer = self.hparams["optimizer"]
212+
if optimizer is None:
213+
optimizer = "Adam"
211214
return optimizer_factory(
212215
self.hparams["optimizer"],
213216
self.hparams["lr"],

terratorch/tasks/segmentation_tasks.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(
4646
ignore_index: int | None = None,
4747
lr: float = 0.001,
4848
# the following are optional so CLI doesnt need to pass them
49-
optimizer: str | None = "Adam",
49+
optimizer: str | None = None,
5050
optimizer_hparams: dict | None = None,
5151
scheduler: str | None = None,
5252
scheduler_hparams: dict | None = None,
@@ -77,7 +77,7 @@ def __init__(
7777
ignore_index (int | None, optional): Label to ignore in the loss computation. Defaults to None.
7878
lr (float, optional): Learning rate to be used. Defaults to 0.001.
7979
optimizer (str | None, optional): Name of optimizer class from torch.optim to be used.
80-
Defaults to "Adam". Overriden by config / cli specification through LightningCLI.
80+
If None, will use Adam. Defaults to None. Overriden by config / cli specification through LightningCLI.
8181
optimizer_hparams (dict | None): Parameters to be passed for instantiation of the optimizer.
8282
Overriden by config / cli specification through LightningCLI.
8383
scheduler (str, optional): Name of Torch scheduler class from torch.optim.lr_scheduler
@@ -121,6 +121,9 @@ def configure_models(self) -> None:
121121
def configure_optimizers(
122122
self,
123123
) -> "lightning.pytorch.utilities.types.OptimizerLRSchedulerConfig":
124+
optimizer = self.hparams["optimizer"]
125+
if optimizer is None:
126+
optimizer = "Adam"
124127
return optimizer_factory(
125128
self.hparams["optimizer"],
126129
self.hparams["lr"],
@@ -158,7 +161,9 @@ def configure_losses(self) -> None:
158161
elif loss == "dice":
159162
self.criterion = smp.losses.DiceLoss("multiclass", ignore_index=ignore_index)
160163
else:
161-
exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss."
164+
exception_message = (
165+
f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss."
166+
)
162167
raise ValueError(exception_message)
163168

164169
def configure_metrics(self) -> None:

0 commit comments

Comments
 (0)