Skip to content

Commit 874f6b6

Browse files
authored
[Feature] Add ENCNet and SECrossEntropyLoss (PaddlePaddle#1648)
1 parent bfca53f commit 874f6b6

6 files changed

+318
-0
lines changed

configs/encnet/README.md

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# ENCNet: Context Encoding for Semantic Segmentation
2+
3+
## Reference
4+
> Hang Zhang, Kristin Dana, et, al. "Context Encoding for Semantic Segmentation". In Proceedings of the IEEE conference on Computer Vision and Pattern Recognition, pp. 7151-7160. 2018.
5+
6+
## Performance
7+
8+
### Cityscapes
9+
10+
| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links |
11+
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
12+
|ENCNet|ResNet101_vd|1024x512|80000|79.42%|80.02%|-|[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/encnet_resnet101_os8_cityscapes_1024x512_80k/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/encnet_resnet101_os8_cityscapes_1024x512_80k/train.log )\| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/index?id=c2b819e6b666e4e50bba4b525f515d41)|
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
_base_: '../_base_/cityscapes.yml'
2+
3+
batch_size: 2
4+
iters: 80000
5+
6+
model:
7+
type: ENCNet
8+
backbone:
9+
type: ResNet101_vd
10+
output_stride: 8
11+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet101_vd_ssld.tar.gz
12+
num_codes: 32
13+
mid_channels: 512
14+
backbone_indices: [1, 2, 3]
15+
use_se_loss: True
16+
add_lateral: True
17+
18+
optimizer:
19+
type: sgd
20+
weight_decay: 0.0005
21+
22+
loss:
23+
types:
24+
- type: CrossEntropyLoss
25+
- type: CrossEntropyLoss
26+
- type: SECrossEntropyLoss
27+
coef: [1, 0.4, 0.2]
28+
29+
lr_scheduler:
30+
type: PolynomialDecay
31+
learning_rate: 0.01
32+
end_lr: 0.0
33+
power: 0.9

paddleseg/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from .ginet import GINet
4747
from .segmenter import *
4848
from .segnet import SegNet
49+
from .encnet import ENCNet
4950
from .hrnet_contrast import HRNetW48Contrast
5051
from .espnet import ESPNetV2
5152
from .dmnet import DMNet

paddleseg/models/encnet.py

