@@ -144,7 +144,12 @@ def standardization(data):
144
144
"""standardization numpy array"""
145
145
mu = np .mean (data , axis = 0 )
146
146
sigma = np .std (data , axis = 0 )
147
- sigma = 1e-13 if sigma == 0. else sigma
147
+ if isinstance (sigma , list ) or isinstance (sigma , np .ndarray ):
148
+ for idx , sig in enumerate (sigma ):
149
+ if sig == 0. :
150
+ sigma [idx ] = 1e-13
151
+ else :
152
+ sigma = 1e-13 if sigma == 0. else sigma
148
153
return (data - mu ) / sigma
149
154
150
155
@@ -241,18 +246,15 @@ def eval_quant_model():
241
246
if have_invalid_num (out_float ) or have_invalid_num (out_quant ):
242
247
continue
243
248
244
- try :
245
- out_float = standardization (out_float )
246
- out_quant = standardization (out_quant )
247
- except :
248
- continue
249
- out_float_list .append (out_float )
250
- out_quant_list .append (out_quant )
249
+ out_float_list .append (list (out_float ))
250
+ out_quant_list .append (list (out_quant ))
251
251
valid_data_num += 1
252
252
253
253
if valid_data_num >= max_eval_data_num :
254
254
break
255
255
256
+ out_float_list = standardization (out_float_list )
257
+ out_quant_list = standardization (out_quant_list )
256
258
emd_sum = cal_emd_lose (out_float_list , out_quant_list ,
257
259
out_len_sum / float (valid_data_num ))
258
260
_logger .info ("output diff: {}" .format (emd_sum ))
0 commit comments