Skip to content

Commit 8fe111b

Browse files
authored
change threshold for ptq hpo (#1254)
1 parent c590123 commit 8fe111b

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

paddleslim/auto_compression/auto_strategy.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@
7777
MAGIC_SPARSE_RATIO = 0.75
7878
### TODO: 0.02 threshold maybe not suitable, need to check
7979
### NOTE: reduce magic data to choose quantization aware training.
80-
MAGIC_MAX_EMD_DISTANCE = 0.0002 #0.02
81-
MAGIC_MIN_EMD_DISTANCE = 0.0001 #0.01
80+
MAGIC_MAX_EMD_DISTANCE = 0.00002 #0.02
81+
MAGIC_MIN_EMD_DISTANCE = 0.00001 #0.01
8282

8383
DEFAULT_TRANSFORMER_STRATEGY = 'prune_0.25_int8'
8484
DEFAULT_STRATEGY = 'origin_int8'

paddleslim/quant/post_quant_hpo.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,12 @@ def standardization(data):
144144
"""standardization numpy array"""
145145
mu = np.mean(data, axis=0)
146146
sigma = np.std(data, axis=0)
147-
sigma = 1e-13 if sigma == 0. else sigma
147+
if isinstance(sigma, list) or isinstance(sigma, np.ndarray):
148+
for idx, sig in enumerate(sigma):
149+
if sig == 0.:
150+
sigma[idx] = 1e-13
151+
else:
152+
sigma = 1e-13 if sigma == 0. else sigma
148153
return (data - mu) / sigma
149154

150155

@@ -241,18 +246,15 @@ def eval_quant_model():
241246
if have_invalid_num(out_float) or have_invalid_num(out_quant):
242247
continue
243248

244-
try:
245-
out_float = standardization(out_float)
246-
out_quant = standardization(out_quant)
247-
except:
248-
continue
249-
out_float_list.append(out_float)
250-
out_quant_list.append(out_quant)
249+
out_float_list.append(list(out_float))
250+
out_quant_list.append(list(out_quant))
251251
valid_data_num += 1
252252

253253
if valid_data_num >= max_eval_data_num:
254254
break
255255

256+
out_float_list = standardization(out_float_list)
257+
out_quant_list = standardization(out_quant_list)
256258
emd_sum = cal_emd_lose(out_float_list, out_quant_list,
257259
out_len_sum / float(valid_data_num))
258260
_logger.info("output diff: {}".format(emd_sum))

0 commit comments

Comments
 (0)