Skip to content
This repository was archived by the owner on Aug 5, 2022. It is now read-only.

Commit ab5baad

Browse files
committed
Work for 1.1.2 release
1 parent 12c7e5e commit ab5baad

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+52281
-1220
lines changed

Makefile.config.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ DISABLE_CONV_RELU_FUSION:= 0
8787
# Use Bn + ReLU fusion to boost inference
8888
DISABLE_BN_RELU_FUSION := 0
8989

90+
# Use Conv + Concat fusion to boost inference.
91+
ENABLE_CONCAT_FUSION := 0
92+
9093
# Use Conv + Eltwise + Relu layer fusion to boost inference.
9194
DISABLE_CONV_SUM_FUSION := 0
9295

cmake/Misc.cmake

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ if(DISABLE_BN_RELU_FUSION)
2121
add_definitions("-DDISABLE_BN_RELU_FUSION")
2222
endif()
2323

24+
if(ENABLE_CONCAT_FUSION)
25+
message(STATUS "conv/concat fusion is enabled!")
26+
add_definitions("-DENABLE_CONCAT_FUSION")
27+
endif()
28+
2429
if(DISABLE_CONV_SUM_FUSION)
2530
message(STATUS "conv/eltwise/relu fusion is disabled!")
2631
add_definitions("-DDISABLE_CONV_SUM_FUSION")
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#
2+
# All modification made by Intel Corporation: Copyright (c) 2016 Intel Corporation
3+
#
4+
# All contributions by the University of California:
5+
# Copyright (c) 2014, 2015, The Regents of the University of California (Regents)
6+
# All rights reserved.
7+
#
8+
# All other contributions:
9+
# Copyright (c) 2014, 2015, the respective contributors
10+
# All rights reserved.
11+
# For the list of contributors go to https://github.com/BVLC/caffe/blob/master/CONTRIBUTORS.md
12+
#
13+
#
14+
# Redistribution and use in source and binary forms, with or without
15+
# modification, are permitted provided that the following conditions are met:
16+
#
17+
# * Redistributions of source code must retain the above copyright notice,
18+
# this list of conditions and the following disclaimer.
19+
# * Redistributions in binary form must reproduce the above copyright
20+
# notice, this list of conditions and the following disclaimer in the
21+
# documentation and/or other materials provided with the distribution.
22+
# * Neither the name of Intel Corporation nor the names of its contributors
23+
# may be used to endorse or promote products derived from this software
24+
# without specific prior written permission.
25+
#
26+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
27+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
29+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
30+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
31+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
32+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
33+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
34+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
35+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36+
#
37+
import caffe
38+
import google.protobuf.text_format as txtf
39+
from caffe.proto import caffe_pb2
40+
41+
def update_conv_quantized_dict(conv_quantized_dict, tmp_conv_quantized_dict):
42+
for conv in conv_quantized_dict:
43+
if tmp_conv_quantized_dict[conv][0] > conv_quantized_dict[conv][0]:
44+
conv_quantized_dict[conv][0] = tmp_conv_quantized_dict[conv][0]
45+
46+
if tmp_conv_quantized_dict[conv][1] > conv_quantized_dict[conv][1]:
47+
conv_quantized_dict[conv][1] = tmp_conv_quantized_dict[conv][1]
48+
49+
50+
def create_quantized_net(raw_net, quantized_net, conv_quantized_dict):
51+
net_param = caffe_pb2.NetParameter()
52+
with open(raw_net) as f:
53+
txtf.Merge(f.read(), net_param)
54+
#skip first conv layer when quantizing net
55+
first_conv = True
56+
for layer_param in net_param.layer:
57+
if layer_param.type == "Convolution":
58+
if first_conv:
59+
first_conv = False
60+
continue
61+
layer_param.quantization_param.bw_layer_in = 8
62+
layer_param.quantization_param.bw_layer_out = 8
63+
layer_param.quantization_param.bw_params = 8
64+
layer_param.quantization_param.scale_in.append(conv_quantized_dict[layer_param.name][0])
65+
layer_param.quantization_param.scale_out.append(conv_quantized_dict[layer_param.name][1])
66+
for param_scale in conv_quantized_dict[layer_param.name][2]:
67+
layer_param.quantization_param.scale_params.append(param_scale)
68+
with open(quantized_net, 'w') as f:
69+
f.write(str(net_param))