+224
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.nn as nn
17+
import paddle.nn.functional as F
18+
19+
from paddleseg.cvlibs import manager
20+
from paddleseg.models import layers
21+
from paddleseg.utils import utils
22+
23+
24+
@manager.MODELS.add_component
25+
class ENCNet(nn.Layer):
26+
"""
27+
The ENCNet implementation based on PaddlePaddle.
28+
29+
The original article refers to
30+
Hang Zhang, Kristin Dana, et, al. "Context Encoding for Semantic Segmentation".
31+
32+
Args:
33+
num_classes (int): The unique number of target classes.
34+
backbone (Paddle.nn.Layer): A backbone network.
35+
backbone_indices (tuple): The values in the tuple indicate the indices of
36+
output of backbone.
37+
num_codes (int): The number of encoded words. Default: 32.
38+
mid_channels (int): The channels of middle layers. Default: 512.
39+
use_se_loss (int): Whether use semantic encoding loss. Default: True.
40+
add_lateral (int): Whether use lateral convolution layers. Default: False.
41+
pretrained (str, optional): The path or url of pretrained model. Default: None.
42+
"""
43+
def __init__(self,
44+
num_classes,
45+
backbone,
46+
backbone_indices=[1, 2, 3],
47+
num_codes=32,
48+
mid_channels=512,
49+
use_se_loss=True,
50+
add_lateral=False,
51+
pretrained=None):
52+
super().__init__()
53+
self.add_lateral = add_lateral
54+
self.num_codes = num_codes
55+
self.backbone = backbone
56+
self.backbone_indices = backbone_indices
57+
in_channels = [
58+
self.backbone.feat_channels[index] for index in backbone_indices
59+
]
60+
61+
self.bottleneck = layers.ConvBNReLU(
62+
in_channels[-1],
63+
mid_channels,
64+
3,
65+
padding=1,
66+
)
67+
if self.add_lateral:
68+
self.lateral_convs = nn.LayerList()
69+
for in_ch in in_channels[:-1]:
70+
self.lateral_convs.append(
71+
layers.ConvBNReLU(
72+
in_ch,
73+
mid_channels,
74+
1,
75+
))
76+
self.fusion = layers.ConvBNReLU(
77+
len(in_channels) * mid_channels,
78+
mid_channels,
79+
3,
80+
padding=1,
81+
)
82+
83+
self.enc_module = EncModule(mid_channels, num_codes)
84+
self.head = nn.Conv2D(mid_channels, num_classes, 1)
85+
86+
self.fcn_head = layers.AuxLayer(self.backbone.feat_channels[2],
87+
mid_channels, num_classes)
88+
89+
self.use_se_loss = use_se_loss
90+
if use_se_loss:
91+
self.se_layer = nn.Linear(mid_channels, num_classes)
92+
93+
self.pretrained = pretrained
94+
self.init_weight()
95+
96+
def init_weight(self):
97+
if self.pretrained is not None:
98+
utils.load_entire_model(self, self.pretrained)
99+
100+
def forward(self, inputs):
101+
N, C, H, W = paddle.shape(inputs)
102+
feats = self.backbone(inputs)
103+
fcn_feat = feats[2]
104+
105+
feats = [feats[i] for i in self.backbone_indices]
106+
feat = self.bottleneck(feats[-1])
107+
108+
if self.add_lateral:
109+
laterals = []
110+
for i, lateral_conv in enumerate(self.lateral_convs):
111+
laterals.append(
112+
F.interpolate(lateral_conv(feats[i]),
113+
size=paddle.shape(feat)[2:],
114+
mode='bilinear',
115+
align_corners=False))
116+
feat = self.fusion(paddle.concat([feat, *laterals], 1))
117+
encode_feat, feat = self.enc_module(feat)
118+
out = self.head(feat)
119+
out = F.interpolate(out,
120+
size=[H, W],
121+
mode='bilinear',
122+
align_corners=False)
123+
output = [out]
124+
if self.training:
125+
fcn_out = self.fcn_head(fcn_feat)
126+
fcn_out = F.interpolate(fcn_out,
127+
size=[H, W],
128+
mode='bilinear',
129+
align_corners=False)
130+
output.append(fcn_out)
131+
if self.use_se_loss:
132+
se_out = self.se_layer(encode_feat)
133+
output.append(se_out)
134+
return output
135+
return output
136+
137+
138+
class Encoding(nn.Layer):
139+
def __init__(self, channels, num_codes):
140+
super().__init__()
141+
self.channels, self.num_codes = channels, num_codes
142+
143+
std = 1 / ((channels * num_codes)**0.5)
144+
self.codewords = self.create_parameter(
145+
shape=(num_codes, channels),
146+
default_initializer=nn.initializer.Uniform(-std, std),
147+
)
148+
self.scale = self.create_parameter(
149+
shape=(num_codes, ),
150+
default_initializer=nn.initializer.Uniform(-1, 0),
151+
)
152+
self.channels = channels
153+
154+
def scaled_l2(self, x, codewords, scale):
155+
num_codes, channels = paddle.shape(codewords)
156+
reshaped_scale = scale.reshape([1, 1, num_codes])
157+
expanded_x = paddle.tile(x.unsqueeze(2), [1, 1, num_codes, 1])
158+
reshaped_codewords = codewords.reshape([1, 1, num_codes, channels])
159+
160+
scaled_l2_norm = paddle.multiply(
161+
reshaped_scale,
162+
(expanded_x - reshaped_codewords).pow(2).sum(axis=3))
163+
return scaled_l2_norm
164+
165+
def aggregate(self, assignment_weights, x, codewords):
166+
num_codes, channels = paddle.shape(codewords)
167+
reshaped_codewords = codewords.reshape([1, 1, num_codes, channels])
168+
expanded_x = paddle.tile(x.unsqueeze(2), [1, 1, num_codes, 1])
169+
170+
encoded_feat = paddle.multiply(
171+
assignment_weights.unsqueeze(3),
172+
(expanded_x - reshaped_codewords)).sum(axis=1)
173+
encoded_feat = paddle.reshape(encoded_feat,
174+
[-1, self.num_codes, self.channels])
175+
return encoded_feat
176+
177+
def forward(self, x):
178+
x_dims = x.ndim
179+
assert x_dims == 4, "The dimension of input tensor must equal 4, but got {}.".format(
180+
x_dims)
181+
assert paddle.shape(
182+
x
183+
)[1] == self.channels, "Encoding channels error, excepted {} but got {}.".format(
184+
self.channels,
185+
paddle.shape(x)[1])
186+
batch_size = paddle.shape(x)[0]
187+
x = x.reshape([batch_size, self.channels, -1]).transpose([0, 2, 1])
188+
assignment_weights = F.softmax(self.scaled_l2(x, self.codewords,
189+
self.scale),
190+
axis=2)
191+
encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
192+
return encoded_feat
193+
194+
195+
class EncModule(nn.Layer):
196+
def __init__(self, in_channels, num_codes):
197+
super().__init__()
198+
self.encoding_project = layers.ConvBNReLU(
199+
in_channels,
200+
in_channels,
201+
1,
202+
)
203+
self.encoding = nn.Sequential(
204+
Encoding(channels=in_channels, num_codes=num_codes),
205+
nn.BatchNorm1D(num_codes),
206+
nn.ReLU(),
207+
)
208+
self.fc = nn.Sequential(
209+
nn.Linear(in_channels, in_channels),
210+
nn.Sigmoid(),
211+
)
212+
self.in_channels = in_channels
213+
214+
def forward(self, x):
215+
encoding_projection = self.encoding_project(x)
216+
encoding_feat = self.encoding(encoding_projection)
217+
218+
encoding_feat = encoding_feat.mean(axis=1)
219+
batch_size, _, _, _ = paddle.shape(x)
220+
221+
gamma = self.fc(encoding_feat)
222+
y = gamma.reshape([batch_size, self.in_channels, 1, 1])
223+
output = F.relu(x + x * y)
224+
return encoding_feat, output

