Skip to content

Commit 3b90cc0

Browse files
authored
[Feature] Add FastFCN (PaddlePaddle#1669)
1 parent 6d417cf commit 3b90cc0

File tree

4 files changed

+298
-0
lines changed

4 files changed

+298
-0
lines changed

configs/fastfcn/README.md

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# FastFCN: Rethinking Dilated Convolution in the Backbone for Semantic Segmentation
2+
3+
## Reference
4+
> Wu, Huikai, Junge Zhang, Kaiqi Huang, Kongming Liang, and Yizhou Yu. "Fastfcn: Rethinking dilated convolution in the backbone for semantic segmentation." arXiv preprint arXiv:1903.11816 (2019).
5+
6+
## Performance
7+
8+
### ADE20K
9+
10+
| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links |
11+
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
12+
|FastFCN|ResNet50_vd|480x480|120000|43.76%|44.11%|44.48%|[model](https://bj.bcebos.com/paddleseg/dygraph/ade20k/fastfcn_resnet50_os8_ade20k_480x480_120k/model.pdparams) \|[log](https://bj.bcebos.com/paddleseg/dygraph/ade20k/fastfcn_resnet50_os8_ade20k_480x480_120k/train.log)\|[vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=e159d5be3860b8d08762c0416ab54acc)|
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
_base_: '../_base_/ade20k.yml'
2+
batch_size: 4
3+
iters: 120000
4+
5+
train_dataset:
6+
transforms:
7+
- type: ResizeStepScaling
8+
min_scale_factor: 0.5
9+
max_scale_factor: 2.0
10+
scale_step_size: 0.25
11+
- type: RandomPaddingCrop
12+
crop_size: [480, 480]
13+
im_padding_value: [0, 0, 0]
14+
- type: RandomHorizontalFlip
15+
- type: RandomDistort
16+
brightness_range: 0.4
17+
contrast_range: 0.4
18+
saturation_range: 0.4
19+
- type: Normalize
20+
21+
model:
22+
type: FastFCN
23+
backbone:
24+
type: ResNet50_vd
25+
output_stride: 8
26+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
27+
num_codes: 32
28+
mid_channels: 512
29+
use_jpu: True
30+
aux_loss: True
31+
use_se_loss: True
32+
add_lateral: True
33+
34+
loss:
35+
types:
36+
- type: CrossEntropyLoss
37+
- type: CrossEntropyLoss
38+
- type: SECrossEntropyLoss
39+
coef: [1, 0.4, 0.2]
40+
41+
lr_scheduler:
42+
type: PolynomialDecay
43+
learning_rate: 0.01
44+
end_lr: 0
45+
power: 0.9

paddleseg/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,5 @@
5151
from .espnet import ESPNetV2
5252
from .dmnet import DMNet
5353
from .espnetv1 import ESPNetV1
54+
from .fastfcn import FastFCN
5455
from .pfpnnet import PFPNNet

paddleseg/models/fastfcn.py

+240
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.nn as nn
17+
import paddle.nn.functional as F
18+
19+
from paddleseg.cvlibs import manager
20+
from paddleseg.models import layers
21+
from paddleseg.utils import utils
22+
23+
24+
@manager.MODELS.add_component
25+
class FastFCN(nn.Layer):
26+
"""
27+
The FastFCN implementation based on PaddlePaddle.
28+
29+
The original article refers to
30+
Huikai Wu, Junge Zhang, Kaiqi Huang. "FastFCN: Rethinking Dilated Convolution in the Backbone for Semantic Segmentation".
31+
32+
Args:
33+
num_classes (int): The unique number of target classes.
34+
backbone (Paddle.nn.Layer): A backbone network.
35+
backbone_indices (tuple): The values in the tuple indicate the indices of
36+
output of backbone.
37+
num_codes (int): The number of encoded words. Default: 32.
38+
mid_channels (int): The channels of middle layers. Default: 512.
39+
use_jpu (bool): Whether use jpu module. Default: True.
40+
aux_loss (bool): Whether use auxiliary head loss. Default: True.
41+
use_se_loss (int): Whether use semantic encoding loss. Default: True.
42+
add_lateral (int): Whether use lateral convolution layers. Default: False.
43+
pretrained (str, optional): The path or url of pretrained model. Default: None.
44+
"""
45+
def __init__(self,
46+
num_classes,
47+
backbone,
48+
num_codes=32,
49+
mid_channels=512,
50+
use_jpu=True,
51+
aux_loss=True,
52+
use_se_loss=True,
53+
add_lateral=False,
54+
pretrained=None):
55+
super().__init__()
56+
self.add_lateral = add_lateral
57+
self.num_codes = num_codes
58+
self.backbone = backbone
59+
self.use_jpu = use_jpu
60+
in_channels = self.backbone.feat_channels
61+
62+
if use_jpu:
63+
self.jpu_layer = layers.JPU(in_channels, mid_channels)
64+
in_channels[-1] = mid_channels * 4
65+
self.bottleneck = layers.ConvBNReLU(
66+
in_channels[-1],
67+
mid_channels,
68+
1,
69+
padding=0,
70+
bias_attr=False,
71+
)
72+
else:
73+
self.bottleneck = layers.ConvBNReLU(
74+
in_channels[-1],
75+
mid_channels,
76+
3,
77+
padding=1,
78+
bias_attr=False,
79+
)
80+
if self.add_lateral:
81+
self.lateral_convs = nn.LayerList([
82+
layers.ConvBNReLU(in_channels[0],
83+
mid_channels,
84+
1,
85+
bias_attr=False),
86+
layers.ConvBNReLU(in_channels[1],
87+
mid_channels,
88+
1,
89+
bias_attr=False),
90+
])
91+
92+
self.fusion = layers.ConvBNReLU(
93+
3 * mid_channels,
94+
mid_channels,
95+
3,
96+
padding=1,
97+
bias_attr=False,
98+
)
99+
100+
self.enc_module = EncModule(mid_channels, num_codes)
101+
self.cls_seg = nn.Conv2D(mid_channels, num_classes, 1)
102+
103+
self.aux_loss = aux_loss
104+
if self.aux_loss:
105+
self.fcn_head = layers.AuxLayer(in_channels[-2], mid_channels,
106+
num_classes)
107+
108+
self.use_se_loss = use_se_loss
109+
if use_se_loss:
110+
self.se_layer = nn.Linear(mid_channels, num_classes)
111+
112+
self.pretrained = pretrained
113+
self.init_weight()
114+
115+
def init_weight(self):
116+
if self.pretrained is not None:
117+
utils.load_entire_model(self, self.pretrained)
118+
119+
def forward(self, inputs):
120+
imsize = paddle.shape(inputs)[2:]
121+
feats = self.backbone(inputs)
122+
if self.use_jpu:
123+
feats = self.jpu_layer(*feats)
124+
125+
fcn_feat = feats[2]
126+
127+
feat = self.bottleneck(feats[-1])
128+
if self.add_lateral:
129+
laterals = []
130+
for i, lateral_conv in enumerate(self.lateral_convs):
131+
laterals.append(
132+
F.interpolate(lateral_conv(feats[i]),
133+
size=paddle.shape(feat)[2:],
134+
mode='bilinear',
135+
align_corners=False))
136+
feat = self.fusion(paddle.concat([feat, *laterals], 1))
137+
encode_feat, feat = self.enc_module(feat)
138+
out = self.cls_seg(feat)
139+
out = F.interpolate(out,
140+
size=imsize,
141+
mode='bilinear',
142+
align_corners=False)
143+
output = [out]
144+
145+
if self.training:
146+
fcn_out = self.fcn_head(fcn_feat)
147+
fcn_out = F.interpolate(fcn_out,
148+
size=imsize,
149+
mode='bilinear',
150+
align_corners=False)
151+
output.append(fcn_out)
152+
if self.use_se_loss:
153+
se_out = self.se_layer(encode_feat)
154+
output.append(se_out)
155+
return output
156+
return output
157+
158+
159+
class Encoding(nn.Layer):
160+
def __init__(self, channels, num_codes):
161+
super().__init__()
162+
self.channels, self.num_codes = channels, num_codes
163+
164+
std = 1 / ((channels * num_codes)**0.5)
165+
self.codewords = self.create_parameter(
166+
shape=(num_codes, channels),
167+
default_initializer=nn.initializer.Uniform(-std, std),
168+
)
169+
self.scale = self.create_parameter(
170+
shape=(num_codes, ),
171+
default_initializer=nn.initializer.Uniform(-1, 0),
172+
)
173+
174+
def scaled_l2(self, x, codewords, scale):
175+
num_codes, channels = paddle.shape(codewords)
176+
reshaped_scale = scale.reshape([1, 1, num_codes])
177+
expanded_x = paddle.tile(x.unsqueeze(2), [1, 1, num_codes, 1])
178+
reshaped_codewords = codewords.reshape([1, 1, num_codes, channels])
179+
180+
scaled_l2_norm = reshaped_scale * (
181+
expanded_x - reshaped_codewords).pow(2).sum(axis=3)
182+
return scaled_l2_norm
183+
184+
def aggregate(self, assignment_weights, x, codewords):
185+
num_codes, channels = paddle.shape(codewords)
186+
reshaped_codewords = codewords.reshape([1, 1, num_codes, channels])
187+
expanded_x = paddle.tile(
188+
x.unsqueeze(2),
189+
[1, 1, num_codes, 1],
190+
)
191+
encoded_feat = (assignment_weights.unsqueeze(3) *
192+
(expanded_x - reshaped_codewords)).sum(axis=1)
193+
return encoded_feat
194+
195+
def forward(self, x):
196+
x_dims = x.ndim
197+
assert x_dims == 4, "The dimension of input tensor must equal 4, but got {}.".format(
198+
x_dims)
199+
assert paddle.shape(
200+
x
201+
)[1] == self.channels, "Encoding channels error, excepted {} but got {}.".format(
202+
self.channels,
203+
paddle.shape(x)[1])
204+
batch_size = paddle.shape(x)[0]
205+
x = x.reshape([batch_size, self.channels, -1]).transpose([0, 2, 1])
206+
assignment_weights = F.softmax(self.scaled_l2(x, self.codewords,
207+
self.scale),
208+
axis=2)
209+
210+
encoded_feat = self.aggregate(assignment_weights, x, self.codewords)
211+
encoded_feat = encoded_feat.reshape([batch_size, self.num_codes, -1])
212+
return encoded_feat
213+
214+
215+
class EncModule(nn.Layer):
216+
def __init__(self, in_channels, num_codes):
217+
super().__init__()
218+
self.encoding_project = layers.ConvBNReLU(
219+
in_channels,
220+
in_channels,
221+
1,
222+
)
223+
self.encoding = nn.Sequential(
224+
Encoding(channels=in_channels, num_codes=num_codes),
225+
nn.BatchNorm1D(num_codes),
226+
nn.ReLU(),
227+
)
228+
self.fc = nn.Sequential(
229+
nn.Linear(in_channels, in_channels),
230+
nn.Sigmoid(),
231+
)
232+
233+
def forward(self, x):
234+
encoding_projection = self.encoding_project(x)
235+
encoding_feat = self.encoding(encoding_projection).mean(axis=1)
236+
batch_size, channels, _, _ = paddle.shape(x)
237+
gamma = self.fc(encoding_feat)
238+
y = gamma.reshape([batch_size, channels, 1, 1])
239+
output = F.relu(x + x * y)
240+
return encoding_feat, output

0 commit comments

Comments
 (0)