examples/faster-rcnn/lib/fast_rcnn/test.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,20 @@
1010
from fast_rcnn.config import cfg, get_output_dir
1111
from fast_rcnn.bbox_transform import clip_boxes, bbox_transform_inv
1212
import argparse
13+
from fast_rcnn import net_sample_utils
1314
from utils.timer import Timer
1415
import numpy as np
1516
import cv2
1617
import caffe
18+
import os
19+
import sys
20+
utilpath = os.path.split(os.path.realpath(__file__))[0] + "/../../../../scripts/"
21+
print utilpath
22+
sys.path.insert(0, utilpath)
23+
import sampling
1724
from fast_rcnn.nms_wrapper import nms
1825
import cPickle
1926
from utils.blob import im_list_to_blob
20-
import os
2127

2228
def _get_image_blob(im):
2329
"""Converts an image into a network input.
@@ -105,7 +111,7 @@ def _get_blobs(im, rois):
105111
blobs['rois'] = _get_rois_blob(rois, im_scale_factors)
106112
return blobs, im_scale_factors
107113

108-
def im_detect(net, im, boxes=None):
114+
def im_detect(net, im, boxes=None, sample_blobs = None):
109115
"""Detect object classes in an image given object proposals.
110116
111117
Arguments:
@@ -152,7 +158,19 @@ def im_detect(net, im, boxes=None):
152158
else:
153159
forward_kwargs['rois'] = blobs['rois'].astype(np.float32, copy=False)
154160
blobs_out = net.forward(**forward_kwargs)
155-
161+
162+
163+
164+
if sample_blobs != None:
165+
for k, _ in net.blobs.items(): # top blob
166+
output = np.array(net.blobs[k].data)
167+
if k not in sample_blobs.keys():
168+
sample_blobs[k] = [output]
169+
else:
170+
new_outputs = sample_blobs[k]
171+
#print "new_outputs type: " + str(type(new_outputs))
172+
new_outputs.append(output)
173+
sample_blobs[k] = new_outputs
156174
if cfg.TEST.HAS_RPN:
157175
assert len(im_scales) == 1, "Only single-image batch implemented"
158176
rois = net.blobs['rois'].data.copy()
@@ -224,6 +242,30 @@ def apply_nms(all_boxes, thresh):
224242
nms_boxes[cls_ind][im_ind] = dets[keep, :].copy()
225243
return nms_boxes
226244

245+
246+
def sample_net(raw_net_prototxt, net, imdb, sampling_iterations, quant_mode, enable_first_conv = False, winograd_algo = False):
247+
248+
(conv_layers, test_net_top_names, test_net_bottom_names, conv_top_blob_layer_map, conv_bottom_blob_layer_map) = sampling.get_blob_map(net, enable_first_conv)
249+
#currently out sample_net only supports TEST with HAS_RPN and AGNOSTIC flag
250+
num_images = len(imdb.image_index)
251+
image_index = 0
252+
box_proposals = None
253+
sample_blobs = {}
254+
for iter_ in xrange(sampling_iterations):
255+
im = cv2.imread(imdb.image_path_at(image_index))
256+
im_detect(net, im, box_proposals, sample_blobs)
257+
image_index += 1
258+
259+
params = {}
260+
for k, _ in net.params.items():
261+
if k not in conv_layers:
262+
continue
263+
param = np.abs(net.params[k][0].data) # ignore bias
264+
params[k] = [param]
265+
266+
(winograd_bottoms, winograd_convolutions) = sampling.get_winograd_info(raw_net_prototxt, conv_bottom_blob_layer_map, winograd_algo)
267+
return (sample_blobs, params, test_net_top_names, test_net_bottom_names, conv_top_blob_layer_map, conv_bottom_blob_layer_map, winograd_bottoms, winograd_convolutions)
268+
227269
def test_net(net, imdb, max_per_image=100, thresh=0.05, vis=False):
228270
"""Test a Fast R-CNN network on an image database."""
229271
num_images = len(imdb.image_index)

