Skip to content

Commit c2cf8da

Browse files
authored
command line api (onnx#112)
* command line api Signed-off-by: haoshengqiang <[email protected]> * support positional args Signed-off-by: haoshengqiang <[email protected]> * update Signed-off-by: haoshengqiang <[email protected]> --------- Signed-off-by: haoshengqiang <[email protected]>
1 parent 8e6d2aa commit c2cf8da

File tree

5 files changed

+130
-3
lines changed

5 files changed

+130
-3
lines changed

README.md

+24-1
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,32 @@ pip3 install -e .
3838

3939
Note that you need to install protobuf before building from source.
4040

41+
42+
## Command-line API
43+
Now you can use command-line api in terminal instead of python script.
44+
45+
```
46+
python -m onnxoptimizer input_model.onnx output_model.onnx
47+
```
48+
49+
Arguments list is following:
50+
```
51+
# python3 -m onnxoptimizer -h
52+
usage: python -m onnxoptimizer input_model.onnx output_model.onnx
53+
54+
onnxoptimizer command-line api
55+
56+
optional arguments:
57+
-h, --help show this help message and exit
58+
--print_all_passes print all available passes
59+
--print_fuse_elimination_passes
60+
print all fuse and elimination passes
61+
-p [PASSES ...], --passes [PASSES ...]
62+
list of optimization passes name, if no set, fuse_and_elimination_passes will be used
63+
--fixed_point fixed point
64+
```
4165
## Roadmap
4266

43-
* Command-line API (e.g. `python3 -m onnxoptimizer model.onnx output.onnx`)
4467
* More built-in pass
4568
* Separate graph rewriting and constant folding (or a pure graph rewriting mode, see [issue #9](https://github.com/onnx/optimizer/issues/9) for the details)
4669

onnxoptimizer/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .version import version as __version__ # noqa
1414
from onnx import ModelProto
1515
from typing import Text, Sequence, Optional
16-
16+
from onnxoptimizer.onnxoptimizer_main import main
1717
import tempfile
1818
import os
1919

@@ -71,4 +71,4 @@ def optimize(model, passes=None, fixed_point=False): # type: (ModelProto, Optio
7171
os.remove(data_file_dest.name)
7272

7373

74-
__all__ = ['optimize', 'get_available_passes', 'get_fuse_and_elimination_passes']
74+
__all__ = ['optimize', 'get_available_passes', 'get_fuse_and_elimination_passes', 'main']

onnxoptimizer/__main__.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# ATTENTION: The code in this file is highly EXPERIMENTAL.
4+
# Adventurous users should note that the APIs will probably change.
5+
6+
"""onnx optimizer
7+
8+
This enables users to optimize their models.
9+
"""
10+
11+
from . import main
12+
13+
14+
if __name__ == '__main__':
15+
main()

onnxoptimizer/onnxoptimizer_main.py

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# ATTENTION: The code in this file is highly EXPERIMENTAL.
4+
# Adventurous users should note that the APIs will probably change.
5+
6+
"""onnx optimizer
7+
8+
This enables users to optimize their models.
9+
"""
10+
11+
import onnx
12+
import onnx.checker
13+
import argparse
14+
import sys
15+
import onnxoptimizer
16+
import pathlib
17+
18+
usage = 'python -m onnxoptimizer input_model.onnx output_model.onnx '
19+
20+
21+
def format_argv(argv):
22+
argv_ = argv[1:]
23+
if len(argv_) == 1:
24+
return argv_
25+
elif len(argv_) >= 2:
26+
return argv_[2:]
27+
else:
28+
print('please check arguments!')
29+
sys.exit(1)
30+
31+
32+
def main():
33+
parser = argparse.ArgumentParser(
34+
prog='onnxoptimizer',
35+
usage=usage,
36+
description='onnxoptimizer command-line api')
37+
parser.add_argument('--print_all_passes', action='store_true', default=False, help='print all available passes')
38+
parser.add_argument('--print_fuse_elimination_passes', action='store_true', default=False, help='print all fuse and elimination passes')
39+
parser.add_argument('-p', '--passes', nargs='*', default=None, help='list of optimization passes name, if no set, fuse_and_elimination_passes will be used')
40+
parser.add_argument('--fixed_point', action='store_true', default=False, help='fixed point')
41+
argv = sys.argv.copy()
42+
args = parser.parse_args(format_argv(sys.argv))
43+
44+
all_available_passes = onnxoptimizer.get_available_passes()
45+
fuse_and_elimination_passes = onnxoptimizer.get_fuse_and_elimination_passes()
46+
47+
if args.print_all_passes:
48+
print(*all_available_passes)
49+
sys.exit(0)
50+
51+
if args.print_fuse_elimination_passes:
52+
print(*fuse_and_elimination_passes)
53+
sys.exit(0)
54+
55+
passes = []
56+
if args.passes is None:
57+
passes = fuse_and_elimination_passes
58+
59+
if len(argv[1:]) < 2:
60+
print('usage:{}'.format(usage))
61+
print('please check arguments!')
62+
sys.exit(1)
63+
64+
input_file = argv[1]
65+
output_file = argv[2]
66+
67+
if not pathlib.Path(input_file).exists():
68+
print("input file: {0} no exist!".format(input_file))
69+
sys.exit(1)
70+
71+
model = onnx.load(input_file)
72+
73+
# when model size large than 2G bytes, onnx.checker.check_model(model) will fail.
74+
# we use onnx.check.check_model(input_file) as workaround
75+
onnx.checker.check_model(input_file)
76+
model = onnxoptimizer.optimize(model=model, passes=passes, fixed_point=args.fixed_point)
77+
if model is None:
78+
print('onnxoptimizer failed')
79+
sys.exit(1)
80+
try:
81+
onnx.save(proto=model, f=output_file)
82+
except:
83+
onnx.save(proto=model, f=output_file, save_as_external_data=True)
84+
onnx.checker.check_model(output_file)

setup.py

+5
Original file line numberDiff line numberDiff line change
@@ -335,4 +335,9 @@ def run(self):
335335
keywords='deep-learning ONNX',
336336
long_description=long_description,
337337
long_description_content_type='text/markdown',
338+
entry_points={
339+
'console_scripts': [
340+
'onnxoptimizer=onnxoptimizer:main',
341+
],
342+
},
338343
)

0 commit comments

Comments
 (0)