@@ -115,6 +115,8 @@ def __init__(self, block, layers, num_classes=1000):
115
115
self .layer2 = self ._make_layer (block , 128 , layers [1 ], stride = 2 )
116
116
self .layer3 = self ._make_layer (block , 256 , layers [2 ], stride = 2 )
117
117
self .layer4 = self ._make_layer (block , 512 , layers [3 ], stride = 2 )
118
+ # it is slightly better whereas slower to set stride = 1
119
+ # self.layer4 = self._make_layer(block, 512, layers[3], stride=1)
118
120
self .avgpool = nn .AvgPool2d (7 )
119
121
self .fc = nn .Linear (512 * block .expansion , num_classes )
120
122
@@ -221,8 +223,8 @@ def __init__(self, classes, num_layers=101, pretrained=False, class_agnostic=Fal
221
223
self .dout_base_model = 1024
222
224
self .pretrained = pretrained
223
225
self .class_agnostic = class_agnostic
224
-
225
- _fasterRCNN .__init__ (self , classes , class_agnostic )
226
+
227
+ _fasterRCNN .__init__ (self , classes , class_agnostic )
226
228
227
229
def _init_modules (self ):
228
230
resnet = resnet101 ()
@@ -233,7 +235,7 @@ def _init_modules(self):
233
235
resnet .load_state_dict ({k :v for k ,v in state_dict .items () if k in resnet .state_dict ()})
234
236
235
237
# Build resnet.
236
- self .RCNN_base = nn .Sequential (resnet .conv1 , resnet .bn1 ,resnet .relu ,
238
+ self .RCNN_base = nn .Sequential (resnet .conv1 , resnet .bn1 ,resnet .relu ,
237
239
resnet .maxpool ,resnet .layer1 ,resnet .layer2 ,resnet .layer3 )
238
240
239
241
self .RCNN_top = nn .Sequential (resnet .layer4 )
@@ -242,9 +244,9 @@ def _init_modules(self):
242
244
if self .class_agnostic :
243
245
self .RCNN_bbox_pred = nn .Linear (2048 , 4 )
244
246
else :
245
- self .RCNN_bbox_pred = nn .Linear (2048 , 4 * self .n_classes )
247
+ self .RCNN_bbox_pred = nn .Linear (2048 , 4 * self .n_classes )
246
248
247
- # Fix blocks
249
+ # Fix blocks
248
250
for p in self .RCNN_base [0 ].parameters (): p .requires_grad = False
249
251
for p in self .RCNN_base [1 ].parameters (): p .requires_grad = False
250
252
@@ -277,10 +279,10 @@ def set_bn_eval(m):
277
279
classname = m .__class__ .__name__
278
280
if classname .find ('BatchNorm' ) != - 1 :
279
281
m .eval ()
280
-
282
+
281
283
self .RCNN_base .apply (set_bn_eval )
282
284
self .RCNN_top .apply (set_bn_eval )
283
-
285
+
284
286
def _head_to_tail (self , pool5 ):
285
287
fc7 = self .RCNN_top (pool5 ).mean (3 ).mean (2 )
286
288
return fc7
0 commit comments