Skip to content

Commit 53dfee4

Browse files
author
chenxi
committed
fix some bugs
1 parent 3beb40f commit 53dfee4

File tree

10 files changed

+134
-111
lines changed

10 files changed

+134
-111
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,6 @@ __pycache__
33
*.jpg
44
*.png
55
log
6+
pytorch_playground.egg-info
7+
script/val224_compressed.pkl
68

README.md

+6-12
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,21 @@ Also, if want to train the MLP model on mnist, simply run `python mnist/train.py
2424

2525

2626
# Install
27-
- pytorch (>=0.1.11) and torchvision from [official website](http://pytorch.org/), for example, cuda8.0 for python3.5
28-
- `pip install http://download.pytorch.org/whl/cu80/torch-0.1.12.post2-cp35-cp35m-linux_x86_64.whl`
29-
- `pip install torchvision`
30-
- tqdm
31-
- `pip install tqdm`
32-
- OpenCV
33-
- `conda install -c menpo opencv3`
34-
- Setting PYTHONPATH
35-
- `export PYTHONPATH=/path/to/pytorch-playground:$PYTHONPATH`
27+
```
28+
python3 setup.py develop --user
29+
```
3630

3731
# ImageNet dataset
3832
We provide precomputed imagenet validation dataset with 224x224x3 size. We first resize the shorter size of image to 256, then we crop 224x224 image in the center. Then we encode the cropped images to jpg string and dump to pickle.
3933
- `cd script`
40-
- Download the [val224_compressed.pkl](https://drive.google.com/file/d/1U8ir2fOR4Sir3FCj9b7FQRPSVsycTfVc/view?usp=sharing)
41-
- `python convert.py`
34+
- Download the `val224_compressed.pkl` ([Tsinghua](http://ml.cs.tsinghua.edu.cn/~chenxi/dataset/val224_compressed.pkl) / [Google Drive](https://drive.google.com/file/d/1U8ir2fOR4Sir3FCj9b7FQRPSVsycTfVc/view?usp=sharing))
35+
- `python convert.py` (needs 48G memory, thanks [@jnorwood](https://github.com/aaron-xichen/pytorch-playground/issues/18) )
4236

4337

4438
# Quantization
4539
We also provide a simple demo to quantize these models to specified bit-width with several methods, including linear method, minmax method and non-linear method.
4640

47-
`python quantize.py --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1`
41+
`quantize --type cifar10 --quant_method linear --param_bits 8 --fwd_bits 8 --bn_bits 8 --ngpu 1`
4842

4943
## Top1 Accuracy
5044
We evaluate the performance of popular dataset and models with linear quantized method. The bit-width of running mean and running variance in BN are 10 bits for all results. (except for 32-float)

imagenet/dataset.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import os.path
44
import numpy as np
5+
import joblib
56

67

78
def get(batch_size, data_root='/tmp/public_dataset/pytorch', train=False, val=True, **kwargs):
@@ -26,7 +27,7 @@ def __init__(self, root, batch_size, train=False, input_size=224, **kwargs):
2627
pkl_file = os.path.join(root, 'train{}.pkl'.format(input_size))
2728
else:
2829
pkl_file = os.path.join(root, 'val{}.pkl'.format(input_size))
29-
self.data_dict = misc.load_pickle(pkl_file)
30+
self.data_dict = joblib.load(pkl_file)
3031

3132
self.batch_size = batch_size
3233
self.idx = 0

imagenet/inception.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False):
6060
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
6161
X = stats.truncnorm(-2, 2, scale=stddev)
6262
values = torch.Tensor(X.rvs(m.weight.data.numel()))
63-
m.weight.data.copy_(values)
63+
m.weight.data.copy_(values.reshape(m.weight.shape))
6464
elif isinstance(m, nn.BatchNorm2d):
6565
m.weight.data.fill_(1)
6666
m.bias.data.zero_()

quantize.py

+80-78
Original file line numberDiff line numberDiff line change
@@ -5,93 +5,95 @@
55
cudnn.benchmark =True
66
from collections import OrderedDict
77

8-
parser = argparse.ArgumentParser(description='PyTorch SVHN Example')
9-
parser.add_argument('--type', default='cifar10', help='|'.join(selector.known_models))
10-
parser.add_argument('--quant_method', default='linear', help='linear|minmax|log|tanh')
11-
parser.add_argument('--batch_size', type=int, default=100, help='input batch size for training (default: 64)')
12-
parser.add_argument('--gpu', default=None, help='index of gpus to use')
13-
parser.add_argument('--ngpu', type=int, default=8, help='number of gpus to use')
14-
parser.add_argument('--seed', type=int, default=117, help='random seed (default: 1)')
15-
parser.add_argument('--model_root', default='~/.torch/models/', help='folder to save the model')
16-
parser.add_argument('--data_root', default='/tmp/public_dataset/pytorch/', help='folder to save the model')
17-
parser.add_argument('--logdir', default='log/default', help='folder to save to the log')
8+
def main():
9+
parser = argparse.ArgumentParser(description='PyTorch SVHN Example')
10+
parser.add_argument('--type', default='cifar10', help='|'.join(selector.known_models))
11+
parser.add_argument('--quant_method', default='linear', help='linear|minmax|log|tanh')
12+
parser.add_argument('--batch_size', type=int, default=100, help='input batch size for training (default: 64)')
13+
parser.add_argument('--gpu', default=None, help='index of gpus to use')
14+
parser.add_argument('--ngpu', type=int, default=8, help='number of gpus to use')
15+
parser.add_argument('--seed', type=int, default=117, help='random seed (default: 1)')
16+
parser.add_argument('--model_root', default='~/.torch/models/', help='folder to save the model')
17+
parser.add_argument('--data_root', default='/data/public_dataset/pytorch/', help='folder to save the model')
18+
parser.add_argument('--logdir', default='log/default', help='folder to save to the log')
1819

19-
parser.add_argument('--input_size', type=int, default=224, help='input size of image')
20-
parser.add_argument('--n_sample', type=int, default=20, help='number of samples to infer the scaling factor')
21-
parser.add_argument('--param_bits', type=int, default=8, help='bit-width for parameters')
22-
parser.add_argument('--bn_bits', type=int, default=32, help='bit-width for running mean and std')
23-
parser.add_argument('--fwd_bits', type=int, default=8, help='bit-width for layer output')
24-
parser.add_argument('--overflow_rate', type=float, default=0.0, help='overflow rate')
25-
args = parser.parse_args()
20+
parser.add_argument('--input_size', type=int, default=224, help='input size of image')
21+
parser.add_argument('--n_sample', type=int, default=20, help='number of samples to infer the scaling factor')
22+
parser.add_argument('--param_bits', type=int, default=8, help='bit-width for parameters')
23+
parser.add_argument('--bn_bits', type=int, default=32, help='bit-width for running mean and std')
24+
parser.add_argument('--fwd_bits', type=int, default=8, help='bit-width for layer output')
25+
parser.add_argument('--overflow_rate', type=float, default=0.0, help='overflow rate')
26+
args = parser.parse_args()
2627

27-
args.gpu = misc.auto_select_gpu(utility_bound=0, num_gpu=args.ngpu, selected_gpus=args.gpu)
28-
args.ngpu = len(args.gpu)
29-
misc.ensure_dir(args.logdir)
30-
args.model_root = misc.expand_user(args.model_root)
31-
args.data_root = misc.expand_user(args.data_root)
32-
args.input_size = 299 if 'inception' in args.type else args.input_size
33-
assert args.quant_method in ['linear', 'minmax', 'log', 'tanh']
34-
print("=================FLAGS==================")
35-
for k, v in args.__dict__.items():
36-
print('{}: {}'.format(k, v))
37-
print("========================================")
28+
args.gpu = misc.auto_select_gpu(utility_bound=0, num_gpu=args.ngpu, selected_gpus=args.gpu)
29+
args.ngpu = len(args.gpu)
30+
misc.ensure_dir(args.logdir)
31+
args.model_root = misc.expand_user(args.model_root)
32+
args.data_root = misc.expand_user(args.data_root)
33+
args.input_size = 299 if 'inception' in args.type else args.input_size
34+
assert args.quant_method in ['linear', 'minmax', 'log', 'tanh']
35+
print("=================FLAGS==================")
36+
for k, v in args.__dict__.items():
37+
print('{}: {}'.format(k, v))
38+
print("========================================")
3839

39-
assert torch.cuda.is_available(), 'no cuda'
40-
torch.manual_seed(args.seed)
41-
torch.cuda.manual_seed(args.seed)
40+
assert torch.cuda.is_available(), 'no cuda'
41+
torch.manual_seed(args.seed)
42+
torch.cuda.manual_seed(args.seed)
4243

43-
# load model and dataset fetcher
44-
model_raw, ds_fetcher, is_imagenet = selector.select(args.type, model_root=args.model_root)
45-
args.ngpu = args.ngpu if is_imagenet else 1
44+
# load model and dataset fetcher
45+
model_raw, ds_fetcher, is_imagenet = selector.select(args.type, model_root=args.model_root)
46+
args.ngpu = args.ngpu if is_imagenet else 1
4647

47-
# quantize parameters
48-
if args.param_bits < 32:
49-
state_dict = model_raw.state_dict()
50-
state_dict_quant = OrderedDict()
51-
sf_dict = OrderedDict()
52-
for k, v in state_dict.items():
53-
if 'running' in k:
54-
if args.bn_bits >=32:
55-
print("Ignoring {}".format(k))
56-
state_dict_quant[k] = v
57-
continue
48+
# quantize parameters
49+
if args.param_bits < 32:
50+
state_dict = model_raw.state_dict()
51+
state_dict_quant = OrderedDict()
52+
sf_dict = OrderedDict()
53+
for k, v in state_dict.items():
54+
if 'running' in k:
55+
if args.bn_bits >=32:
56+
print("Ignoring {}".format(k))
57+
state_dict_quant[k] = v
58+
continue
59+
else:
60+
bits = args.bn_bits
5861
else:
59-
bits = args.bn_bits
60-
else:
61-
bits = args.param_bits
62+
bits = args.param_bits
6263

63-
if args.quant_method == 'linear':
64-
sf = bits - 1. - quant.compute_integral_part(v, overflow_rate=args.overflow_rate)
65-
v_quant = quant.linear_quantize(v, sf, bits=bits)
66-
elif args.quant_method == 'log':
67-
v_quant = quant.log_minmax_quantize(v, bits=bits)
68-
elif args.quant_method == 'minmax':
69-
v_quant = quant.min_max_quantize(v, bits=bits)
70-
else:
71-
v_quant = quant.tanh_quantize(v, bits=bits)
72-
state_dict_quant[k] = v_quant
73-
print(k, bits)
74-
model_raw.load_state_dict(state_dict_quant)
75-
76-
# quantize forward activation
77-
if args.fwd_bits < 32:
78-
model_raw = quant.duplicate_model_with_quant(model_raw, bits=args.fwd_bits, overflow_rate=args.overflow_rate,
79-
counter=args.n_sample, type=args.quant_method)
80-
print(model_raw)
81-
val_ds_tmp = ds_fetcher(10, data_root=args.data_root, train=False, input_size=args.input_size)
82-
misc.eval_model(model_raw, val_ds_tmp, ngpu=1, n_sample=args.n_sample, is_imagenet=is_imagenet)
64+
if args.quant_method == 'linear':
65+
sf = bits - 1. - quant.compute_integral_part(v, overflow_rate=args.overflow_rate)
66+
v_quant = quant.linear_quantize(v, sf, bits=bits)
67+
elif args.quant_method == 'log':
68+
v_quant = quant.log_minmax_quantize(v, bits=bits)
69+
elif args.quant_method == 'minmax':
70+
v_quant = quant.min_max_quantize(v, bits=bits)
71+
else:
72+
v_quant = quant.tanh_quantize(v, bits=bits)
73+
state_dict_quant[k] = v_quant
74+
print(k, bits)
75+
model_raw.load_state_dict(state_dict_quant)
8376

84-
# eval model
85-
val_ds = ds_fetcher(args.batch_size, data_root=args.data_root, train=False, input_size=args.input_size)
86-
acc1, acc5 = misc.eval_model(model_raw, val_ds, ngpu=args.ngpu, is_imagenet=is_imagenet)
77+
# quantize forward activation
78+
if args.fwd_bits < 32:
79+
model_raw = quant.duplicate_model_with_quant(model_raw, bits=args.fwd_bits, overflow_rate=args.overflow_rate,
80+
counter=args.n_sample, type=args.quant_method)
81+
print(model_raw)
82+
val_ds_tmp = ds_fetcher(10, data_root=args.data_root, train=False, input_size=args.input_size)
83+
misc.eval_model(model_raw, val_ds_tmp, ngpu=1, n_sample=args.n_sample, is_imagenet=is_imagenet)
8784

88-
# print sf
89-
print(model_raw)
90-
res_str = "type={}, quant_method={}, param_bits={}, bn_bits={}, fwd_bits={}, overflow_rate={}, acc1={:.4f}, acc5={:.4f}".format(
91-
args.type, args.quant_method, args.param_bits, args.bn_bits, args.fwd_bits, args.overflow_rate, acc1, acc5)
92-
print(res_str)
93-
with open('acc1_acc5.txt', 'a') as f:
94-
f.write(res_str + '\n')
85+
# eval model
86+
val_ds = ds_fetcher(args.batch_size, data_root=args.data_root, train=False, input_size=args.input_size)
87+
acc1, acc5 = misc.eval_model(model_raw, val_ds, ngpu=args.ngpu, is_imagenet=is_imagenet)
9588

89+
# print sf
90+
print(model_raw)
91+
res_str = "type={}, quant_method={}, param_bits={}, bn_bits={}, fwd_bits={}, overflow_rate={}, acc1={:.4f}, acc5={:.4f}".format(
92+
args.type, args.quant_method, args.param_bits, args.bn_bits, args.fwd_bits, args.overflow_rate, acc1, acc5)
93+
print(res_str)
94+
with open('acc1_acc5.txt', 'a') as f:
95+
f.write(res_str + '\n')
9696

9797

98+
if __name__ == '__main__':
99+
main()

script/convert.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -4,41 +4,45 @@
44
from utee import misc
55
import argparse
66
import cv2
7+
import joblib
78

8-
imagenet_urls = [
9-
'http://ml.cs.tsinghua.edu.cn/~chenxi/dataset/val224_compressed.pkl'
10-
]
119
parser = argparse.ArgumentParser(description='Extract the ILSVRC2012 val dataset')
1210
parser.add_argument('--in_file', default='val224_compressed.pkl', help='input file path')
13-
parser.add_argument('--out_root', default='/tmp/public_dataset/pytorch/imagenet-data/', help='output file path')
11+
parser.add_argument('--out_root', default='/data/public_dataset/pytorch/imagenet-data/', help='output file path')
1412
args = parser.parse_args()
1513

1614
d = misc.load_pickle(args.in_file)
1715
assert len(d['data']) == 50000, len(d['data'])
1816
assert len(d['target']) == 50000, len(d['target'])
1917

20-
data224 = []
18+
2119
data299 = []
2220
for img, target in tqdm.tqdm(zip(d['data'], d['target']), total=50000):
2321
img224 = misc.str2img(img)
2422
img299 = cv2.resize(img224, (299, 299))
25-
data224.append(img224)
2623
data299.append(img299)
27-
data_dict224 = dict(
28-
data = np.array(data224).transpose(0, 3, 1, 2),
29-
target = d['target']
30-
)
24+
3125
data_dict299 = dict(
3226
data = np.array(data299).transpose(0, 3, 1, 2),
3327
target = d['target']
3428
)
35-
3629
if not os.path.exists(args.out_root):
3730
os.makedirs(args.out_root)
38-
misc.dump_pickle(data_dict224, os.path.join(args.out_root, 'val224.pkl'))
39-
misc.dump_pickle(data_dict299, os.path.join(args.out_root, 'val299.pkl'))
31+
joblib.dump(data_dict299, os.path.join(args.out_root, 'val299.pkl'))
4032

33+
data299.clear()
34+
data_dict299.clear()
4135

36+
data224 = []
37+
for img, target in tqdm.tqdm(zip(d['data'], d['target']), total=50000):
38+
img224 = misc.str2img(img)
39+
data224.append(img224)
40+
41+
data_dict224 = dict(
42+
data = np.array(data224).transpose(0, 3, 1, 2),
43+
target = d['target']
44+
)
45+
joblib.dump(data_dict224, os.path.join(args.out_root, 'val224.pkl'))
4246

4347

4448

setup.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from setuptools import setup, find_packages
2+
3+
with open("requirements.txt") as requirements_file:
4+
REQUIREMENTS = requirements_file.readlines()
5+
6+
setup(
7+
name="pytorch-playground",
8+
version="1.0.0",
9+
author='Aaron Chen',
10+
author_email='[email protected]',
11+
packages=find_packages(),
12+
entry_points = {
13+
'console_scripts': [
14+
'quantize=quantize:main',
15+
]
16+
},
17+
install_requires=REQUIREMENTS,
18+
19+
)

svhn/dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def get(batch_size, data_root='/tmp/public_dataset/pytorch', train=True, val=Tru
1010
print("Building SVHN data loader with {} workers".format(num_workers))
1111

1212
def target_transform(target):
13-
return int(target[0]) - 1
13+
return int(target) - 1
1414

1515
ds = []
1616
if train:

utee/misc.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def load_state_dict(model, model_urls, model_root):
222222
own_state[name].copy_(param)
223223

224224
missing = set(own_state.keys()) - set(state_dict.keys())
225-
if len(missing) > 0:
226-
raise KeyError('missing keys in state_dict: "{}"'.format(missing))
225+
no_use = set(state_dict.keys()) - set(own_state.keys())
226+
if len(no_use) > 0:
227+
raise KeyError('some keys are not used: "{}"'.format(no_use))
227228

utee/quant.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def compute_integral_part(input, overflow_rate):
1111
split_idx = int(overflow_rate * len(sorted_value))
1212
v = sorted_value[split_idx]
1313
if isinstance(v, Variable):
14-
v = v.data.cpu().numpy()[0]
14+
v = float(v.data.cpu())
1515
sf = math.ceil(math.log2(v+1e-12))
1616
return sf
1717

@@ -35,7 +35,7 @@ def log_minmax_quantize(input, bits):
3535

3636
s = torch.sign(input)
3737
input0 = torch.log(torch.abs(input) + 1e-20)
38-
v = min_max_quantize(input0, bits)
38+
v = min_max_quantize(input0, bits-1)
3939
v = torch.exp(v) * s
4040
return v
4141

@@ -46,7 +46,7 @@ def log_linear_quantize(input, sf, bits):
4646

4747
s = torch.sign(input)
4848
input0 = torch.log(torch.abs(input) + 1e-20)
49-
v = linear_quantize(input0, sf, bits)
49+
v = linear_quantize(input0, sf, bits-1)
5050
v = torch.exp(v) * s
5151
return v
5252

0 commit comments

Comments
 (0)