Skip to content

Commit 87b92c5

Browse files
committed
Some pretrianed URL changes
* host some of Cadene's weights on github instead of .fr for speed * add my old port of ensemble adversarial inception resnet v2 * switch to my TF port of normal inception res v2 and change FC layer back to 'classif' for compat with ens_adv
1 parent 827a3d6 commit 87b92c5

File tree

5 files changed

+50
-27
lines changed

5 files changed

+50
-27
lines changed

timm/models/inception_resnet_v2.py

+33-10
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,25 @@
66
from .adaptive_avgmax_pool import *
77
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
88

9-
_models = ['inception_resnet_v2']
9+
_models = ['inception_resnet_v2', 'ens_adv_inception_resnet_v2']
1010
__all__ = ['InceptionResnetV2'] + _models
1111

1212
default_cfgs = {
13+
# ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
1314
'inception_resnet_v2': {
14-
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
15+
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
1516
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
1617
'crop_pct': 0.8975, 'interpolation': 'bicubic',
1718
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
18-
'first_conv': 'conv2d_1a.conv', 'classifier': 'last_linear',
19+
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
20+
},
21+
# ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
22+
'ens_adv_inception_resnet_v2': {
23+
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
24+
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
25+
'crop_pct': 0.8975, 'interpolation': 'bicubic',
26+
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
27+
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
1928
}
2029
}
2130

@@ -274,19 +283,20 @@ def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'
274283
)
275284
self.block8 = Block8(noReLU=True)
276285
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
277-
self.last_linear = nn.Linear(self.num_features, num_classes)
286+
# NOTE some variants/checkpoints for this model may have 'last_linear' as the name for the FC
287+
self.classif = nn.Linear(self.num_features, num_classes)
278288

279289
def get_classifier(self):
280-
return self.last_linear
290+
return self.classif
281291

282292
def reset_classifier(self, num_classes, global_pool='avg'):
283293
self.global_pool = global_pool
284294
self.num_classes = num_classes
285-
del self.last_linear
295+
del self.classif
286296
if num_classes:
287-
self.last_linear = torch.nn.Linear(self.num_features, num_classes)
297+
self.classif = torch.nn.Linear(self.num_features, num_classes)
288298
else:
289-
self.last_linear = None
299+
self.classif = None
290300

291301
def forward_features(self, x, pool=True):
292302
x = self.conv2d_1a(x)
@@ -314,13 +324,13 @@ def forward(self, x):
314324
x = self.forward_features(x, pool=True)
315325
if self.drop_rate > 0:
316326
x = F.dropout(x, p=self.drop_rate, training=self.training)
317-
x = self.last_linear(x)
327+
x = self.classif(x)
318328
return x
319329

320330

321331
def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
322332
r"""InceptionResnetV2 model architecture from the
323-
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
333+
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>` paper.
324334
"""
325335
default_cfg = default_cfgs['inception_resnet_v2']
326336
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
@@ -330,3 +340,16 @@ def inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs
330340

331341
return model
332342

343+
344+
def ens_adv_inception_resnet_v2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
345+
r""" Ensemble Adversarially trained InceptionResnetV2 model architecture
346+
As per https://arxiv.org/abs/1705.07204 and
347+
https://github.com/tensorflow/models/tree/master/research/adv_imagenet_models.
348+
"""
349+
default_cfg = default_cfgs['ens_adv_inception_resnet_v2']
350+
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
351+
model.default_cfg = default_cfg
352+
if pretrained:
353+
load_pretrained(model, default_cfg, num_classes, in_chans)
354+
355+
return model

timm/models/inception_v4.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
default_cfgs = {
1313
'inception_v4': {
14-
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
14+
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth',
1515
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
1616
'crop_pct': 0.875, 'interpolation': 'bicubic',
1717
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,

timm/models/pnasnet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
default_cfgs = {
2222
'pnasnet5large': {
23-
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
23+
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/pnasnet5large-bf079911.pth',
2424
'input_size': (3, 331, 331),
2525
'pool_size': (11, 11),
2626
'crop_pct': 0.875,

timm/models/senet.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -37,20 +37,20 @@ def _cfg(url='', **kwargs):
3737
default_cfgs = {
3838
'senet154':
3939
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'),
40-
'seresnet18':
41-
_cfg(url='https://www.dropbox.com/s/3o3nd8mfhxod7rq/seresnet18-4bb0ce65.pth?dl=1',
42-
interpolation='bicubic'),
43-
'seresnet34':
44-
_cfg(url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'),
45-
'seresnet50':
46-
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth'),
47-
'seresnet101':
48-
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth'),
49-
'seresnet152':
50-
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth'),
51-
'seresnext26_32x4d':
52-
_cfg(url='https://www.dropbox.com/s/zaeruz2bejcdhh3/seresnext26_32x4d-65ebdb501.pth?dl=1',
53-
interpolation='bicubic'),
40+
'seresnet18': _cfg(
41+
url='https://www.dropbox.com/s/3o3nd8mfhxod7rq/seresnet18-4bb0ce65.pth?dl=1',
42+
interpolation='bicubic'),
43+
'seresnet34': _cfg(
44+
url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'),
45+
'seresnet50': _cfg(
46+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet50-ce0d4300.pth'),
47+
'seresnet101': _cfg(
48+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet101-7e38fcc6.pth'),
49+
'seresnet152': _cfg(
50+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/se_resnet152-d17c99b7.pth'),
51+
'seresnext26_32x4d': _cfg(
52+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26_32x4d-65ebdb501.pth',
53+
interpolation='bicubic'),
5454
'seresnext50_32x4d':
5555
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'),
5656
'seresnext101_32x4d':

timm/models/xception.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
default_cfgs = {
3737
'xception': {
38-
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
38+
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/xception-43020ad28.pth',
3939
'input_size': (3, 299, 299),
4040
'crop_pct': 0.8975,
4141
'interpolation': 'bicubic',

0 commit comments

Comments
 (0)