Skip to content

Commit 30cad5f

Browse files
committed
Update Oriented form DoubleHead
1 parent 4fd26a7 commit 30cad5f

File tree

7 files changed

+269
-0
lines changed

7 files changed

+269
-0
lines changed

configs/obb/atss_obb/README.md

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection
2+
3+
4+
## Introduction
5+
6+
Oriented form of ATSS model.
7+
8+
```
9+
@article{zhang2019bridging,
10+
title = {Bridging the Gap Between Anchor-based and Anchor-free Detection via Adaptive Training Sample Selection},
11+
author = {Zhang, Shifeng and Chi, Cheng and Yao, Yongqiang and Lei, Zhen and Li, Stan Z.},
12+
journal = {arXiv preprint arXiv:1912.02424},
13+
year = {2019}
14+
}
15+
```
16+
17+
to be continue!!
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Rethinking Classification and Localization for Object Detection
2+
3+
## Introduction
4+
5+
Oriented form of double head model
6+
7+
```
8+
@article{wu2019rethinking,
9+
title={Rethinking Classification and Localization for Object Detection},
10+
author={Yue Wu and Yinpeng Chen and Lu Yuan and Zicheng Liu and Lijuan Wang and Hongzhi Li and Yun Fu},
11+
year={2019},
12+
eprint={1904.06493},
13+
archivePrefix={arXiv},
14+
primaryClass={cs.CV}
15+
}
16+
```
17+
18+
## Results and models
19+
to be continue !
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
_base_ = '../faster_rcnn_obb/faster_rcnn_obb_r50_fpn_1x_dota10.py'
2+
model = dict(
3+
roi_head=dict(
4+
type='OBBDoubleHeadRoIHead',
5+
reg_roi_scale_factor=1.2,
6+
bbox_head=dict(
7+
_delete_=True,
8+
type='OBBDoubleConvFCBBoxHead',
9+
start_bbox_type='hbb',
10+
end_bbox_type='obb',
11+
num_convs=4,
12+
num_fcs=2,
13+
in_channels=256,
14+
conv_out_channels=1024,
15+
fc_out_channels=1024,
16+
roi_feat_size=7,
17+
num_classes=15,
18+
bbox_coder=dict(
19+
type='HBB2OBBDeltaXYWHTCoder',
20+
target_means=[0., 0., 0., 0., 0.],
21+
target_stds=[0.1, 0.1, 0.2, 0.2, 1]),
22+
reg_class_agnostic=False,
23+
loss_cls=dict(
24+
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=2.0),
25+
loss_bbox=dict(type='SmoothL1Loss', beta=1.0, loss_weight=2.0))))

mmdet/models/roi_heads/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .obb.roitrans_roi_head import RoITransRoIHead
1919
from .obb.obb_standard_roi_head import OBBStandardRoIHead
2020
from .obb.gv_ratio_roi_head import GVRatioRoIHead
21+
from .obb.obb_double_roi_head import OBBDoubleHeadRoIHead
2122

