Skip to content

Commit 1d05225

Browse files
authored
Fix RMI Loss (PaddlePaddle#1192)
1 parent 527ca16 commit 1d05225

File tree

5 files changed

+269
-113
lines changed

5 files changed

+269
-113
lines changed

.gitignore

Lines changed: 0 additions & 113 deletions
This file was deleted.

configs/deeplabv3p/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links |
1212
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
1313
|DeepLabV3P|ResNet50_OS8|1024x512|80000|80.36%|80.57%|80.81%|[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/deeplabv3p_resnet50_os8_cityscapes_1024x512_80k/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/deeplabv3p_resnet50_os8_cityscapes_1024x512_80k/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=860bd0049ba5495d629a96d5aaf1bf75)|
14+
|DeepLabV3P*|ResNet50_OS8|1024x512|80000|81.18%| 81.42% | 81.48% |[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/deeplabv3p_resnet50_os8_cityscapes_1024x512_80k_rmiloss/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/deeplabv3p_resnet50_os8_cityscapes_1024x512_80k_rmiloss/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=ce094fb8a42c056b6edb92f975cfa0e3)|
1415
|DeepLabV3P|ResNet101_OS8|1024x512|80000|81.10%|81.38%|81.24%|[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/deeplabv3p_resnet101_os8_cityscapes_1024x512_80k/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/deeplabv3p_resnet101_os8_cityscapes_1024x512_80k/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=8b11e75b8977a0fd74180145350c27de)|
1516
|DeepLabV3P|ResNet101_OS8|769x769|80000|81.53%|81.88%|82.12%|[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/deeplabv3p_resnet101_os8_cityscapes_769x769_80k/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/deeplabv3p_resnet101_os8_cityscapes_769x769_80k/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=420039406361cbc3cf7ec14c1084d886)|
1617

18+
DeepLabV3P* is DeepLabV3P with [RMI Loss](https://arxiv.org/abs/1910.12037), which requires paddlepaddle=2.2.
19+
1720
### Pascal VOC 2012 + Aug
1821

1922
| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links |
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
_base_: 'deeplabv3p_resnet50_os8_cityscapes_1024x512_80k.yml'
2+
3+
loss:
4+
types:
5+
- type: MixedLoss
6+
losses:
7+
- type: CrossEntropyLoss
8+
- type: RMILoss
9+
coef: [0.5, 0.5]

paddleseg/models/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@
2525
from .ohem_edge_attention_loss import OhemEdgeAttentionLoss
2626
from .l1_loss import L1Loss
2727
from .mean_square_error_loss import MSELoss
28+
from .rmi_loss import RMILoss

paddleseg/models/losses/rmi_loss.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""rmi loss in PaddlePaddle"""
15+
import numpy
16+
import paddle
17+
import paddle.nn as nn
18+
import paddle.nn.functional as F
19+
20+
from paddleseg.cvlibs import manager
21+
22+
_euler_num = 2.718281828
23+
_pi = 3.14159265
24+
_ln_2_pi = 1.837877
25+
_CLIP_MIN = 1e-6
26+
_CLIP_MAX = 1.0
27+
_POS_ALPHA = 5e-4
28+
_IS_SUM = 1
29+
30+
31+
@manager.LOSSES.add_component
32+
class RMILoss(nn.Layer):
33+
"""
34+
Implements the Region Mutual Information(RMI) Loss(https://arxiv.org/abs/1910.12037) for Semantic Segmentation.
35+
Unlike vanilla rmi loss which contains Cross Entropy Loss, we disband them and only
36+
left the RMI-related parts.
37+
The motivation is to allow for a more flexible combination of losses during training.
38+
For example, by employing mixed loss to merge RMI Loss with Boostrap Cross Entropy Loss,
39+
we can achieve the online mining of hard examples together with attention to region information.
40+
Args:
41+
weight (tuple|list|ndarray|Tensor, optional): A manual rescaling weight
42+
given to each class. Its length must be equal to the number of classes.
43+
Default ``None``.
44+
ignore_index (int64, optional): Specifies a target value that is ignored
45+
and does not contribute to the input gradient. Default ``255``.
46+
"""
47+
48+
def __init__(self,
49+
num_classes=19,
50+
rmi_radius=3,
51+
rmi_pool_way=0,
52+
rmi_pool_size=3,
53+
rmi_pool_stride=3,
54+
loss_weight_lambda=0.5,
55+
ignore_index=255):
56+
super(RMILoss, self).__init__()
57+
58+
self.num_classes = num_classes
59+
assert rmi_radius in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
60+
self.rmi_radius = rmi_radius
61+
assert rmi_pool_way in [0, 1, 2, 3]
62+
self.rmi_pool_way = rmi_pool_way
63+
assert rmi_pool_size == rmi_pool_stride
64+
self.rmi_pool_size = rmi_pool_size
65+
self.rmi_pool_stride = rmi_pool_stride
66+
self.weight_lambda = loss_weight_lambda
67+
self.half_d = self.rmi_radius * self.rmi_radius
68+
self.d = 2 * self.half_d
69+
self.kernel_padding = self.rmi_pool_size // 2
70+
self.ignore_index = ignore_index
71+
72+
def forward(self, logits_4D, labels_4D, do_rmi=True):
73+
"""
74+
Forward computation.
75+
Args:
76+
logits (Tensor): Shape is [N, C, H, W], logits at each prediction (between -\infty and +\infty).
77+
labels (Tensor): Shape is [N, H, W], ground truth labels (between 0 and C - 1).
78+
"""
79+
logits_4D = paddle.cast(logits_4D, dtype='float32')
80+
labels_4D = paddle.cast(labels_4D, dtype='float32')
81+
82+
loss = self.forward_sigmoid(logits_4D, labels_4D, do_rmi=do_rmi)
83+
return loss
84+
85+
def forward_sigmoid(self, logits_4D, labels_4D, do_rmi=False):
86+
"""
87+
Using the sigmiod operation both.
88+
Args:
89+
logits_4D : [N, C, H, W], dtype=float32
90+
labels_4D : [N, H, W], dtype=long
91+
do_rmi : bool
92+
"""
93+
label_mask_3D = labels_4D != self.ignore_index
94+
valid_onehot_labels_4D = paddle.cast(
95+
F.one_hot(
96+
paddle.cast(labels_4D, dtype='int64') * paddle.cast(
97+
label_mask_3D, dtype='int64'),
98+
num_classes=self.num_classes),
99+
dtype='float32')
100+
# label_mask_flat = paddle.cast(
101+
# paddle.reshape(label_mask_3D, [-1]), dtype='float32')
102+
103+
valid_onehot_labels_4D = valid_onehot_labels_4D * paddle.unsqueeze(
104+
label_mask_3D, axis=3)
105+
valid_onehot_labels_4D.stop_gradient = True
106+
probs_4D = F.sigmoid(logits_4D) * paddle.unsqueeze(
107+
label_mask_3D, axis=1) + _CLIP_MIN
108+
109+
valid_onehot_labels_4D = paddle.transpose(valid_onehot_labels_4D,
110+
[0, 3, 1, 2])
111+
valid_onehot_labels_4D.stop_gradient = True
112+
rmi_loss = self.rmi_lower_bound(valid_onehot_labels_4D, probs_4D)
113+
114+
return rmi_loss
115+
116+
def inverse(self, x):
117+
return paddle.inverse(x)
118+
119+
def rmi_lower_bound(self, labels_4D, probs_4D):
120+
"""
121+
calculate the lower bound of the region mutual information.
122+
Args:
123+
labels_4D : [N, C, H, W], dtype=float32
124+
probs_4D : [N, C, H, W], dtype=float32
125+
"""
126+
assert labels_4D.shape == probs_4D.shape, print(
127+
'shapes', labels_4D.shape, probs_4D.shape)
128+
129+
p, s = self.rmi_pool_size, self.rmi_pool_stride
130+
if self.rmi_pool_stride > 1:
131+
if self.rmi_pool_way == 0:
132+
labels_4D = F.max_pool2d(
133+
labels_4D,
134+
kernel_size=p,
135+
stride=s,
136+
padding=self.kernel_padding)
137+
probs_4D = F.max_pool2d(
138+
probs_4D,
139+
kernel_size=p,
140+
stride=s,
141+
padding=self.kernel_padding)
142+
elif self.rmi_pool_way == 1:
143+
labels_4D = F.avg_pool2d(
144+
labels_4D,
145+
kernel_size=p,
146+
stride=s,
147+
padding=self.kernel_padding)
148+
probs_4D = F.avg_pool2d(
149+
probs_4D,
150+
kernel_size=p,
151+
stride=s,
152+
padding=self.kernel_padding)
153+
elif self.rmi_pool_way == 2:
154+
shape = labels_4D.shape
155+
new_h, new_w = shape[2] // s, shape[3] // s
156+
labels_4D = F.interpolate(
157+
labels_4D, size=(new_h, new_w), mode='nearest')
158+
probs_4D = F.interpolate(
159+
probs_4D,
160+
size=(new_h, new_w),
161+
mode='bilinear',
162+
align_corners=True)
163+
else:
164+
raise NotImplementedError("Pool way of RMI is not defined!")
165+
166+
label_shape = labels_4D.shape
167+
n, c = label_shape[0], label_shape[1]
168+
169+
la_vectors, pr_vectors = self.map_get_pairs(
170+
labels_4D, probs_4D, radius=self.rmi_radius, is_combine=0)
171+
172+
la_vectors = paddle.reshape(la_vectors, [n, c, self.half_d, -1])
173+
la_vectors = paddle.cast(la_vectors, dtype='float64')
174+
la_vectors.stop_gradient = True
175+
176+
pr_vectors = paddle.reshape(pr_vectors, [n, c, self.half_d, -1])
177+
pr_vectors = paddle.cast(pr_vectors, dtype='float64')
178+
179+
diag_matrix = paddle.unsqueeze(
180+
paddle.unsqueeze(paddle.eye(self.half_d), axis=0), axis=0)
181+
la_vectors = la_vectors - paddle.mean(la_vectors, axis=3, keepdim=True)
182+
183+
la_cov = paddle.matmul(la_vectors,
184+
paddle.transpose(la_vectors, [0, 1, 3, 2]))
185+
pr_vectors = pr_vectors - paddle.mean(pr_vectors, axis=3, keepdim=True)
186+
pr_cov = paddle.matmul(pr_vectors,
187+
paddle.transpose(pr_vectors, [0, 1, 3, 2]))
188+
189+
pr_cov_inv = self.inverse(
190+
pr_cov + paddle.cast(diag_matrix, dtype='float64') * _POS_ALPHA)
191+
192+
la_pr_cov = paddle.matmul(la_vectors,
193+
paddle.transpose(pr_vectors, [0, 1, 3, 2]))
194+
195+
appro_var = la_cov - paddle.matmul(
196+
paddle.matmul(la_pr_cov, pr_cov_inv),
197+
paddle.transpose(la_pr_cov, [0, 1, 3, 2]))
198+
199+
rmi_now = 0.5 * self.log_det_by_cholesky(
200+
appro_var + paddle.cast(diag_matrix, dtype='float64') * _POS_ALPHA)
201+
202+
rmi_per_class = paddle.cast(
203+
paddle.mean(
204+
paddle.reshape(rmi_now, [-1, self.num_classes]), axis=0),
205+
dtype='float32')
206+
rmi_per_class = paddle.divide(rmi_per_class,
207+
paddle.to_tensor(float(self.half_d)))
208+
209+
rmi_loss = paddle.sum(rmi_per_class) if _IS_SUM else paddle.mean(
210+
rmi_per_class)
211+
212+
return rmi_loss
213+
214+
def log_det_by_cholesky(self, matrix):
215+
"""
216+
Args:
217+
matrix: matrix must be a positive define matrix.
218+
shape [N, C, D, D].
219+
"""
220+
221+
chol = paddle.cholesky(matrix)
222+
diag = paddle.diagonal(chol, offset=0, axis1=-2, axis2=-1)
223+
chol = paddle.log(diag + 1e-8)
224+
225+
return 2.0 * paddle.sum(chol, axis=-1)
226+
227+
def map_get_pairs(self, labels_4D, probs_4D, radius=3, is_combine=True):
228+
"""
229+
Args:
230+
labels_4D : labels, shape [N, C, H, W]
231+
probs_4D : probabilities, shape [N, C, H, W]
232+
radius : the square radius
233+
Return:
234+
tensor with shape [N, C, radius * radius, H - (radius - 1), W - (radius - 1)]
235+
"""
236+
237+
label_shape = labels_4D.shape
238+
h, w = label_shape[2], label_shape[3]
239+
new_h, new_w = h - (radius - 1), w - (radius - 1)
240+
la_ns = []
241+
pr_ns = []
242+
for y in range(0, radius, 1):
243+
for x in range(0, radius, 1):
244+
la_now = labels_4D[:, :, y:y + new_h, x:x + new_w]
245+
pr_now = probs_4D[:, :, y:y + new_h, x:x + new_w]
246+
la_ns.append(la_now)
247+
pr_ns.append(pr_now)
248+
249+
if is_combine:
250+
pair_ns = la_ns + pr_ns
251+
p_vectors = paddle.stack(pair_ns, axis=2)
252+
return p_vectors
253+
else:
254+
la_vectors = paddle.stack(la_ns, axis=2)
255+
pr_vectors = paddle.stack(pr_ns, axis=2)
256+
return la_vectors, pr_vectors

0 commit comments

Comments
 (0)