6
6
from .adaptive_avgmax_pool import *
7
7
from timm .data import IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
8
8
9
- _models = ['inception_resnet_v2' ]
9
+ _models = ['inception_resnet_v2' , 'ens_adv_inception_resnet_v2' ]
10
10
__all__ = ['InceptionResnetV2' ] + _models
11
11
12
12
default_cfgs = {
13
+ # ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
13
14
'inception_resnet_v2' : {
14
- 'url' : 'http ://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4 .pth' ,
15
+ 'url' : 'https ://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6 .pth' ,
15
16
'num_classes' : 1001 , 'input_size' : (3 , 299 , 299 ), 'pool_size' : (8 , 8 ),
16
17
'crop_pct' : 0.8975 , 'interpolation' : 'bicubic' ,
17
18
'mean' : IMAGENET_INCEPTION_MEAN , 'std' : IMAGENET_INCEPTION_STD ,
18
- 'first_conv' : 'conv2d_1a.conv' , 'classifier' : 'last_linear' ,
19
+ 'first_conv' : 'conv2d_1a.conv' , 'classifier' : 'classif' ,
20
+ },
21
+ # ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
22
+ 'ens_adv_inception_resnet_v2' : {
23
+ 'url' : 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth' ,
24
+ 'num_classes' : 1001 , 'input_size' : (3 , 299 , 299 ), 'pool_size' : (8 , 8 ),
25
+ 'crop_pct' : 0.8975 , 'interpolation' : 'bicubic' ,
26
+ 'mean' : IMAGENET_INCEPTION_MEAN , 'std' : IMAGENET_INCEPTION_STD ,
27
+ 'first_conv' : 'conv2d_1a.conv' , 'classifier' : 'classif' ,
19
28
}
20
29
}
21
30
@@ -274,19 +283,20 @@ def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'
274
283
)
275
284
self .block8 = Block8 (noReLU = True )
276
285
self .conv2d_7b = BasicConv2d (2080 , self .num_features , kernel_size = 1 , stride = 1 )
277
- self .last_linear = nn .Linear (self .num_features , num_classes )
286
+ # NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
287
+ self .classif = nn .Linear (self .num_features , num_classes )
278
288
279
289
def get_classifier (self ):
280
- return self .last_linear
290
+ return self .classif
281
291
282
292
def reset_classifier (self , num_classes , global_pool = 'avg' ):
283
293
self .global_pool = global_pool
284
294
self .num_classes = num_classes
285
- del self .last_linear
295
+ del self .classif
286
296
if num_classes :
287
- self .last_linear = torch .nn .Linear (self .num_features , num_classes )
297
+ self .classif = torch .nn .Linear (self .num_features , num_classes )
288
298
else :
289
- self .last_linear = None
299
+ self .classif = None
290
300
291
301
def forward_features (self , x , pool = True ):
292
302
x = self .conv2d_1a (x )
@@ -314,13 +324,13 @@ def forward(self, x):
314
324
x = self .forward_features (x , pool = True )
315
325
if self .drop_rate > 0 :
316
326
x = F .dropout (x , p = self .drop_rate , training = self .training )
317
- x = self .last_linear (x )
327
+ x = self .classif (x )
318
328
return x
319
329
320
330
321
331
def inception_resnet_v2 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
322
332
r"""InceptionResnetV2 model architecture from the
323
- `"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
333
+ `"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
324
334
"""
325
335
default_cfg = default_cfgs ['inception_resnet_v2' ]
326
336
model = InceptionResnetV2 (num_classes = num_classes , in_chans = in_chans , ** kwargs )
@@ -330,3 +340,16 @@ def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs
330
340
331
341
return model
332
342
343
+
344
+ def ens_adv_inception_resnet_v2 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
345
+ r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
346
+ As per https://arxiv.org/abs/1705.07204 and
347
+ https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
348
+ """
349
+ default_cfg = default_cfgs ['ens_adv_inception_resnet_v2' ]
350
+ model = InceptionResnetV2 (num_classes = num_classes , in_chans = in_chans , ** kwargs )
351
+ model .default_cfg = default_cfg
352
+ if pretrained :
353
+ load_pretrained (model , default_cfg , num_classes , in_chans )
354
+
355
+ return model
0 commit comments