Skip to content

Commit b699344

Browse files
Merge pull request #15 from IBM/feature/dice_loss
Feature/dice loss
2 parents c4afa03 + 1ab3d1f commit b699344

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

src/terratorch/tasks/segmentation_tasks.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -152,11 +152,13 @@ def configure_losses(self) -> None:
152152
f"Jaccard loss does not support ignore_index, but found non-None value of {ignore_index}."
153153
)
154154
raise RuntimeError(exception_message)
155-
self.criterion = smp.losses.JaccardLoss(mode="multiclass", classes=self.hparams["num_classes"])
155+
self.criterion = smp.losses.JaccardLoss(mode="multiclass")
156156
elif loss == "focal":
157157
self.criterion = smp.losses.FocalLoss("multiclass", ignore_index=ignore_index, normalized=True)
158+
elif loss == "dice":
159+
self.criterion = smp.losses.DiceLoss("multiclass", ignore_index=ignore_index)
158160
else:
159-
exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard' or 'focal' loss."
161+
exception_message = f"Loss type '{loss}' is not valid. Currently, supports 'ce', 'jaccard', 'dice' or 'focal' loss."
160162
raise ValueError(exception_message)
161163

162164
def configure_metrics(self) -> None:

tests/test_prithvi_tasks.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def model_input() -> torch.Tensor:
2222

2323
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
2424
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
25-
def test_create_segmentation_task(backbone, decoder, model_factory: PrithviModelFactory):
25+
@pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"])
26+
def test_create_segmentation_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
2627
SemanticSegmentationTask(
2728
{
2829
"backbone": backbone,
@@ -33,12 +34,14 @@ def test_create_segmentation_task(backbone, decoder, model_factory: PrithviModel
3334
"num_classes": NUM_CLASSES,
3435
},
3536
model_factory,
37+
loss=loss
3638
)
3739

3840

3941
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
4042
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
41-
def test_create_regression_task(backbone, decoder, model_factory: PrithviModelFactory):
43+
@pytest.mark.parametrize("loss", ["mae", "rmse", "huber"])
44+
def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
4245
PixelwiseRegressionTask(
4346
{
4447
"backbone": backbone,
@@ -48,12 +51,14 @@ def test_create_regression_task(backbone, decoder, model_factory: PrithviModelFa
4851
"pretrained": False,
4952
},
5053
model_factory,
54+
loss=loss
5155
)
5256

5357

5458
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
5559
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
56-
def test_create_classification_task(backbone, decoder, model_factory: PrithviModelFactory):
60+
@pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"])
61+
def test_create_classification_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
5762
ClassificationTask(
5863
{
5964
"backbone": backbone,
@@ -64,4 +69,5 @@ def test_create_classification_task(backbone, decoder, model_factory: PrithviMod
6469
"num_classes": NUM_CLASSES,
6570
},
6671
model_factory,
72+
loss=loss
6773
)

0 commit comments

Comments
 (0)