examples/faster-rcnn/tools/test_net.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,19 @@
1111

1212
import _init_paths
1313
from fast_rcnn.test import test_net
14+
from fast_rcnn.test import sample_net
15+
from fast_rcnn import net_sample_utils
1416
from fast_rcnn.config import cfg, cfg_from_file, cfg_from_list
1517
from datasets.factory import get_imdb
1618
import caffe
19+
20+
import time, os, sys
21+
utilpath = os.path.split(os.path.realpath(__file__))[0] + "/../../../scripts/"
22+
sys.path.insert(0, utilpath)
23+
import calibrator
24+
import sampling
1725
import argparse
1826
import pprint
19-
import time, os, sys
2027

2128
def parse_args():
2229
"""
@@ -49,7 +56,27 @@ def parse_args():
4956
parser.add_argument('--num_dets', dest='max_per_image',
5057
help='max number of detections per image',
5158
default=100, type=int)
59+
parser.add_argument('--quantized_net', action='store', dest='quantized_prototxt',
60+
default=None, type=str)
61+
parser.add_argument('--sample_iters', action='store', dest='sample_iters',
62+
default=100, type=int)
63+
parser.add_argument('--quant_mode', action='store', dest='quant_mode',
64+
default='single', type=str)
65+
66+
parser.add_argument('-u', '--unsigned_range', dest='unsigned_range', action="store_true", default=False,
67+
help='to quantize using unsigned range for activation')
5268

69+
parser.add_argument('-t', '--concat_use_fp32', dest='concat_use_fp32', action="store_true", default=False,
70+
help='to use fp32 for concat')
71+
72+
parser.add_argument('-f', '--unify_concat_scales', dest='unify_concat_scales', action="store_true", default=False,
73+
help='to unify concat scales')
74+
75+
parser.add_argument('-a', '--calibration_algos', dest='calibration_algos', action='store', default="DIRECT",
76+
help='to choose the calibration alogorithm')
77+
78+
parser.add_argument('-wi', '--conv_algo', dest='conv_algo', action="store_true", default=False,
79+
help='to choose the convolution algorithm')
5380
if len(sys.argv) == 1:
5481
parser.print_help()
5582
sys.exit(1)
@@ -87,4 +114,17 @@ def parse_args():
87114
if not cfg.TEST.HAS_RPN:
88115
imdb.set_proposal_method(cfg.TEST.PROPOSAL_METHOD)
89116

90-
test_net(net, imdb, max_per_image=args.max_per_image, vis=args.vis)
117+
if args.quantized_prototxt == None:
118+
test_net(net, imdb, max_per_image=args.max_per_image, vis=args.vis)
119+
else:
120+
(blobs, params, top_blobs_map, bottom_blobs_map, conv_top_blob_layer_map, conv_bottom_blob_layer_map, winograd_bottoms, winograd_convolutions) = sample_net(args.prototxt, net, imdb, args.sample_iters, args.quant_mode)
121+
122+
(inputs_max, outputs_max, inputs_min) = sampling.calibrate_activations(blobs, conv_top_blob_layer_map, conv_bottom_blob_layer_map, winograd_bottoms, args.calibration_algos, "SINGLE", args.conv_algo)
123+
params_max = sampling.calibrate_parameters(params, winograd_convolutions, "DIRECT", args.quant_mode.upper(), args.conv_algo)
124+
calibrator.generate_sample_impl(args.prototxt, args.quantized_prototxt, inputs_max, outputs_max, inputs_min, params_max, False)
125+
compiled_net_str = caffe.compile_net(args.prototxt, caffe.TEST, "MKLDNN")
126+
raw_net_basename = os.path.basename(args.prototxt)
127+
compile_net_path = "./compiled_" + raw_net_basename
128+
with open(compile_net_path, "w") as f:
129+
f.write(compiled_net_str)
130+
calibrator.transform_convolutions(args.quantized_prototxt, compile_net_path, top_blobs_map, bottom_blobs_map, args.unsigned_range, args.concat_use_fp32, args.unify_concat_scales, args.conv_algo, False)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
#
2+
# All modification made by Intel Corporation: Copyright (c) 2016 Intel Corporation
3+
#
4+
# All contributions by the University of California:
5+
# Copyright (c) 2014, 2015, The Regents of the University of California (Regents)
6+
# All rights reserved.
7+
#
8+
# All other contributions:
9+
# Copyright (c) 2014, 2015, the respective contributors
10+
# All rights reserved.
11+
# For the list of contributors go to https://github.com/BVLC/caffe/blob/master/CONTRIBUTORS.md
12+
#
13+
#
14+
# Redistribution and use in source and binary forms, with or without
15+
# modification, are permitted provided that the following conditions are met:
16+
#
17+
# * Redistributions of source code must retain the above copyright notice,
18+
# this list of conditions and the following disclaimer.
19+
# * Redistributions in binary form must reproduce the above copyright
20+
# notice, this list of conditions and the following disclaimer in the
21+
# documentation and/or other materials provided with the distribution.
22+
# * Neither the name of Intel Corporation nor the names of its contributors
23+
# may be used to endorse or promote products derived from this software
24+
# without specific prior written permission.
25+
#
26+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
27+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
28+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
29+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
30+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
31+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
32+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
33+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
34+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
35+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
36+
#
37+
import caffe
38+
import google.protobuf.text_format as txtf
39+
from caffe.proto import caffe_pb2
40+
41+
def update_conv_quantized_dict(conv_quantized_dict, tmp_conv_quantized_dict):
42+
for conv in conv_quantized_dict:
43+
if tmp_conv_quantized_dict[conv][0] > conv_quantized_dict[conv][0]:
44+
conv_quantized_dict[conv][0] = tmp_conv_quantized_dict[conv][0]
45+
46+
if tmp_conv_quantized_dict[conv][1] > conv_quantized_dict[conv][1]:
47+
conv_quantized_dict[conv][1] = tmp_conv_quantized_dict[conv][1]
48+
49+
50+
def create_quantized_net(raw_net, quantized_net, conv_quantized_dict):
51+
net_param = caffe_pb2.NetParameter()
52+
with open(raw_net) as f:
53+
txtf.Merge(f.read(), net_param)
54+
#skip first conv layer when quantizing net
55+
first_conv = True
56+
for layer_param in net_param.layer:
57+
if layer_param.type == "Convolution":
58+
if first_conv:
59+
first_conv = False
60+
continue
61+
layer_param.quantization_param.bw_layer_in = 8
62+
layer_param.quantization_param.bw_layer_out = 8
63+
layer_param.quantization_param.bw_params = 8
64+
layer_param.quantization_param.scale_in.append(conv_quantized_dict[layer_param.name][0])
65+
layer_param.quantization_param.scale_out.append(conv_quantized_dict[layer_param.name][1])
66+
for param_scale in conv_quantized_dict[layer_param.name][2]:
67+
layer_param.quantization_param.scale_params.append(param_scale)
68+
with open(quantized_net, 'w') as f:
69+
f.write(str(net_param))

examples/rfcn/lib/fast_rcnn/test.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,21 @@
1010
from fast_rcnn.config import cfg, get_output_dir
1111
from fast_rcnn.bbox_transform import clip_boxes, bbox_transform_inv
1212
import argparse
13+
from fast_rcnn import net_sample_utils
1314
from utils.timer import Timer
1415
import numpy as np
1516
import cv2
1617
import caffe
18+
import os
19+
import sys
20+
utilpath = os.path.split(os.path.realpath(__file__))[0] + "/../../../../scripts/"
21+
print utilpath
22+
sys.path.insert(0, utilpath)
23+
import sampling
1724
from fast_rcnn.nms_wrapper import nms
1825
import cPickle
1926
import gzip
2027
from utils.blob import im_list_to_blob
21-
import os
2228

2329
def _get_image_blob(im):
2430
"""Converts an image into a network input.
@@ -106,7 +112,7 @@ def _get_blobs(im, rois):
106112
blobs['rois'] = _get_rois_blob(rois, im_scale_factors)
107113
return blobs, im_scale_factors
108114

