Skip to content

Commit e2b89e4

Browse files
authored
[Feature] Add Upernet (PaddlePaddle#2175)
1 parent 3139536 commit e2b89e4

8 files changed

+307
-1
lines changed

configs/upernet/README.md

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Unified Perceptual Parsing for SceneUnderstanding
2+
3+
4+
## Reference
5+
6+
> Tete Xiao, Yingcheng Liu, Bolei Zhou, Yuning Jiang, Jian Sun. "Unified Perceptual Parsing for Scene Understanding." Proceedings of the European Conference on Computer Vision (ECCV). 2018.
7+
8+
## Performance
9+
10+
### Cityscapes
11+
12+
| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links |
13+
|-|-|-|-|-|-|-|-|
14+
|UPerNet|ResNet101_OS8|512x1024|40000|79.58%|80.11%|80.41%|[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/upernet_resnet101_os8_cityscapes_512x1024_40k/model.pdparams)\|[log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/upernet_resnet101_os8_cityscapes_512x1024_40k/train.log)\|[vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/index?id=c635ae2e70e148796cd58fae5273c3d6)|
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
_base_: '../_base_/cityscapes.yml'
2+
3+
batch_size: 2
4+
iters: 40000
5+
6+
model:
7+
type: UPerNet
8+
backbone:
9+
type: ResNet101_vd
10+
output_stride: 8
11+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet101_vd_ssld.tar.gz
12+
backbone_indices: [0, 1, 2, 3]
13+
channels: 512
14+
dropout_prob: 0.1
15+
enable_auxiliary_loss: True
16+
17+
optimizer:
18+
type: sgd
19+
weight_decay: 0.0005
20+
21+
loss:
22+
types:
23+
- type: CrossEntropyLoss
24+
types:
25+
- type: CrossEntropyLoss
26+
coef: [1, 0.4]
27+
28+
29+
lr_scheduler:
30+
type: PolynomialDecay
31+
learning_rate: 0.01
32+
end_lr: 0.0
33+
power: 0.9

paddleseg/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,4 @@
6060
from .ddrnet import DDRNet_23
6161
from .ccnet import CCNet
6262
from .mobileseg import MobileSeg
63+
from .upernet import UPerNet

paddleseg/models/upernet.py

+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.nn as nn
17+
import paddle.nn.functional as F
18+
19+
from paddleseg import utils
20+
from paddleseg.cvlibs import manager
21+
from paddleseg.models import layers
22+
23+
24+
@manager.MODELS.add_component
25+
class UPerNet(nn.Layer):
26+
"""
27+
The UPerNet implementation based on PaddlePaddle.
28+
29+
The original article refers to
30+
Tete Xiao, et, al. "Unified Perceptual Parsing for Scene Understanding"
31+
(https://arxiv.org/abs/1807.10221).
32+
33+
Args:
34+
num_classes (int): The unique number of target classes.
35+
backbone (Paddle.nn.Layer): Backbone network, currently support Resnet50/101.
36+
backbone_indices (tuple): Four values in the tuple indicate the indices of output of backbone.
37+
channels (int): The channels of inter layers. Default: 512.
38+
enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: False.
39+
align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
40+
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
41+
dropout_prob (float): Dropout ratio for upernet head. Default: 0.1.
42+
pretrained (str, optional): The path or url of pretrained model. Default: None.
43+
"""
44+
45+
def __init__(self,
46+
num_classes,
47+
backbone,
48+
backbone_indices,
49+
channels=512,
50+
enable_auxiliary_loss=False,
51+
align_corners=False,
52+
dropout_prob=0.1,
53+
pretrained=None):
54+
super().__init__()
55+
self.backbone = backbone
56+
self.backbone_indices = backbone_indices
57+
self.in_channels = [
58+
self.backbone.feat_channels[i] for i in backbone_indices
59+
]
60+
self.align_corners = align_corners
61+
self.pretrained = pretrained
62+
self.enable_auxiliary_loss = enable_auxiliary_loss
63+
64+
fpn_inplanes = [
65+
self.backbone.feat_channels[i] for i in backbone_indices
66+
]
67+
self.head = UPerNetHead(
68+
num_classes=num_classes,
69+
fpn_inplanes=fpn_inplanes,
70+
dropout_prob=dropout_prob,
71+
channels=channels,
72+
enable_auxiliary_loss=self.enable_auxiliary_loss)
73+
self.init_weight()
74+
75+
def forward(self, x):
76+
feats = self.backbone(x)
77+
feats = [feats[i] for i in self.backbone_indices]
78+
logit_list = self.head(feats)
79+
logit_list = [
80+
F.interpolate(
81+
logit,
82+
paddle.shape(x)[2:],
83+
mode='bilinear',
84+
align_corners=self.align_corners) for logit in logit_list
85+
]
86+
return logit_list
87+
88+
def init_weight(self):
89+
if self.pretrained is not None:
90+
utils.load_entire_model(self, self.pretrained)
91+
92+
93+
class UPerNetHead(nn.Layer):
94+
def __init__(self,
95+
num_classes,
96+
fpn_inplanes,
97+
channels,
98+
dropout_prob=0.1,
99+
enable_auxiliary_loss=False,
100+
align_corners=True):
101+
super(UPerNetHead, self).__init__()
102+
self.align_corners = align_corners
103+
self.ppm = layers.PPModule(
104+
in_channels=fpn_inplanes[-1],
105+
out_channels=channels,
106+
bin_sizes=(1, 2, 3, 6),
107+
dim_reduction=True,
108+
align_corners=True)
109+
self.enable_auxiliary_loss = enable_auxiliary_loss
110+
self.lateral_convs = nn.LayerList()
111+
self.fpn_convs = nn.LayerList()
112+
113+
for fpn_inplane in fpn_inplanes[:-1]:
114+
self.lateral_convs.append(
115+
layers.ConvBNReLU(fpn_inplane, channels, 1))
116+
self.fpn_convs.append(
117+
layers.ConvBNReLU(
118+
channels, channels, 3, bias_attr=False))
119+
120+
if self.enable_auxiliary_loss:
121+
self.aux_head = layers.AuxLayer(
122+
fpn_inplanes[2],
123+
fpn_inplanes[2],
124+
num_classes,
125+
dropout_prob=dropout_prob)
126+
127+
self.fpn_bottleneck = layers.ConvBNReLU(
128+
len(fpn_inplanes) * channels, channels, 3, padding=1)
129+
130+
self.conv_last = nn.Sequential(
131+
layers.ConvBNReLU(
132+
len(fpn_inplanes) * channels, channels, 3, bias_attr=False),
133+
nn.Conv2D(
134+
channels, num_classes, kernel_size=1))
135+
self.conv_seg = nn.Conv2D(channels, num_classes, kernel_size=1)
136+
137+
def forward(self, inputs):
138+
laterals = []
139+
for i, lateral_conv in enumerate(self.lateral_convs):
140+
laterals.append(lateral_conv(inputs[i]))
141+
142+
laterals.append(self.ppm(inputs[-1]))
143+
fpn_levels = len(laterals)
144+
for i in range(fpn_levels - 1, 0, -1):
145+
prev_shape = paddle.shape(laterals[i - 1])
146+
laterals[i - 1] = laterals[i - 1] + F.interpolate(
147+
laterals[i],
148+
size=prev_shape[2:],
149+
mode='bilinear',
150+
align_corners=self.align_corners)
151+
152+
fpn_outs = []
153+
for i in range(fpn_levels - 1):
154+
fpn_outs.append(self.fpn_convs[i](laterals[i]))
155+
fpn_outs.append(laterals[-1])
156+
157+
for i in range(fpn_levels - 1, 0, -1):
158+
fpn_outs[i] = F.interpolate(
159+
fpn_outs[i],
160+
size=paddle.shape(fpn_outs[0])[2:],
161+
mode='bilinear',
162+
align_corners=self.align_corners)
163+
fuse_out = paddle.concat(fpn_outs, axis=1)
164+
x = self.fpn_bottleneck(fuse_out)
165+
166+
x = self.conv_seg(x)
167+
logits_list = [x]
168+
if self.enable_auxiliary_loss:
169+
aux_out = self.aux_head(inputs[2])
170+
logits_list.append(aux_out)
171+
return logits_list
172+
else:
173+
return logits_list

test_tipc/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
| CCNet | CCNet | 支持 | - | - | - |
4242
| PP-LiteSeg | PP-LiteSeg(STDC-1) | 支持 | - | - | - |
4343
| PP-LiteSeg | PP-LiteSeg(STDC-2) | 支持 | - | - | - |
44+
| UPerNet | UPerNet | 支持 | - | - | - |
4445

4546
## 3. 测试工具简介
4647
### 目录介绍
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
===========================train_params===========================
2+
model_name:upernet
3+
python:python3
4+
gpu_list:0|0,1
5+
Global.use_gpu:null|null
6+
--precision:null
7+
--iters:lite_train_lite_infer=20|lite_train_whole_infer=20|whole_train_whole_infer=1000
8+
--save_dir:
9+
--batch_size:lite_train_lite_infer=2|lite_train_whole_infer=2|whole_train_whole_infer=3
10+
--model_path:null
11+
train_model_name:best_model/model.pdparams
12+
train_infer_img_dir:test_tipc/data/cityscapes/cityscapes_val_5.list
13+
null:null
14+
##
15+
trainer:norm
16+
norm_train:train.py --config test_tipc/configs/upernet/upernet_resnet101_os8_cityscapes_512x1024_40k.yml --save_interval 500 --seed 100 --num_workers 8
17+
pact_train:null
18+
fpgm_train:null
19+
distill_train:null
20+
null:null
21+
null:null
22+
##
23+
===========================eval_params===========================
24+
eval:val.py --config test_tipc/configs/upernet/upernet_resnet101_os8_cityscapes_512x1024_40k.yml --num_workers 8
25+
null:null
26+
##
27+
===========================export_params===========================
28+
--save_dir:
29+
--model_path:
30+
norm_export:export.py --config test_tipc/configs/upernet/upernet_resnet101_os8_cityscapes_512x1024_40k.yml
31+
quant_export:null
32+
fpgm_export:null
33+
distill_export:null
34+
export1:null
35+
export2:null
36+
===========================infer_params===========================
37+
infer_model:./test_tipc/output/upernet/model.pdparams
38+
infer_export:export.py --config test_tipc/configs/upernet/upernet_resnet101_os8_cityscapes_512x1024_40k.yml
39+
infer_quant:False
40+
inference:deploy/python/infer.py
41+
--device:cpu|gpu
42+
--enable_mkldnn:True|False
43+
--cpu_threads:6
44+
--batch_size:1
45+
--use_trt:False
46+
--precision:fp32
47+
--config:
48+
--image_path:./test_tipc/data/cityscapes/cityscapes_val_5.list
49+
--save_log_path:null
50+
--benchmark:True
51+
--save_dir:
52+
--model_name:upernet
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
_base_: '../_base_/cityscapes.yml'
2+
3+
batch_size: 2
4+
iters: 40000
5+
6+
model:
7+
type: UPerNet
8+
backbone:
9+
type: ResNet101_vd
10+
output_stride: 8
11+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet101_vd_ssld.tar.gz
12+
backbone_indices: [0, 1, 2, 3]
13+
channels: 512
14+
dropout_prob: 0.1
15+
enable_auxiliary_loss: True
16+
17+
optimizer:
18+
type: sgd
19+
weight_decay: 0.0005
20+
21+
loss:
22+
types:
23+
- type: CrossEntropyLoss
24+
types:
25+
- type: CrossEntropyLoss
26+
coef: [1, 0.4]
27+
28+
lr_scheduler:
29+
type: PolynomialDecay
30+
learning_rate: 0.01
31+
end_lr: 0.0
32+
power: 0.9

test_tipc/prepare.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ if [ ${MODE} = "cpp_infer" ];then
205205
bash build.sh
206206
else
207207
models=("enet" "bisenetv2" "ocrnet_hrnetw18" "ocrnet_hrnetw48" "deeplabv3p_resnet50_cityscapes" \
208-
"fastscnn" "fcn_hrnetw18" "pp_liteseg_stdc1" "pp_liteseg_stdc2" "ccnet")
208+
"fastscnn" "fcn_hrnetw18" "pp_liteseg_stdc1" "pp_liteseg_stdc2" "ccnet" "upernet")
209209
if [ $(contains "${models[@]}" "${model_name}") == "y" ]; then
210210
cp ./test_tipc/data/cityscapes_val_5.list ./test_tipc/data/cityscapes
211211
fi

0 commit comments

Comments
 (0)