Skip to content

Commit 236403c

Browse files
author
Jianwei Yang
committed
tried stride=1 at layer4
1 parent d1dca5a commit 236403c

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

lib/model/faster_rcnn/resnet.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ def __init__(self, block, layers, num_classes=1000):
115115
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
116116
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
117117
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
118+
# it is slightly better whereas slower to set stride = 1
119+
# self.layer4 = self._make_layer(block, 512, layers[3], stride=1)
118120
self.avgpool = nn.AvgPool2d(7)
119121
self.fc = nn.Linear(512 * block.expansion, num_classes)
120122

@@ -221,8 +223,8 @@ def __init__(self, classes, num_layers=101, pretrained=False, class_agnostic=Fal
221223
self.dout_base_model = 1024
222224
self.pretrained = pretrained
223225
self.class_agnostic = class_agnostic
224-
225-
_fasterRCNN.__init__(self, classes, class_agnostic)
226+
227+
_fasterRCNN.__init__(self, classes, class_agnostic)
226228

227229
def _init_modules(self):
228230
resnet = resnet101()
@@ -233,7 +235,7 @@ def _init_modules(self):
233235
resnet.load_state_dict({k:v for k,v in state_dict.items() if k in resnet.state_dict()})
234236

235237
# Build resnet.
236-
self.RCNN_base = nn.Sequential(resnet.conv1, resnet.bn1,resnet.relu,
238+
self.RCNN_base = nn.Sequential(resnet.conv1, resnet.bn1,resnet.relu,
237239
resnet.maxpool,resnet.layer1,resnet.layer2,resnet.layer3)
238240

239241
self.RCNN_top = nn.Sequential(resnet.layer4)
@@ -242,9 +244,9 @@ def _init_modules(self):
242244
if self.class_agnostic:
243245
self.RCNN_bbox_pred = nn.Linear(2048, 4)
244246
else:
245-
self.RCNN_bbox_pred = nn.Linear(2048, 4 * self.n_classes)
247+
self.RCNN_bbox_pred = nn.Linear(2048, 4 * self.n_classes)
246248

247-
# Fix blocks
249+
# Fix blocks
248250
for p in self.RCNN_base[0].parameters(): p.requires_grad=False
249251
for p in self.RCNN_base[1].parameters(): p.requires_grad=False
250252

@@ -277,10 +279,10 @@ def set_bn_eval(m):
277279
classname = m.__class__.__name__
278280
if classname.find('BatchNorm') != -1:
279281
m.eval()
280-
282+
281283
self.RCNN_base.apply(set_bn_eval)
282284
self.RCNN_top.apply(set_bn_eval)
283-
285+
284286
def _head_to_tail(self, pool5):
285287
fc7 = self.RCNN_top(pool5).mean(3).mean(2)
286288
return fc7

0 commit comments

Comments
 (0)