Skip to content

Commit bf5ae48

Browse files
authored
Replaced view with replace to prevent fails on non-contiguous tensors (#1174)
1 parent 4c4ad41 commit bf5ae48

File tree

5 files changed

+21
-21
lines changed

5 files changed

+21
-21
lines changed

segmentation_models_pytorch/losses/dice.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7373
dims = (0, 2)
7474

7575
if self.mode == BINARY_MODE:
76-
y_true = y_true.view(bs, 1, -1)
77-
y_pred = y_pred.view(bs, 1, -1)
76+
y_true = y_true.reshape(bs, 1, -1)
77+
y_pred = y_pred.reshape(bs, 1, -1)
7878

7979
if self.ignore_index is not None:
8080
mask = y_true != self.ignore_index
8181
y_pred = y_pred * mask
8282
y_true = y_true * mask
8383

8484
if self.mode == MULTICLASS_MODE:
85-
y_true = y_true.view(bs, -1)
86-
y_pred = y_pred.view(bs, num_classes, -1)
85+
y_true = y_true.reshape(bs, -1)
86+
y_pred = y_pred.reshape(bs, num_classes, -1)
8787

8888
if self.ignore_index is not None:
8989
mask = y_true != self.ignore_index
@@ -98,8 +98,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
9898
y_true = y_true.permute(0, 2, 1) # N, C, H*W
9999

100100
if self.mode == MULTILABEL_MODE:
101-
y_true = y_true.view(bs, num_classes, -1)
102-
y_pred = y_pred.view(bs, num_classes, -1)
101+
y_true = y_true.reshape(bs, num_classes, -1)
102+
y_pred = y_pred.reshape(bs, num_classes, -1)
103103

104104
if self.ignore_index is not None:
105105
mask = y_true != self.ignore_index

segmentation_models_pytorch/losses/focal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def __init__(
5757

5858
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
5959
if self.mode in {BINARY_MODE, MULTILABEL_MODE}:
60-
y_true = y_true.view(-1)
61-
y_pred = y_pred.view(-1)
60+
y_true = y_true.reshape(-1)
61+
y_pred = y_pred.reshape(-1)
6262

6363
if self.ignore_index is not None:
6464
# Filter predictions with ignore label from loss computation

segmentation_models_pytorch/losses/jaccard.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
7373
dims = (0, 2)
7474

7575
if self.mode == BINARY_MODE:
76-
y_true = y_true.view(bs, 1, -1)
77-
y_pred = y_pred.view(bs, 1, -1)
76+
y_true = y_true.reshape(bs, 1, -1)
77+
y_pred = y_pred.reshape(bs, 1, -1)
7878

7979
if self.ignore_index is not None:
8080
mask = y_true != self.ignore_index
8181
y_pred = y_pred * mask
8282
y_true = y_true * mask
8383

8484
if self.mode == MULTICLASS_MODE:
85-
y_true = y_true.view(bs, -1)
86-
y_pred = y_pred.view(bs, num_classes, -1)
85+
y_true = y_true.reshape(bs, -1)
86+
y_pred = y_pred.reshape(bs, num_classes, -1)
8787

8888
if self.ignore_index is not None:
8989
mask = y_true != self.ignore_index
@@ -98,8 +98,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
9898
y_true = y_true.permute(0, 2, 1) # N, C, H*W
9999

100100
if self.mode == MULTILABEL_MODE:
101-
y_true = y_true.view(bs, num_classes, -1)
102-
y_pred = y_pred.view(bs, num_classes, -1)
101+
y_true = y_true.reshape(bs, num_classes, -1)
102+
y_pred = y_pred.reshape(bs, num_classes, -1)
103103

104104
if self.ignore_index is not None:
105105
mask = y_true != self.ignore_index

segmentation_models_pytorch/losses/lovasz.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def _flatten_binary_scores(scores, labels, ignore=None):
7777
"""Flattens predictions in the batch (binary case)
7878
Remove labels equal to 'ignore'
7979
"""
80-
scores = scores.view(-1)
81-
labels = labels.view(-1)
80+
scores = scores.reshape(-1)
81+
labels = labels.reshape(-1)
8282
if ignore is None:
8383
return scores, labels
8484
valid = labels != ignore
@@ -151,13 +151,13 @@ def _flatten_probas(probas, labels, ignore=None):
151151
if probas.dim() == 3:
152152
# assumes output of a sigmoid layer
153153
B, H, W = probas.size()
154-
probas = probas.view(B, 1, H, W)
154+
probas = probas.reshape(B, 1, H, W)
155155

156156
C = probas.size(1)
157157
probas = torch.movedim(probas, 1, -1) # [B, C, Di, Dj, ...] -> [B, Di, Dj, ..., C]
158-
probas = probas.contiguous().view(-1, C) # [P, C]
158+
probas = probas.reshape(-1, C) # [P, C]
159159

160-
labels = labels.view(-1)
160+
labels = labels.reshape(-1)
161161
if ignore is None:
162162
return probas, labels
163163
valid = labels != ignore

segmentation_models_pytorch/losses/mcc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
2929

3030
bs = y_true.shape[0]
3131

32-
y_true = y_true.view(bs, 1, -1)
33-
y_pred = y_pred.view(bs, 1, -1)
32+
y_true = y_true.reshape(bs, 1, -1)
33+
y_pred = y_pred.reshape(bs, 1, -1)
3434

3535
tp = torch.sum(torch.mul(y_pred, y_true)) + self.eps
3636
tn = torch.sum(torch.mul((1 - y_pred), (1 - y_true))) + self.eps

0 commit comments

Comments
 (0)