2223
__all__ = [
2324
'BaseRoIHead', 'CascadeRoIHead', 'DoubleHeadRoIHead', 'MaskScoringRoIHead',

mmdet/models/roi_heads/bbox_heads/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .obb.obbox_head import OBBoxHead
77
from .obb.obb_convfc_bbox_head import (OBBConvFCBBoxHead, OBBShared2FCBBoxHead,
88
OBBShared4Conv1FCBBoxHead)
9+
from .obb.obb_double_bbox_head import OBBDoubleConvFCBBoxHead
910
from .obb.gv_bbox_head import GVBBoxHead
1011

1112
__all__ = [
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import torch.nn as nn
2+
from mmcv.cnn import ConvModule, normal_init, xavier_init
3+
4+
from mmdet.models.backbones.resnet import Bottleneck
5+
from mmdet.models.builder import HEADS
6+
from .obbox_head import OBBoxHead
7+
8+
9+
class BasicResBlock(nn.Module):
10+
"""Basic residual block.
11+
12+
This block is a little different from the block in the ResNet backbone.
13+
The kernel size of conv1 is 1 in this block while 3 in ResNet BasicBlock.
14+
15+
Args:
16+
in_channels (int): Channels of the input feature map.
17+
out_channels (int): Channels of the output feature map.
18+
conv_cfg (dict): The config dict for convolution layers.
19+
norm_cfg (dict): The config dict for normalization layers.
20+
"""
21+
22+
def __init__(self,
23+
in_channels,
24+
out_channels,
25+
conv_cfg=None,
26+
norm_cfg=dict(type='BN')):
27+
super(BasicResBlock, self).__init__()
28+
29+
# main path
30+
self.conv1 = ConvModule(
31+
in_channels,
32+
in_channels,
33+
kernel_size=3,
34+
padding=1,
35+
bias=False,
36+
conv_cfg=conv_cfg,
37+
norm_cfg=norm_cfg)
38+
self.conv2 = ConvModule(
39+
in_channels,
40+
out_channels,
41+
kernel_size=1,
42+
bias=False,
43+
conv_cfg=conv_cfg,
44+
norm_cfg=norm_cfg,
45+
act_cfg=None)
46+
47+
# identity path
48+
self.conv_identity = ConvModule(
49+
in_channels,
50+
out_channels,
51+
kernel_size=1,
52+
conv_cfg=conv_cfg,
53+
norm_cfg=norm_cfg,
54+
act_cfg=None)
55+
56+
self.relu = nn.ReLU(inplace=True)
57+
58+
def forward(self, x):
59+
identity = x
60+
61+
x = self.conv1(x)
62+
x = self.conv2(x)
63+
64+
identity = self.conv_identity(identity)
65+
out = x + identity
66+
67+
out = self.relu(out)
68+
return out
69+
70+
71+
@HEADS.register_module()
72+
class OBBDoubleConvFCBBoxHead(OBBoxHead):
73+
r"""Bbox head used in Double-Head R-CNN
74+
75+
.. code-block:: none
76+
77+
/-> cls
78+
/-> shared convs ->
79+
\-> reg
80+
roi features
81+
/-> cls
82+
\-> shared fc ->
83+
\-> reg
84+
""" # noqa: W605
85+
86+
def __init__(self,
87+
num_convs=0,
88+
num_fcs=0,
89+
conv_out_channels=1024,
90+
fc_out_channels=1024,
91+
conv_cfg=None,
92+
norm_cfg=dict(type='BN'),
93+
**kwargs):
94+
kwargs.setdefault('with_avg_pool', True)
95+
super(OBBDoubleConvFCBBoxHead, self).__init__(**kwargs)
96+
assert self.with_avg_pool
97+
assert num_convs > 0
98+
assert num_fcs > 0
99+
self.num_convs = num_convs
100+
self.num_fcs = num_fcs
101+
self.conv_out_channels = conv_out_channels
102+
self.fc_out_channels = fc_out_channels
103+
self.conv_cfg = conv_cfg
104+
self.norm_cfg = norm_cfg
105+
106+
# increase the channel of input features
107+
self.res_block = BasicResBlock(self.in_channels,
108+
self.conv_out_channels)
109+
110+
# add conv heads
111+
self.conv_branch = self._add_conv_branch()
112+
# add fc heads
113+
self.fc_branch = self._add_fc_branch()
114+
115+
out_dim_reg = self.reg_dim if self.reg_class_agnostic else \
116+
self.reg_dim * self.num_classes
117+
self.fc_reg = nn.Linear(self.conv_out_channels, out_dim_reg)
118+
119+
self.fc_cls = nn.Linear(self.fc_out_channels, self.num_classes + 1)
120+
self.relu = nn.ReLU(inplace=True)
121+
122+
def _add_conv_branch(self):
123+
"""Add the fc branch which consists of a sequential of conv layers"""
124+
branch_convs = nn.ModuleList()
125+
for i in range(self.num_convs):
126+
branch_convs.append(
127+
Bottleneck(
128+
inplanes=self.conv_out_channels,
129+
planes=self.conv_out_channels // 4,
130+
conv_cfg=self.conv_cfg,
131+
norm_cfg=self.norm_cfg))
132+
return branch_convs
133+
134+
def _add_fc_branch(self):
135+
"""Add the fc branch which consists of a sequential of fc layers"""
136+
branch_fcs = nn.ModuleList()
137+
for i in range(self.num_fcs):
138+
fc_in_channels = (
139+
self.in_channels *
140+
self.roi_feat_area if i == 0 else self.fc_out_channels)
141+
branch_fcs.append(nn.Linear(fc_in_channels, self.fc_out_channels))
142+
return branch_fcs
143+
144+
def init_weights(self):
145+
# conv layers are already initialized by ConvModule
146+
normal_init(self.fc_cls, std=0.01)
147+
normal_init(self.fc_reg, std=0.001)
148+
149+
for m in self.fc_branch.modules():
150+
if isinstance(m, nn.Linear):
151+
xavier_init(m, distribution='uniform')
152+
153+
def forward(self, x_cls, x_reg):
154+
# conv head
155+
x_conv = self.res_block(x_reg)
156+
157+
for conv in self.conv_branch:
158+
x_conv = conv(x_conv)
159+
160+
if self.with_avg_pool:
161+
x_conv = self.avg_pool(x_conv)
162+
163+
x_conv = x_conv.view(x_conv.size(0), -1)
164+
bbox_pred = self.fc_reg(x_conv)
165+
166+
# fc head
167+
x_fc = x_cls.view(x_cls.size(0), -1)
168+
for fc in self.fc_branch:
169+
x_fc = self.relu(fc(x_fc))
170+
171+
cls_score = self.fc_cls(x_fc)
172+
173+
return cls_score, bbox_pred
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from mmdet.models.builder import HEADS
2+
from .obb_standard_roi_head import OBBStandardRoIHead
3+
4+
5+
@HEADS.register_module()
6+
class OBBDoubleHeadRoIHead(OBBStandardRoIHead):
7+
"""RoI head for Double Head RCNN
8+
9+
https://arxiv.org/abs/1904.06493
10+
"""
11+
12+
def __init__(self, reg_roi_scale_factor, **kwargs):
13+
super(OBBDoubleHeadRoIHead, self).__init__(**kwargs)
14+
self.reg_roi_scale_factor = reg_roi_scale_factor
15+
16+
def _bbox_forward(self, x, rois):
17+
"""Box head forward function used in both training and testing time"""
18+
bbox_cls_feats = self.bbox_roi_extractor(
19+
x[:self.bbox_roi_extractor.num_inputs], rois)
20+
bbox_reg_feats = self.bbox_roi_extractor(
21+
x[:self.bbox_roi_extractor.num_inputs],
22+
rois,
23+
roi_scale_factor=self.reg_roi_scale_factor)
24+
if self.with_shared_head:
25+
bbox_cls_feats = self.shared_head(bbox_cls_feats)
26+
bbox_reg_feats = self.shared_head(bbox_reg_feats)
27+
cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
28+
29+
bbox_results = dict(
30+
cls_score=cls_score,
31+
bbox_pred=bbox_pred,
32+
bbox_feats=bbox_cls_feats)
33+
return bbox_results

0 commit comments

Comments
 (0)