109-
def im_detect(net, im, boxes=None):
115+
def im_detect(net, im, boxes=None, sample_blobs = None):
110116
"""Detect object classes in an image given object proposals.
111117
112118
Arguments:
@@ -154,6 +160,16 @@ def im_detect(net, im, boxes=None):
154160
forward_kwargs['rois'] = blobs['rois'].astype(np.float32, copy=False)
155161
blobs_out = net.forward(**forward_kwargs)
156162

163+
if sample_blobs != None:
164+
for k, _ in net.blobs.items(): # top blob
165+
output = np.array(net.blobs[k].data)
166+
if k not in sample_blobs.keys():
167+
sample_blobs[k] = [output]
168+
else:
169+
new_outputs = sample_blobs[k]
170+
#print "new_outputs type: " + str(type(new_outputs))
171+
new_outputs.append(output)
172+
sample_blobs[k] = new_outputs
157173
if cfg.TEST.HAS_RPN:
158174
assert len(im_scales) == 1, "Only single-image batch implemented"
159175
rois = net.blobs['rois'].data.copy()
@@ -225,6 +241,28 @@ def apply_nms(all_boxes, thresh):
225241
nms_boxes[cls_ind][im_ind] = dets[keep, :].copy()
226242
return nms_boxes
227243

244+
def sample_net(raw_net_prototxt, net, imdb, sampling_iterations, quant_mode, enable_first_conv = False, winograd_algo = False):
245+
246+
(conv_layers, test_net_top_names, test_net_bottom_names, conv_top_blob_layer_map, conv_bottom_blob_layer_map) = sampling.get_blob_map(net, enable_first_conv)
247+
#currently out sample_net only supports TEST with HAS_RPN and AGNOSTIC flag
248+
num_images = len(imdb.image_index)
249+
image_index = 0
250+
box_proposals = None
251+
sample_blobs = {}
252+
for iter_ in xrange(sampling_iterations):
253+
im = cv2.imread(imdb.image_path_at(image_index))
254+
im_detect(net, im, box_proposals, sample_blobs)
255+
image_index += 1
256+
257+
params = {}
258+
for k, _ in net.params.items():
259+
if k not in conv_layers:
260+
continue
261+
param = np.abs(net.params[k][0].data) # ignore bias
262+
params[k] = [param]
263+
264+
(winograd_bottoms, winograd_convolutions) = sampling.get_winograd_info(raw_net_prototxt, conv_bottom_blob_layer_map, winograd_algo)
265+
return (sample_blobs, params, test_net_top_names, test_net_bottom_names, conv_top_blob_layer_map, conv_bottom_blob_layer_map, winograd_bottoms, winograd_convolutions)
228266
def test_net(net, imdb, max_per_image=400, thresh=-np.inf, vis=False):
229267
"""Test a Fast R-CNN network on an image database."""
230268
num_images = len(imdb.image_index)

0 commit comments

Comments
 (0)