|
| 1 | +import os |
| 2 | +import subprocess |
| 3 | +import argparse |
| 4 | + |
| 5 | +import paddle |
| 6 | +from paddleslim.analysis import TableLatencyPredictor |
| 7 | + |
| 8 | +from paddle.vision.models import mobilenet_v1, mobilenet_v2 |
| 9 | + |
| 10 | +opt_tool = 'opt_ubuntu' # use in linux |
| 11 | +# opt_tool = 'opt_M1_mac' # use in mac with M1 chip |
| 12 | +# opt_tool = 'opt_intel_mac' # use in mac with intel chip |
| 13 | + |
| 14 | +parser = argparse.ArgumentParser(description='latency predictor') |
| 15 | +parser.add_argument('--model', type=str, help='which model to test.') |
| 16 | +parser.add_argument('--data_type', type=str, default='fp32') |
| 17 | + |
| 18 | +args = parser.parse_args() |
| 19 | + |
| 20 | +if not os.path.exists(opt_tool): |
| 21 | + subprocess.call( |
| 22 | + f'wget https://paddle-slim-models.bj.bcebos.com/LatencyPredictor/{opt_tool}', |
| 23 | + shell=True) |
| 24 | + subprocess.call(f'chmod +x {opt_tool}', shell=True) |
| 25 | + |
| 26 | + |
| 27 | +def get_latency(model, data_type): |
| 28 | + paddle.disable_static() |
| 29 | + predictor = TableLatencyPredictor( |
| 30 | + f'./{opt_tool}', hardware='845', threads=4, power_mode=3, batchsize=1) |
| 31 | + latency = predictor.predict_latency( |
| 32 | + model, |
| 33 | + input_shape=[1, 3, 224, 224], |
| 34 | + save_dir='./tmp_model', |
| 35 | + data_type=data_type, |
| 36 | + task_type='cls') |
| 37 | + print('{} latency : {}'.format(data_type, latency)) |
| 38 | + |
| 39 | + subprocess.call('rm -rf ./tmp_model', shell=True) |
| 40 | + paddle.disable_static() |
| 41 | + return latency |
| 42 | + |
| 43 | + |
| 44 | +if __name__ == '__main__': |
| 45 | + if args.model == 'mobilenet_v1': |
| 46 | + model = mobilenet_v1() |
| 47 | + elif args.model == 'mobilenet_v2': |
| 48 | + model = mobilenet_v2() |
| 49 | + else: |
| 50 | + assert False, f'model should be mobilenet_v1 or mobilenet_v2' |
| 51 | + |
| 52 | + latency = get_latency(model, args.data_type) |
| 53 | + |
| 54 | + if args.model == 'mobilenet_v1' and args.data_type == 'fp32': |
| 55 | + assert latency == 41.92806607483133 |
| 56 | + elif args.model == 'mobilenet_v1' and args.data_type == 'int8': |
| 57 | + assert latency == 36.64814722993898 |
| 58 | + elif args.model == 'mobilenet_v2' and args.data_type == 'fp32': |
| 59 | + assert latency == 27.847896889217566 |
| 60 | + elif args.model == 'mobilenet_v2' and args.data_type == 'int8': |
| 61 | + assert latency == 23.967800360138803 |
| 62 | + else: |
| 63 | + assert False, f'model or data_type wrong.' |
0 commit comments