diff --git a/segmentation_models_pytorch/losses/jaccard.py b/segmentation_models_pytorch/losses/jaccard.py index b250cacf..35727f95 100644 --- a/segmentation_models_pytorch/losses/jaccard.py +++ b/segmentation_models_pytorch/losses/jaccard.py @@ -17,6 +17,7 @@ def __init__( log_loss: bool = False, from_logits: bool = True, smooth: float = 0.0, + ignore_index: Optional[int] = None, eps: float = 1e-7, ): """Jaccard loss for image segmentation task. @@ -51,6 +52,7 @@ def __init__( self.classes = classes self.from_logits = from_logits self.smooth = smooth + self.ignore_index = ignore_index self.eps = eps self.log_loss = log_loss @@ -74,17 +76,36 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: y_true = y_true.view(bs, 1, -1) y_pred = y_pred.view(bs, 1, -1) + if self.ignore_index is not None: + mask = y_true != self.ignore_index + y_pred = y_pred * mask + y_true = y_true * mask + if self.mode == MULTICLASS_MODE: y_true = y_true.view(bs, -1) y_pred = y_pred.view(bs, num_classes, -1) - y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C - y_true = y_true.permute(0, 2, 1) # H, C, H*W + if self.ignore_index is not None: + mask = y_true != self.ignore_index + y_pred = y_pred * mask.unsqueeze(1) + + y_true = F.one_hot( + (y_true * mask).to(torch.long), num_classes + ) # N,H*W -> N,H*W, C + y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W + else: + y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C + y_true = y_true.permute(0, 2, 1) # N, C, H*W if self.mode == MULTILABEL_MODE: y_true = y_true.view(bs, num_classes, -1) y_pred = y_pred.view(bs, num_classes, -1) + if self.ignore_index is not None: + mask = y_true != self.ignore_index + y_pred = y_pred * mask + y_true = y_true * mask + scores = soft_jaccard_score( y_pred, y_true.type(y_pred.dtype),