Skip to content

Commit d1cb3b8

Browse files
committed
Bring activations mostly in line with timm, update some docstrings, fixup ONNX export and caffe2 scripts.
1 parent 8795d32 commit d1cb3b8

15 files changed

+411
-255
lines changed

caffe2_benchmark.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
""" Caffe2 validation script
2-
This script runs Caffe2 benchmark on exported model.
2+
3+
This script runs Caffe2 benchmark on exported ONNX model.
4+
It is a useful tool for reporting model FLOPS.
5+
6+
Copyright 2020 Ross Wightman
37
"""
48
import argparse
59
from caffe2.python import core, workspace, model_helper
610
from caffe2.proto import caffe2_pb2
711

812

913
parser = argparse.ArgumentParser(description='Caffe2 Model Benchmark')
14+
parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME',
15+
help='caffe2 model pb name prefix')
1016
parser.add_argument('--c2-init', default='', type=str, metavar='PATH',
11-
help='path to latest checkpoint (default: none)')
17+
help='caffe2 model init .pb')
1218
parser.add_argument('--c2-predict', default='', type=str, metavar='PATH',
13-
help='path to latest checkpoint (default: none)')
19+
help='caffe2 model predict .pb')
1420
parser.add_argument('-b', '--batch-size', default=1, type=int,
1521
metavar='N', help='mini-batch size (default: 1)')
1622
parser.add_argument('--img-size', default=224, type=int,
@@ -20,20 +26,23 @@
2026
def main():
2127
args = parser.parse_args()
2228
args.gpu_id = 0
29+
if args.c2_prefix:
30+
args.c2_init = args.c2_prefix + '.init.pb'
31+
args.c2_predict = args.c2_prefix + '.predict.pb'
2332

2433
model = model_helper.ModelHelper(name="le_net", init_params=False)
2534

2635
# Bring in the init net from init_net.pb
2736
init_net_proto = caffe2_pb2.NetDef()
2837
with open(args.c2_init, "rb") as f:
2938
init_net_proto.ParseFromString(f.read())
30-
model.param_init_net = core.Net(init_net_proto) # model.param_init_net.AppendNet(core.Net(init_net_proto)) #
39+
model.param_init_net = core.Net(init_net_proto)
3140

3241
# bring in the predict net from predict_net.pb
3342
predict_net_proto = caffe2_pb2.NetDef()
3443
with open(args.c2_predict, "rb") as f:
3544
predict_net_proto.ParseFromString(f.read())
36-
model.net = core.Net(predict_net_proto) # model.net.AppendNet(core.Net(predict_net_proto))
45+
model.net = core.Net(predict_net_proto)
3746

3847
# CUDA performance not impressive
3948
#device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id)

caffe2_validate.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
""" Caffe2 validation script
2-
This script is intended to verify exported models running in Caffe2
3-
It utilizes the same PyTorch dataloader/processing pipeline for comparison against
4-
the originals, I also have no desire to write that code in Caffe2.
2+
3+
This script is created to verify exported ONNX models running in Caffe2
4+
It utilizes the same PyTorch dataloader/processing pipeline for a
5+
fair comparison against the originals.
6+
7+
Copyright 2020 Ross Wightman
58
"""
69
import argparse
710
import numpy as np
@@ -14,12 +17,12 @@
1417
parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation')
1518
parser.add_argument('data', metavar='DIR',
1619
help='path to dataset')
17-
parser.add_argument('--model', '-m', metavar='MODEL', default='spnasnet1_00',
18-
help='model architecture (default: dpn92)')
20+
parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME',
21+
help='caffe2 model pb name prefix')
1922
parser.add_argument('--c2-init', default='', type=str, metavar='PATH',
20-
help='path to latest checkpoint (default: none)')
23+
help='caffe2 model init .pb')
2124
parser.add_argument('--c2-predict', default='', type=str, metavar='PATH',
22-
help='path to latest checkpoint (default: none)')
25+
help='caffe2 model predict .pb')
2326
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
2427
help='number of data loading workers (default: 2)')
2528
parser.add_argument('-b', '--batch-size', default=256, type=int,
@@ -43,22 +46,25 @@
4346
def main():
4447
args = parser.parse_args()
4548
args.gpu_id = 0
49+
if args.c2_prefix:
50+
args.c2_init = args.c2_prefix + '.init.pb'
51+
args.c2_predict = args.c2_prefix + '.predict.pb'
4652

4753
model = model_helper.ModelHelper(name="validation_net", init_params=False)
4854

4955
# Bring in the init net from init_net.pb
5056
init_net_proto = caffe2_pb2.NetDef()
5157
with open(args.c2_init, "rb") as f:
5258
init_net_proto.ParseFromString(f.read())
53-
model.param_init_net = core.Net(init_net_proto) # model.param_init_net.AppendNet(core.Net(init_net_proto)) #
59+
model.param_init_net = core.Net(init_net_proto)
5460

5561
# bring in the predict net from predict_net.pb
5662
predict_net_proto = caffe2_pb2.NetDef()
5763
with open(args.c2_predict, "rb") as f:
5864
predict_net_proto.ParseFromString(f.read())
59-
model.net = core.Net(predict_net_proto) # model.net.AppendNet(core.Net(predict_net_proto))
65+
model.net = core.Net(predict_net_proto)
6066

61-
data_config = resolve_data_config(args.model, args)
67+
data_config = resolve_data_config(None, args)
6268
loader = create_loader(
6369
Dataset(args.data, load_bytes=args.tf_preprocessing),
6470
input_size=data_config['input_size'],

data/tf_preprocessing.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
""" Tensorflow Preprocessing Adapter
2+
3+
Allows use of Tensorflow preprocessing pipeline in PyTorch Transform
4+
5+
Copyright of original Tensorflow code below.
6+
7+
Hacked together by / Copyright 2020 Ross Wightman
8+
"""
19
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
210
#
311
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -12,7 +20,6 @@
1220
# See the License for the specific language governing permissions and
1321
# limitations under the License.
1422
# ==============================================================================
15-
"""ImageNet preprocessing for MnasNet."""
1623
from __future__ import absolute_import
1724
from __future__ import division
1825
from __future__ import print_function

data/transforms.py

+3-23
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
def resolve_data_config(model, args, default_cfg={}, verbose=True):
1818
new_config = {}
1919
default_cfg = default_cfg
20-
if not default_cfg and hasattr(model, 'default_cfg'):
20+
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
2121
default_cfg = model.default_cfg
2222

2323
# Resolve input/image size
@@ -40,7 +40,7 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
4040
new_config['interpolation'] = default_cfg['interpolation']
4141

4242
# resolve dataset + model mean for normalization
43-
new_config['mean'] = get_mean_by_model(args.model)
43+
new_config['mean'] = IMAGENET_DEFAULT_MEAN
4444
if args.mean is not None:
4545
mean = tuple(args.mean)
4646
if len(mean) == 1:
@@ -52,7 +52,7 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
5252
new_config['mean'] = default_cfg['mean']
5353

5454
# resolve dataset + model std deviation for normalization
55-
new_config['std'] = get_std_by_model(args.model)
55+
new_config['std'] = IMAGENET_DEFAULT_STD
5656
if args.std is not None:
5757
std = tuple(args.std)
5858
if len(std) == 1:
@@ -78,26 +78,6 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
7878
return new_config
7979

8080

81-
def get_mean_by_model(model_name):
82-
model_name = model_name.lower()
83-
if 'dpn' in model_name:
84-
return IMAGENET_DPN_STD
85-
elif 'ception' in model_name:
86-
return IMAGENET_INCEPTION_MEAN
87-
else:
88-
return IMAGENET_DEFAULT_MEAN
89-
90-
91-
def get_std_by_model(model_name):
92-
model_name = model_name.lower()
93-
if 'dpn' in model_name:
94-
return IMAGENET_DEFAULT_STD
95-
elif 'ception' in model_name:
96-
return IMAGENET_INCEPTION_STD
97-
else:
98-
return IMAGENET_DEFAULT_STD
99-
100-
10181
class ToNumpy:
10282

10383
def __call__(self, pil_img):

geffnet/activations/__init__.py

+30-28
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from geffnet import config
2-
from geffnet.activations.activations_autofn import *
2+
from geffnet.activations.activations_me import *
33
from geffnet.activations.activations_jit import *
44
from geffnet.activations.activations import *
55

@@ -15,16 +15,16 @@
1515
hard_swish=hard_swish,
1616
)
1717

18-
_ACT_FN_AUTO = dict(
19-
swish=swish_auto,
20-
mish=mish_auto,
21-
)
22-
2318
_ACT_FN_JIT = dict(
2419
swish=swish_jit,
2520
mish=mish_jit,
26-
#hard_swish=hard_swish_jit,
27-
#hard_sigmoid_jit=hard_sigmoid_jit,
21+
)
22+
23+
_ACT_FN_ME = dict(
24+
swish=swish_me,
25+
mish=mish_me,
26+
hard_swish=hard_swish_me,
27+
hard_sigmoid_jit=hard_sigmoid_me,
2828
)
2929

3030
_ACT_LAYER_DEFAULT = dict(
@@ -38,16 +38,16 @@
3838
hard_swish=HardSwish,
3939
)
4040

41-
_ACT_LAYER_AUTO = dict(
42-
swish=SwishAuto,
43-
mish=MishAuto,
44-
)
45-
4641
_ACT_LAYER_JIT = dict(
4742
swish=SwishJit,
4843
mish=MishJit,
49-
#hard_swish=HardSwishJit,
50-
#hard_sigmoid=HardSigmoidJit
44+
)
45+
46+
_ACT_LAYER_ME = dict(
47+
swish=SwishMe,
48+
mish=MishMe,
49+
hard_swish=HardSwishMe,
50+
hard_sigmoid=HardSigmoidMe
5151
)
5252

5353
_OVERRIDE_FN = dict()
@@ -92,14 +92,15 @@ def get_act_fn(name='relu'):
9292
"""
9393
if name in _OVERRIDE_FN:
9494
return _OVERRIDE_FN[name]
95-
if not config.is_exportable() and not config.is_scriptable():
96-
# If not exporting or scripting the model, first look for a JIT optimized version
97-
# of our activation, then a custom autograd.Function variant before defaulting to
98-
# a Python or Torch builtin impl
99-
if name in _ACT_FN_JIT:
100-
return _ACT_FN_JIT[name]
101-
if name in _ACT_FN_AUTO:
102-
return _ACT_FN_AUTO[name]
95+
no_me = config.is_exportable() or config.is_scriptable() or config.is_no_jit()
96+
if not no_me and name in _ACT_FN_ME:
97+
# If not exporting or scripting the model, first look for a memory optimized version
98+
# activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin
99+
return _ACT_FN_ME[name]
100+
no_jit = config.is_exportable() or config.is_no_jit()
101+
# NOTE: export tracing should work with jit scripted components, but I keep running into issues
102+
if no_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting
103+
return _ACT_FN_JIT[name]
103104
return _ACT_FN_DEFAULT[name]
104105

105106

@@ -110,11 +111,12 @@ def get_act_layer(name='relu'):
110111
"""
111112
if name in _OVERRIDE_LAYER:
112113
return _OVERRIDE_LAYER[name]
113-
if not config.is_exportable() and not config.is_scriptable():
114-
if name in _ACT_LAYER_JIT:
115-
return _ACT_LAYER_JIT[name]
116-
if name in _ACT_LAYER_AUTO:
117-
return _ACT_LAYER_AUTO[name]
114+
no_me = config.is_exportable() or config.is_scriptable() or config.is_no_jit()
115+
if not no_me and name in _ACT_LAYER_ME:
116+
return _ACT_LAYER_ME[name]
117+
no_jit = config.is_exportable() or config.is_no_jit()
118+
if not no_jit and name in _ACT_LAYER_JIT: # jit scripted models should be okay for export/scripting
119+
return _ACT_LAYER_JIT[name]
118120
return _ACT_LAYER_DEFAULT[name]
119121

120122

geffnet/activations/activations.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,19 @@
1+
""" Activations
2+
3+
A collection of activations fn and modules with a common interface so that they can
4+
easily be swapped. All have an `inplace` arg even if not used.
5+
6+
Copyright 2020 Ross Wightman
7+
"""
18
from torch import nn as nn
29
from torch.nn import functional as F
310

411

512
def swish(x, inplace: bool = False):
6-
"""Swish - Described in: https://arxiv.org/abs/1710.05941
13+
"""Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3)
14+
and also as Swish (https://arxiv.org/abs/1710.05941).
15+
16+
TODO Rename to SiLU with addition to PyTorch
717
"""
818
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
919

geffnet/activations/activations_autofn.py

-72
This file was deleted.

0 commit comments

Comments
 (0)