|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | + |
| 3 | +# ATTENTION: The code in this file is highly EXPERIMENTAL. |
| 4 | +# Adventurous users should note that the APIs will probably change. |
| 5 | + |
| 6 | +"""onnx optimizer |
| 7 | +
|
| 8 | +This enables users to optimize their models. |
| 9 | +""" |
| 10 | + |
| 11 | +import onnx |
| 12 | +import onnx.checker |
| 13 | +import argparse |
| 14 | +import sys |
| 15 | +import onnxoptimizer |
| 16 | +import pathlib |
| 17 | + |
| 18 | +usage = 'python -m onnxoptimizer input_model.onnx output_model.onnx ' |
| 19 | + |
| 20 | + |
| 21 | +def format_argv(argv): |
| 22 | + argv_ = argv[1:] |
| 23 | + if len(argv_) == 1: |
| 24 | + return argv_ |
| 25 | + elif len(argv_) >= 2: |
| 26 | + return argv_[2:] |
| 27 | + else: |
| 28 | + print('please check arguments!') |
| 29 | + sys.exit(1) |
| 30 | + |
| 31 | + |
| 32 | +def main(): |
| 33 | + parser = argparse.ArgumentParser( |
| 34 | + prog='onnxoptimizer', |
| 35 | + usage=usage, |
| 36 | + description='onnxoptimizer command-line api') |
| 37 | + parser.add_argument('--print_all_passes', action='store_true', default=False, help='print all available passes') |
| 38 | + parser.add_argument('--print_fuse_elimination_passes', action='store_true', default=False, help='print all fuse and elimination passes') |
| 39 | + parser.add_argument('-p', '--passes', nargs='*', default=None, help='list of optimization passes name, if no set, fuse_and_elimination_passes will be used') |
| 40 | + parser.add_argument('--fixed_point', action='store_true', default=False, help='fixed point') |
| 41 | + argv = sys.argv.copy() |
| 42 | + args = parser.parse_args(format_argv(sys.argv)) |
| 43 | + |
| 44 | + all_available_passes = onnxoptimizer.get_available_passes() |
| 45 | + fuse_and_elimination_passes = onnxoptimizer.get_fuse_and_elimination_passes() |
| 46 | + |
| 47 | + if args.print_all_passes: |
| 48 | + print(*all_available_passes) |
| 49 | + sys.exit(0) |
| 50 | + |
| 51 | + if args.print_fuse_elimination_passes: |
| 52 | + print(*fuse_and_elimination_passes) |
| 53 | + sys.exit(0) |
| 54 | + |
| 55 | + passes = [] |
| 56 | + if args.passes is None: |
| 57 | + passes = fuse_and_elimination_passes |
| 58 | + |
| 59 | + if len(argv[1:]) < 2: |
| 60 | + print('usage:{}'.format(usage)) |
| 61 | + print('please check arguments!') |
| 62 | + sys.exit(1) |
| 63 | + |
| 64 | + input_file = argv[1] |
| 65 | + output_file = argv[2] |
| 66 | + |
| 67 | + if not pathlib.Path(input_file).exists(): |
| 68 | + print("input file: {0} no exist!".format(input_file)) |
| 69 | + sys.exit(1) |
| 70 | + |
| 71 | + model = onnx.load(input_file) |
| 72 | + |
| 73 | + # when model size large than 2G bytes, onnx.checker.check_model(model) will fail. |
| 74 | + # we use onnx.check.check_model(input_file) as workaround |
| 75 | + onnx.checker.check_model(input_file) |
| 76 | + model = onnxoptimizer.optimize(model=model, passes=passes, fixed_point=args.fixed_point) |
| 77 | + if model is None: |
| 78 | + print('onnxoptimizer failed') |
| 79 | + sys.exit(1) |
| 80 | + try: |
| 81 | + onnx.save(proto=model, f=output_file) |
| 82 | + except: |
| 83 | + onnx.save(proto=model, f=output_file, save_as_external_data=True) |
| 84 | + onnx.checker.check_model(output_file) |
0 commit comments