@@ -89,6 +89,7 @@ class Detector(object):
89
89
calibration, trt_calib_mode need to set True
90
90
cpu_threads (int): cpu threads
91
91
enable_mkldnn (bool): whether to open MKLDNN
92
+ enable_mkldnn_bfloat16 (bool): whether to turn on mkldnn bfloat16
92
93
output_dir (str): The path of output
93
94
threshold (float): The threshold of score for visualization
94
95
"""
@@ -105,6 +106,7 @@ def __init__(
105
106
trt_calib_mode = False ,
106
107
cpu_threads = 1 ,
107
108
enable_mkldnn = False ,
109
+ enable_mkldnn_bfloat16 = False ,
108
110
output_dir = 'output' ,
109
111
threshold = 0.5 , ):
110
112
self .pred_config = self .set_config (model_dir )
@@ -120,7 +122,8 @@ def __init__(
120
122
trt_opt_shape = trt_opt_shape ,
121
123
trt_calib_mode = trt_calib_mode ,
122
124
cpu_threads = cpu_threads ,
123
- enable_mkldnn = enable_mkldnn )
125
+ enable_mkldnn = enable_mkldnn ,
126
+ enable_mkldnn_bfloat16 = enable_mkldnn_bfloat16 )
124
127
self .det_times = Timer ()
125
128
self .cpu_mem , self .gpu_mem , self .gpu_util = 0 , 0 , 0
126
129
self .batch_size = batch_size
@@ -323,6 +326,7 @@ class DetectorSOLOv2(Detector):
323
326
calibration, trt_calib_mode need to set True
324
327
cpu_threads (int): cpu threads
325
328
enable_mkldnn (bool): whether to open MKLDNN
329
+ enable_mkldnn_bfloat16 (bool): Whether to turn on mkldnn bfloat16
326
330
output_dir (str): The path of output
327
331
threshold (float): The threshold of score for visualization
328
332
@@ -340,6 +344,7 @@ def __init__(
340
344
trt_calib_mode = False ,
341
345
cpu_threads = 1 ,
342
346
enable_mkldnn = False ,
347
+ enable_mkldnn_bfloat16 = False ,
343
348
output_dir = './' ,
344
349
threshold = 0.5 , ):
345
350
super (DetectorSOLOv2 , self ).__init__ (
@@ -353,6 +358,7 @@ def __init__(
353
358
trt_calib_mode = trt_calib_mode ,
354
359
cpu_threads = cpu_threads ,
355
360
enable_mkldnn = enable_mkldnn ,
361
+ enable_mkldnn_bfloat16 = enable_mkldnn_bfloat16 ,
356
362
output_dir = output_dir ,
357
363
threshold = threshold , )
358
364
@@ -399,7 +405,8 @@ class DetectorPicoDet(Detector):
399
405
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
400
406
calibration, trt_calib_mode need to set True
401
407
cpu_threads (int): cpu threads
402
- enable_mkldnn (bool): whether to open MKLDNN
408
+ enable_mkldnn (bool): whether to turn on MKLDNN
409
+ enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16
403
410
"""
404
411
405
412
def __init__ (
@@ -414,6 +421,7 @@ def __init__(
414
421
trt_calib_mode = False ,
415
422
cpu_threads = 1 ,
416
423
enable_mkldnn = False ,
424
+ enable_mkldnn_bfloat16 = False ,
417
425
output_dir = './' ,
418
426
threshold = 0.5 , ):
419
427
super (DetectorPicoDet , self ).__init__ (
@@ -427,6 +435,7 @@ def __init__(
427
435
trt_calib_mode = trt_calib_mode ,
428
436
cpu_threads = cpu_threads ,
429
437
enable_mkldnn = enable_mkldnn ,
438
+ enable_mkldnn_bfloat16 = enable_mkldnn_bfloat16 ,
430
439
output_dir = output_dir ,
431
440
threshold = threshold , )
432
441
@@ -571,7 +580,8 @@ def load_predictor(model_dir,
571
580
trt_opt_shape = 640 ,
572
581
trt_calib_mode = False ,
573
582
cpu_threads = 1 ,
574
- enable_mkldnn = False ):
583
+ enable_mkldnn = False ,
584
+ enable_mkldnn_bfloat16 = False ):
575
585
"""set AnalysisConfig, generate AnalysisPredictor
576
586
Args:
577
587
model_dir (str): root path of __model__ and __params__
@@ -611,6 +621,8 @@ def load_predictor(model_dir,
611
621
# cache 10 different shapes for mkldnn to avoid memory leak
612
622
config .set_mkldnn_cache_capacity (10 )
613
623
config .enable_mkldnn ()
624
+ if enable_mkldnn_bfloat16 :
625
+ config .enable_mkldnn_bfloat16 ()
614
626
except Exception as e :
615
627
print (
616
628
"The current environment does not support `mkldnn`, so disable mkldnn."
@@ -747,6 +759,7 @@ def main():
747
759
trt_calib_mode = FLAGS .trt_calib_mode ,
748
760
cpu_threads = FLAGS .cpu_threads ,
749
761
enable_mkldnn = FLAGS .enable_mkldnn ,
762
+ enable_mkldnn_bfloat16 = FLAGS .enable_mkldnn_bfloat16 ,
750
763
threshold = FLAGS .threshold ,
751
764
output_dir = FLAGS .output_dir )
752
765
@@ -781,4 +794,6 @@ def main():
781
794
], "device should be CPU, GPU or XPU"
782
795
assert not FLAGS .use_gpu , "use_gpu has been deprecated, please use --device"
783
796
797
+ assert not (FLAGS .enable_mkldnn == False and FLAGS .enable_mkldnn_bfloat16 == True ), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16'
798
+
784
799
main ()
0 commit comments