Skip to content

Commit 55d3b8c

Browse files
authored
Add pafpn (open-mmlab#2392)
* add PAFPN * add doc * rename cfg, inherit from fpn * reformat doc string * standard doc string * fix doc of fpn * rename lateral_dconv to downsample_convs
1 parent 3174d69 commit 55d3b8c

File tree

4 files changed

+165
-15
lines changed

4 files changed

+165
-15
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
_base_ = '../faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
2+
3+
model = dict(
4+
neck=dict(
5+
type='PAFPN',
6+
in_channels=[256, 512, 1024, 2048],
7+
out_channels=256,
8+
num_outs=5))

mmdet/models/necks/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
from .fpn_carafe import FPN_CARAFE
44
from .hrfpn import HRFPN
55
from .nas_fpn import NASFPN
6+
from .pafpn import PAFPN
67

7-
__all__ = ['FPN', 'BFP', 'HRFPN', 'NASFPN', 'FPN_CARAFE']
8+
__all__ = ['FPN', 'BFP', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN']

mmdet/models/necks/fpn.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,25 @@ class FPN(nn.Module):
1616
Detection (https://arxiv.org/abs/1612.03144)
1717
1818
Args:
19-
in_channels (List[int]):
20-
number of input channels per scale
21-
22-
out_channels (int):
23-
number of output channels (used at each scale)
24-
25-
num_outs (int):
26-
number of output scales
27-
28-
start_level (int):
29-
index of the first input scale to use as an output scale
30-
31-
end_level (int, default=-1):
32-
index of the last input scale to use as an output scale
19+
in_channels (List[int]): Number of input channels per scale.
20+
out_channels (int): Number of output channels (used at each scale)
21+
num_outs (int): Number of output scales.
22+
start_level (int): Index of the start input backbone level used to
23+
build the feature pyramid. Default: 0.
24+
end_level (int): Index of the end input backbone level (exclusive) to
25+
build the feature pyramid. Default: -1, which means the last level.
26+
add_extra_convs (bool): Whether to add conv layers on top of the
27+
original feature maps. Default: False.
28+
extra_convs_on_inputs (bool): Whether to apply extra conv on
29+
the original feature from the backbone. Default: False.
30+
relu_before_extra_convs (bool): Whether to apply relu before the extra
31+
conv. Default: False.
32+
no_norm_on_lateral (bool): Whether to apply norm on lateral.
33+
Default: False.
34+
conv_cfg (dict): Config dict for convolution layer. Default: None.
35+
norm_cfg (dict): Config dict for normalization layer. Default: None.
36+
act_cfg (str): Config dict for activation layer in ConvModule.
37+
Default: None.
3338
3439
Example:
3540
>>> import torch

mmdet/models/necks/pafpn.py

+136
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import torch.nn as nn
2+
import torch.nn.functional as F
3+
4+
from mmdet.core import auto_fp16
5+
from mmdet.ops import ConvModule
6+
from ..registry import NECKS
7+
from .fpn import FPN
8+
9+
10+
@NECKS.register_module
11+
class PAFPN(FPN):
12+
"""Path Aggregation Network for Instance Segmentation.
13+
14+
This is an implementation of the PAFPN in Path Aggregation Network
15+
(https://arxiv.org/abs/1803.01534).
16+
17+
Args:
18+
in_channels (List[int]): Number of input channels per scale.
19+
out_channels (int): Number of output channels (used at each scale)
20+
num_outs (int): Number of output scales.
21+
start_level (int): Index of the start input backbone level used to
22+
build the feature pyramid. Default: 0.
23+
end_level (int): Index of the end input backbone level (exclusive) to
24+
build the feature pyramid. Default: -1, which means the last level.
25+
add_extra_convs (bool): Whether to add conv layers on top of the
26+
original feature maps. Default: False.
27+
extra_convs_on_inputs (bool): Whether to apply extra conv on
28+
the original feature from the backbone. Default: False.
29+
relu_before_extra_convs (bool): Whether to apply relu before the extra
30+
conv. Default: False.
31+
no_norm_on_lateral (bool): Whether to apply norm on lateral.
32+
Default: False.
33+
conv_cfg (dict): Config dict for convolution layer. Default: None.
34+
norm_cfg (dict): Config dict for normalization layer. Default: None.
35+
act_cfg (str): Config dict for activation layer in ConvModule.
36+
Default: None.
37+
"""
38+
39+
def __init__(self,
40+
in_channels,
41+
out_channels,
42+
num_outs,
43+
start_level=0,
44+
end_level=-1,
45+
add_extra_convs=False,
46+
extra_convs_on_inputs=True,
47+
relu_before_extra_convs=False,
48+
no_norm_on_lateral=False,
49+
conv_cfg=None,
50+
norm_cfg=None,
51+
act_cfg=None):
52+
super(PAFPN,
53+
self).__init__(in_channels, out_channels, num_outs, start_level,
54+
end_level, add_extra_convs, extra_convs_on_inputs,
55+
relu_before_extra_convs, no_norm_on_lateral,
56+
conv_cfg, norm_cfg, act_cfg)
57+
# add extra bottom up pathway
58+
self.downsample_convs = nn.ModuleList()
59+
self.pafpn_convs = nn.ModuleList()
60+
for i in range(self.start_level + 1, self.backbone_end_level):
61+
d_conv = ConvModule(
62+
out_channels,
63+
out_channels,
64+
3,
65+
stride=2,
66+
padding=1,
67+
conv_cfg=conv_cfg,
68+
norm_cfg=norm_cfg,
69+
act_cfg=act_cfg,
70+
inplace=False)
71+
pafpn_conv = ConvModule(
72+
out_channels,
73+
out_channels,
74+
3,
75+
padding=1,
76+
conv_cfg=conv_cfg,
77+
norm_cfg=norm_cfg,
78+
act_cfg=act_cfg,
79+
inplace=False)
80+
self.downsample_convs.append(d_conv)
81+
self.pafpn_convs.append(pafpn_conv)
82+
83+
@auto_fp16()
84+
def forward(self, inputs):
85+
assert len(inputs) == len(self.in_channels)
86+
87+
# build laterals
88+
laterals = [
89+
lateral_conv(inputs[i + self.start_level])
90+
for i, lateral_conv in enumerate(self.lateral_convs)
91+
]
92+
93+
# build top-down path
94+
used_backbone_levels = len(laterals)
95+
for i in range(used_backbone_levels - 1, 0, -1):
96+
prev_shape = laterals[i - 1].shape[2:]
97+
laterals[i - 1] += F.interpolate(
98+
laterals[i], size=prev_shape, mode='nearest')
99+
100+
# build outputs
101+
# part 1: from original levels
102+
inter_outs = [
103+
self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
104+
]
105+
106+
# part 2: add bottom-up path
107+
for i in range(0, used_backbone_levels - 1):
108+
inter_outs[i + 1] += self.downsample_convs[i](inter_outs[i])
109+
110+
outs = []
111+
outs.append(inter_outs[0])
112+
outs.extend([
113+
self.pafpn_convs[i - 1](inter_outs[i])
114+
for i in range(1, used_backbone_levels)
115+
])
116+
117+
# part 3: add extra levels
118+
if self.num_outs > len(outs):
119+
# use max pool to get more levels on top of outputs
120+
# (e.g., Faster R-CNN, Mask R-CNN)
121+
if not self.add_extra_convs:
122+
for i in range(self.num_outs - used_backbone_levels):
123+
outs.append(F.max_pool2d(outs[-1], 1, stride=2))
124+
# add conv layers on top of original feature maps (RetinaNet)
125+
else:
126+
if self.extra_convs_on_inputs:
127+
orig = inputs[self.backbone_end_level - 1]
128+
outs.append(self.fpn_convs[used_backbone_levels](orig))
129+
else:
130+
outs.append(self.fpn_convs[used_backbone_levels](outs[-1]))
131+
for i in range(used_backbone_levels + 1, self.num_outs):
132+
if self.relu_before_extra_convs:
133+
outs.append(self.fpn_convs[i](F.relu(outs[-1])))
134+
else:
135+
outs.append(self.fpn_convs[i](outs[-1]))
136+
return tuple(outs)

0 commit comments

Comments
 (0)