Skip to content

Commit 3fd33d5

Browse files
authored
[Feature] Add GloRe (PaddlePaddle#1951)
1 parent e667705 commit 3fd33d5

File tree

5 files changed

+261
-0
lines changed

5 files changed

+261
-0
lines changed

configs/glore/README.md

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Graph-Based Global Reasoning Networks
2+
3+
## Reference
4+
5+
> Chen, Yunpeng, Marcus Rohrbach, Zhicheng Yan, Yan Shuicheng, Jiashi Feng, and Yannis Kalantidis. "Graph-based global reasoning networks." In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition, pp. 433-442. 2019.
6+
7+
8+
## Performance
9+
10+
### Cityscapes
11+
12+
| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links |
13+
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
14+
|GloRe|ResNet50_OS8|1024x512|80000|78.26%|78.61%|78.72%|[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/glore_resnet50_os8_cityscapes_1024x512_80k/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/glore_resnet50_os8_cityscapes_1024x512_80k/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=de754e39ac9de4d2e951915c2334d6ec) |
15+
16+
17+
### Pascal VOC 2012 + Aug
18+
19+
| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links |
20+
|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
21+
|GloRe|ResNet50_OS8|512x512|40000|80.16%|80.35%|80.40%|[model](https://bj.bcebos.com/paddleseg/dygraph/pascal_voc12/glore_resnet50_os8_voc12aug_512x512_40k/model.pdparams) \| [log](https://bj.bcebos.com/paddleseg/dygraph/pascal_voc12/glore_resnet50_os8_voc12aug_512x512_40k/train.log) \| [vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=e40c1dd8d4fcbf2dcda01242dec9d9b5) |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
_base_: '../_base_/cityscapes.yml'
2+
3+
batch_size: 2
4+
iters: 80000
5+
6+
learning_rate:
7+
decay:
8+
end_lr: 1.0e-5
9+
10+
loss:
11+
types:
12+
- type: CrossEntropyLoss
13+
coef: [1, 0.4]
14+
15+
model:
16+
type: GloRe
17+
backbone:
18+
type: ResNet50_vd
19+
output_stride: 8
20+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
21+
enable_auxiliary_loss: True
22+
align_corners: False
23+
pretrained: null
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
_base_: '../_base_/pascal_voc12aug.yml'
2+
3+
4+
model:
5+
type: GloRe
6+
backbone:
7+
type: ResNet50_vd
8+
output_stride: 8
9+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/resnet50_vd_ssld_v2.tar.gz
10+
enable_auxiliary_loss: True
11+
align_corners: False
12+
pretrained: null
13+
14+
loss:
15+
types:
16+
- type: CrossEntropyLoss
17+
coef: [1, 0.4]

paddleseg/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,4 @@
5656
from .bisenetv1 import BiseNetV1
5757
from .fastfcn import FastFCN
5858
from .pfpnnet import PFPNNet
59+
from .glore import GloRe

paddleseg/models/glore.py

+199
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.nn as nn
17+
import paddle.nn.functional as F
18+
19+
from paddleseg.cvlibs import manager
20+
from paddleseg.models import layers
21+
from paddleseg.utils import utils
22+
23+
24+
@manager.MODELS.add_component
25+
class GloRe(nn.Layer):
26+
"""
27+
The GloRe implementation based on PaddlePaddle.
28+
29+
The original article refers to:
30+
Chen, Yunpeng, et al. "Graph-Based Global Reasoning Networks"
31+
(https://arxiv.org/pdf/1811.12814.pdf)
32+
33+
Args:
34+
num_classes (int): The unique number of target classes.
35+
backbone (Paddle.nn.Layer): Backbone network, currently support Resnet50/101.
36+
backbone_indices (tuple, optional): Two values in the tuple indicate the indices of output of backbone.
37+
gru_channels (int, optional): The number of input channels in GloRe Unit. Default: 512.
38+
gru_num_state (int, optional): The number of states in GloRe Unit. Default: 128.
39+
gru_num_node (tuple, optional): The number of nodes in GloRe Unit. Default: Default: 128.
40+
enable_auxiliary_loss (bool, optional): A bool value indicates whether adding auxiliary loss. Default: True.
41+
align_corners (bool, optional): An argument of F.interpolate. It should be set to False when the feature size is even,
42+
e.g. 1024x512, otherwise it is True, e.g. 769x769. Default: False.
43+
pretrained (str, optional): The path or url of pretrained model. Default: None.
44+
"""
45+
46+
def __init__(self,
47+
num_classes,
48+
backbone,
49+
backbone_indices=(2, 3),
50+
gru_channels=512,
51+
gru_num_state=128,
52+
gru_num_node=64,
53+
enable_auxiliary_loss=True,
54+
align_corners=False,
55+
pretrained=None):
56+
super().__init__()
57+
58+
self.backbone = backbone
59+
backbone_channels = [
60+
backbone.feat_channels[i] for i in backbone_indices
61+
]
62+
63+
self.head = GloReHead(num_classes, backbone_indices, backbone_channels,
64+
gru_channels, gru_num_state, gru_num_node,
65+
enable_auxiliary_loss)
66+
self.align_corners = align_corners
67+
self.pretrained = pretrained
68+
self.init_weight()
69+
70+
def forward(self, x):
71+
feat_list = self.backbone(x)
72+
logit_list = self.head(feat_list)
73+
return [
74+
F.interpolate(
75+
logit,
76+
x.shape[2:],
77+
mode='bilinear',
78+
align_corners=self.align_corners) for logit in logit_list
79+
]
80+
81+
def init_weight(self):
82+
if self.pretrained is not None:
83+
utils.load_entire_model(self, self.pretrained)
84+
85+
86+
class GloReHead(nn.Layer):
87+
88+
def __init__(self,
89+
num_classes,
90+
backbone_indices,
91+
backbone_channels,
92+
gru_channels=512,
93+
gru_num_state=128,
94+
gru_num_node=64,
95+
enable_auxiliary_loss=True):
96+
super().__init__()
97+
98+
in_channels = backbone_channels[1]
99+
self.conv_bn_relu = layers.ConvBNReLU(
100+
in_channels, gru_channels, 1, bias_attr=False)
101+
self.gru_module = GruModule(
102+
num_input=gru_channels,
103+
num_state=gru_num_state,
104+
num_node=gru_num_node)
105+
106+
self.dropout = nn.Dropout(0.1)
107+
self.classifier = nn.Conv2D(512, num_classes, kernel_size=1)
108+
self.auxlayer = layers.AuxLayer(
109+
in_channels=backbone_channels[0],
110+
inter_channels=backbone_channels[0] // 4,
111+
out_channels=num_classes)
112+
113+
self.backbone_indices = backbone_indices
114+
self.enable_auxiliary_loss = enable_auxiliary_loss
115+
116+
def forward(self, feat_list):
117+
118+
logit_list = []
119+
x = feat_list[self.backbone_indices[1]]
120+
121+
feature = self.conv_bn_relu(x)
122+
gru_output = self.gru_module(feature)
123+
output = self.dropout(gru_output)
124+
logit = self.classifier(output)
125+
logit_list.append(logit)
126+
127+
if self.enable_auxiliary_loss:
128+
low_level_feat = feat_list[self.backbone_indices[0]]
129+
auxiliary_logit = self.auxlayer(low_level_feat)
130+
logit_list.append(auxiliary_logit)
131+
132+
return logit_list
133+
134+
135+
class GCN(nn.Layer):
136+
def __init__(self, num_state, num_node, bias=False):
137+
super(GCN, self).__init__()
138+
self.conv1 = nn.Conv1D(num_node, num_node, kernel_size=1)
139+
self.relu = nn.ReLU()
140+
self.conv2 = nn.Conv1D(
141+
num_state, num_state, kernel_size=1, bias_attr=bias)
142+
143+
def forward(self, x):
144+
h = self.conv1(paddle.transpose(x, perm=(0, 2, 1)))
145+
h = paddle.transpose(h, perm=(0, 2, 1))
146+
h = h + x
147+
h = self.relu(self.conv2(h))
148+
return h
149+
150+
151+
class GruModule(nn.Layer):
152+
def __init__(self,
153+
num_input=512,
154+
num_state=128,
155+
num_node=64,
156+
normalize=False):
157+
super(GruModule, self).__init__()
158+
self.normalize = normalize
159+
self.num_state = num_state
160+
self.num_node = num_node
161+
self.reduction_dim = nn.Conv2D(num_input, num_state, kernel_size=1)
162+
self.projection_mat = nn.Conv2D(num_input, num_node, kernel_size=1)
163+
self.gcn = GCN(num_state=self.num_state, num_node=self.num_node)
164+
self.extend_dim = nn.Conv2D(
165+
self.num_state, num_input, kernel_size=1, bias_attr=False)
166+
self.extend_bn = nn.SyncBatchNorm(num_input, epsilon=1e-4)
167+
168+
def forward(self, input):
169+
n, c, h, w = input.shape
170+
# B, C, H, W
171+
reduction_dim = self.reduction_dim(input)
172+
# B, N, H, W
173+
mat_B = self.projection_mat(input)
174+
# B, C, H*W
175+
reshaped_reduction = paddle.reshape(
176+
reduction_dim, shape=[n, self.num_state, h * w])
177+
# B, N, H*W
178+
reshaped_B = paddle.reshape(mat_B, shape=[n, self.num_node, h * w])
179+
# B, N, H*W
180+
reproject = reshaped_B
181+
# B, C, N
182+
node_state_V = paddle.matmul(
183+
reshaped_reduction, paddle.transpose(
184+
reshaped_B, perm=[0, 2, 1]))
185+
186+
if self.normalize:
187+
node_state_V = node_state_V * (1. / reshaped_reduction.shape[2])
188+
189+
# B, C, N
190+
gcn_out = self.gcn(node_state_V)
191+
# B, C, H*W
192+
Y = paddle.matmul(gcn_out, reproject)
193+
# B, C, H, W
194+
Y = paddle.reshape(Y, shape=[n, self.num_state, h, w])
195+
Y_extend = self.extend_dim(Y)
196+
Y_extend = self.extend_bn(Y_extend)
197+
198+
out = input + Y_extend
199+
return out

0 commit comments

Comments
 (0)