paddleseg/models/losses/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,4 @@
3232
from .detail_aggregate_loss import DetailAggregateLoss
3333
from .point_cross_entropy_loss import PointCrossEntropyLoss
3434
from .pixel_contrast_cross_entropy_loss import PixelContrastCrossEntropyLoss
35+
from .semantic_encode_cross_entropy_loss import SECrossEntropyLoss
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.nn as nn
17+
import paddle.nn.functional as F
18+
19+
from paddleseg.cvlibs import manager
20+
21+
22+
@manager.LOSSES.add_component
23+
class SECrossEntropyLoss(nn.Layer):
24+
"""
25+
The Semantic Encoding Loss implementation based on PaddlePaddle.
26+
27+
"""
28+
def __init__(self, *args, **kwargs):
29+
super(SECrossEntropyLoss, self).__init__()
30+
31+
def forward(self, logit, label):
32+
if logit.ndim == 4:
33+
logit = logit.squeeze(2).squeeze(3)
34+
assert logit.ndim == 2, "The shape of logit should be [N, C, 1, 1] or [N, C], but the logit dim is {}.".format(
35+
logit.ndim)
36+
37+
batch_size, num_classes = paddle.shape(logit)
38+
se_label = paddle.zeros([batch_size, num_classes])
39+
for i in range(batch_size):
40+
hist = paddle.histogram(label[i],
41+
bins=num_classes,
42+
min=0,
43+
max=num_classes - 1)
44+
hist = hist.astype('float32') / hist.sum().astype('float32')
45+
se_label[i] = (hist > 0).astype('float32')
46+
loss = F.binary_cross_entropy_with_logits(logit, se_label)
47+
return loss

0 commit comments

Comments
 (0)