Skip to content

Commit 64e96e3

Browse files
authored
Support onnxruntime fp16 (#2269)
* support ort-fp16 * update configs * update * update reg ci * fix mmrotate mmdet3d ort fp16 * fix dead links
1 parent 553f9b8 commit 64e96e3

29 files changed

+191
-16
lines changed

.github/workflows/regression-test.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ jobs:
111111
apt update && apt install unzip
112112
python -V
113113
python -m pip install --upgrade pip
114-
python -m pip install openmim numpy pycuda xlsxwriter packaging prettytable
114+
python -m pip install openmim numpy pycuda xlsxwriter packaging prettytable onnxconverter-common
115115
python -m pip install opencv-python==4.5.4.60 opencv-python-headless==4.5.4.60 opencv-contrib-python==4.5.4.60
116116
python .github/scripts/prepare_reg_test.py --torch-version ${{ matrix.torch_version }} --codebases ${{ matrix.codebase}}
117117
python -m pip install -r requirements.txt
@@ -221,7 +221,7 @@ jobs:
221221
conda activate $env:TEMP_ENV
222222
python -V
223223
python -m pip install --upgrade pip
224-
python -m pip install openmim numpy pycuda xlsxwriter packaging prettytable
224+
python -m pip install openmim numpy pycuda xlsxwriter packaging prettytable onnxconverter-common
225225
python -m pip install opencv-python==4.5.4.60 opencv-python-headless==4.5.4.60 opencv-contrib-python==4.5.4.60
226226
python .github/scripts/prepare_reg_test.py --torch-version ${{ matrix.torch_version }} --codebases ${{ matrix.codebase}}
227227
python -m pip install -r requirements.txt
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
backend_config = dict(
2+
type='onnxruntime',
3+
precision='fp16',
4+
common_config=dict(
5+
min_positive_val=1e-7,
6+
max_finite_val=1e4,
7+
keep_io_types=False,
8+
disable_shape_infer=False,
9+
op_block_list=None,
10+
node_block_list=None))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
_base_ = [
2+
'./video-recognition_static.py',
3+
'../../_base_/backends/onnxruntime-fp16.py'
4+
]
5+
6+
onnx_config = dict(input_shape=None)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_base_ = [
2+
'./super-resolution_dynamic.py',
3+
'../../_base_/backends/onnxruntime-fp16.py'
4+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_base_ = [
2+
'../_base_/base_dynamic.py', '../../_base_/backends/onnxruntime-fp16.py'
3+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_base_ = [
2+
'../_base_/base_instance-seg_dynamic.py',
3+
'../../_base_/backends/onnxruntime-fp16.py'
4+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_base_ = [
2+
'./voxel-detection_dynamic.py', '../../_base_/backends/onnxruntime-fp16.py'
3+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_base_ = [
2+
'./text-detection_dynamic.py', '../../_base_/backends/onnxruntime-fp16.py'
3+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
_base_ = [
2+
'./text-recognition_dynamic.py',
3+
'../../_base_/backends/onnxruntime-fp16.py'
4+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_base_ = [
2+
'./pose-detection_static.py', '../_base_/backends/onnxruntime-fp16.py'
3+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
_base_ = [
2+
'./pose-detection_static.py', '../_base_/backends/onnxruntime-fp16.py'
3+
]
4+
5+
onnx_config = dict(
6+
input_shape=[192, 256],
7+
output_names=['simcc_x', 'simcc_y'],
8+
dynamic_axes={
9+
'input': {
10+
0: 'batch',
11+
},
12+
'simcc_x': {
13+
0: 'batch'
14+
},
15+
'simcc_y': {
16+
0: 'batch'
17+
}
18+
})
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_base_ = [
2+
'./classification_dynamic.py', '../_base_/backends/onnxruntime-fp16.py'
3+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
_base_ = [
2+
'./rotated-detection_static.py', '../_base_/backends/onnxruntime-fp16.py'
3+
]
4+
5+
onnx_config = dict(
6+
output_names=['dets', 'labels'],
7+
input_shape=[1024, 1024],
8+
dynamic_axes={
9+
'input': {
10+
0: 'batch',
11+
2: 'height',
12+
3: 'width'
13+
},
14+
'dets': {
15+
0: 'batch',
16+
1: 'num_dets',
17+
},
18+
'labels': {
19+
0: 'batch',
20+
1: 'num_dets',
21+
},
22+
})
23+
24+
backend_config = dict(
25+
common_config=dict(op_block_list=['NMSRotated', 'Resize']))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
_base_ = [
2+
'./segmentation_dynamic.py', '../_base_/backends/onnxruntime-fp16.py'
3+
]

docs/en/05-supported-backends/onnxruntime.md

+8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ pip install onnxruntime==1.8.1 # if you want to use cpu version
2222
pip install onnxruntime-gpu==1.8.1 # if you want to use gpu version
2323
```
2424

25+
### Install float16 conversion tool (optional)
26+
27+
If you want to use float16 precision, install the tool by running the following script:
28+
29+
```bash
30+
pip install onnx onnxconverter-common
31+
```
32+
2533
## Build custom ops
2634

2735
### Download ONNXRuntime Library

docs/en/05-supported-backends/openvino.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Notes:
8383

8484
- Custom operations from OpenVINO use the domain `org.openvinotoolkit`.
8585
- For faster work in OpenVINO in the Faster-RCNN, Mask-RCNN, Cascade-RCNN, Cascade-Mask-RCNN models
86-
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvinotoolkit.org/latest/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.
86+
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvino.ai/2022.3/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.
8787
- Models "VFNet" and "Faster R-CNN + DCN" use the custom "DeformableConv2D" operation.
8888

8989
## Deployment config

docs/zh_cn/05-supported-backends/onnxruntime.md

+8
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ pip install onnxruntime==1.8.1 # if you want to use cpu version
2222
pip install onnxruntime-gpu==1.8.1 # if you want to use gpu version
2323
```
2424

25+
### Install float16 conversion tool (optional)
26+
27+
If you want to use float16 precision, install the tool by running the following script:
28+
29+
```bash
30+
pip install onnx onnxconverter-common
31+
```
32+
2533
## Build custom ops
2634

2735
### Download ONNXRuntime Library

docs/zh_cn/05-supported-backends/openvino.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ Notes:
8383

8484
- Custom operations from OpenVINO use the domain `org.openvinotoolkit`.
8585
- For faster work in OpenVINO in the Faster-RCNN, Mask-RCNN, Cascade-RCNN, Cascade-Mask-RCNN models
86-
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvinotoolkit.org/latest/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.
86+
the RoiAlign operation is replaced with the [ExperimentalDetectronROIFeatureExtractor](https://docs.openvino.ai/2022.3/openvino_docs_ops_detection_ExperimentalDetectronROIFeatureExtractor_6.html) operation in the ONNX graph.
8787
- Models "VFNet" and "Faster R-CNN + DCN" use the custom "DeformableConv2D" operation.
8888

8989
## Deployment config

mmdeploy/backend/onnxruntime/backend_manager.py

+14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os.path as osp
44
from typing import Any, Callable, Optional, Sequence
55

6+
from mmdeploy.utils import get_backend_config, get_common_config
67
from ..base import BACKEND_MANAGERS, BaseBackendManager
78

89

@@ -125,6 +126,7 @@ def check_env(cls, log_callback: Callable = lambda _: _) -> str:
125126
def to_backend(cls,
126127
ir_files: Sequence[str],
127128
work_dir: str,
129+
deploy_cfg: Any,
128130
log_level: int = logging.INFO,
129131
device: str = 'cpu',
130132
**kwargs) -> Sequence[str]:
@@ -134,9 +136,21 @@ def to_backend(cls,
134136
ir_files (Sequence[str]): The intermediate representation files.
135137
work_dir (str): The work directory, backend files and logs should
136138
be saved in this directory.
139+
deploy_cfg (Any): The deploy config.
137140
log_level (int, optional): The log level. Defaults to logging.INFO.
138141
device (str, optional): The device type. Defaults to 'cpu'.
139142
Returns:
140143
Sequence[str]: Backend files.
141144
"""
145+
backend_cfg = get_backend_config(deploy_cfg)
146+
147+
precision = backend_cfg.get('precision', 'fp32')
148+
if precision == 'fp16':
149+
import onnx
150+
from onnxconverter_common import float16
151+
152+
common_cfg = get_common_config(deploy_cfg)
153+
model = onnx.load(ir_files[0])
154+
model_fp16 = float16.convert_float_to_float16(model, **common_cfg)
155+
onnx.save(model_fp16, ir_files[0])
142156
return ir_files

mmdeploy/backend/onnxruntime/wrapper.py

+7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os.path as osp
33
from typing import Dict, Optional, Sequence
44

5+
import numpy as np
56
import onnxruntime as ort
67
import torch
78

@@ -58,6 +59,7 @@ def __init__(self,
5859
if output_names is None:
5960
output_names = [_.name for _ in sess.get_outputs()]
6061
self.sess = sess
62+
self._input_metas = {_.name: _ for _ in sess.get_inputs()}
6163
self.io_binding = sess.io_binding()
6264
self.device_id = device_id
6365
self.device_type = 'cpu' if device == 'cpu' else 'cuda'
@@ -75,6 +77,9 @@ def forward(self, inputs: Dict[str,
7577
"""
7678
for name, input_tensor in inputs.items():
7779
# set io binding for inputs/outputs
80+
input_type = self._input_metas[name].type
81+
if 'float16' in input_type:
82+
input_tensor = input_tensor.to(torch.float16)
7883
input_tensor = input_tensor.contiguous()
7984
if self.device_type == 'cpu':
8085
input_tensor = input_tensor.cpu()
@@ -98,6 +103,8 @@ def forward(self, inputs: Dict[str,
98103
output_list = self.io_binding.copy_outputs_to_cpu()
99104
outputs = {}
100105
for output_name, numpy_tensor in zip(self._output_names, output_list):
106+
if numpy_tensor.dtype == np.float16:
107+
numpy_tensor = numpy_tensor.astype(np.float32)
101108
outputs[output_name] = torch.from_numpy(numpy_tensor)
102109

103110
return outputs

tests/regression/mmaction.yml

+5-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ onnxruntime:
2828
convert_image: *convert_image
2929
deploy_config: configs/mmaction/video-recognition/video-recognition_onnxruntime_static.py
3030
backend_test: *default_backend_test
31+
pipeline_ort_static_fp16: &pipeline_ort_static_fp16
32+
convert_image: *convert_image
33+
deploy_config: configs/mmaction/video-recognition/video-recognition_onnxruntime-fp16_static.py
34+
backend_test: *default_backend_test
3135

3236
torchscript:
3337
pipeline_torchscript_fp32: &pipeline_torchscript_fp32
@@ -51,7 +55,7 @@ models:
5155
model_configs:
5256
- configs/recognition/tsn/tsn_imagenet-pretrained-r50_8xb32-1x1x3-100e_kinetics400-rgb.py
5357
pipelines:
54-
- *pipeline_ort_static_fp32
58+
- *pipeline_ort_static_fp16
5559
- *pipeline_trt_2d_static_fp32
5660
- *pipeline_torchscript_fp32
5761

tests/regression/mmagic.yml

+6-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,11 @@ onnxruntime:
4646
convert_image: *convert_image
4747
deploy_config: configs/mmagic/super-resolution/super-resolution_onnxruntime_dynamic.py
4848

49+
pipeline_ort_dynamic_fp16: &pipeline_ort_dynamic_fp16
50+
convert_image: *convert_image
51+
deploy_config: configs/mmagic/super-resolution/super-resolution_onnxruntime-fp16_dynamic.py
52+
53+
4954
tensorrt:
5055
pipeline_trt_static_fp32: &pipeline_trt_static_fp32
5156
convert_image: *convert_image
@@ -114,7 +119,7 @@ models:
114119
- configs/srcnn/srcnn_x4k915_1xb16-1000k_div2k.py
115120
pipelines:
116121
- *pipeline_ts_fp32
117-
- *pipeline_ort_dynamic_fp32
122+
- *pipeline_ort_dynamic_fp16
118123
# - *pipeline_trt_dynamic_fp32
119124
- *pipeline_trt_dynamic_fp16
120125
# - *pipeline_trt_dynamic_int8

tests/regression/mmdet.yml

+12-2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ onnxruntime:
3838
backend_test: False
3939
deploy_config: configs/mmdet/detection/detection_onnxruntime_dynamic.py
4040

41+
pipeline_ort_dynamic_fp16: &pipeline_ort_dynamic_fp16
42+
convert_image: *convert_image
43+
backend_test: False
44+
deploy_config: configs/mmdet/detection/detection_onnxruntime-fp16_dynamic.py
45+
4146
pipeline_seg_ort_static_fp32: &pipeline_seg_ort_static_fp32
4247
convert_image: *convert_image
4348
backend_test: False
@@ -48,6 +53,11 @@ onnxruntime:
4853
backend_test: False
4954
deploy_config: configs/mmdet/instance-seg/instance-seg_onnxruntime_dynamic.py
5055

56+
pipeline_seg_ort_dynamic_fp16: &pipeline_seg_ort_dynamic_fp16
57+
convert_image: *convert_image
58+
backend_test: False
59+
deploy_config: configs/mmdet/instance-seg/instance-seg_onnxruntime-fp16_dynamic.py
60+
5161
tensorrt:
5262
pipeline_trt_static_fp32: &pipeline_trt_static_fp32
5363
convert_image: *convert_image
@@ -203,7 +213,7 @@ models:
203213
- configs/retinanet/retinanet_r50_fpn_1x_coco.py
204214
pipelines:
205215
- *pipeline_ts_fp32
206-
- *pipeline_ort_dynamic_fp32
216+
- *pipeline_ort_dynamic_fp16
207217
- *pipeline_trt_dynamic_fp32
208218
- *pipeline_ncnn_static_fp32
209219
- *pipeline_pplnn_dynamic_fp32
@@ -323,7 +333,7 @@ models:
323333
- configs/mask_rcnn/mask-rcnn_r50_fpn_1x_coco.py
324334
pipelines:
325335
- *pipeline_seg_ts_fp32
326-
- *pipeline_seg_ort_dynamic_fp32
336+
- *pipeline_seg_ort_dynamic_fp16
327337
- *pipeline_seg_trt_dynamic_fp32
328338
- *pipeline_seg_openvino_dynamic_fp32
329339

tests/regression/mmdet3d.yml

+6-1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,11 @@ onnxruntime:
4242
backend_test: False
4343
deploy_config: configs/mmdet3d/voxel-detection/voxel-detection_onnxruntime_dynamic.py
4444

45+
pipeline_ort_dynamic_kitti_fp16: &pipeline_ort_dynamic_kitti_fp16
46+
convert_image: *convert_image
47+
backend_test: False
48+
deploy_config: configs/mmdet3d/voxel-detection/voxel-detection_onnxruntime-fp16_dynamic.py
49+
4550
pipeline_ort_dynamic_nus_fp32: &pipeline_ort_dynamic_nus_fp32
4651
convert_image: *convert_image_nus
4752
backend_test: False
@@ -86,7 +91,7 @@ models:
8691
- configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class.py
8792
- configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-car.py
8893
pipelines:
89-
- *pipeline_ort_dynamic_kitti_fp32
94+
- *pipeline_ort_dynamic_kitti_fp16
9095
- *pipeline_openvino_dynamic_kitti_fp32
9196
- *pipeline_trt_dynamic_kitti_fp32
9297
- name: PointPillars

tests/regression/mmocr.yml

+11-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ onnxruntime:
4343
convert_image: *convert_image_det
4444
deploy_config: configs/mmocr/text-detection/text-detection_onnxruntime_dynamic.py
4545

46+
pipeline_ort_detection_dynamic_fp16: &pipeline_ort_detection_dynamic_fp16
47+
convert_image: *convert_image_det
48+
deploy_config: configs/mmocr/text-detection/text-detection_onnxruntime-fp16_dynamic.py
49+
4650
pipeline_ort_detection_mrcnn_dynamic_fp32: &pipeline_ort_detection_mrcnn_dynamic_fp32
4751
convert_image: *convert_image_det
4852
deploy_config: configs/mmocr/text-detection/text-detection_mrcnn_onnxruntime_dynamic.py
@@ -56,6 +60,11 @@ onnxruntime:
5660
convert_image: *convert_image_rec
5761
deploy_config: configs/mmocr/text-recognition/text-recognition_onnxruntime_dynamic.py
5862

63+
pipeline_ort_recognition_dynamic_fp16: &pipeline_ort_recognition_dynamic_fp16
64+
convert_image: *convert_image_rec
65+
deploy_config: configs/mmocr/text-recognition/text-recognition_onnxruntime-fp16_dynamic.py
66+
67+
5968
tensorrt:
6069
# ======= detection =======
6170
pipeline_trt_detection_static_fp32: &pipeline_trt_detection_static_fp32
@@ -239,7 +248,7 @@ models:
239248
- configs/textdet/dbnet/dbnet_resnet18_fpnc_1200e_icdar2015.py
240249
pipelines:
241250
- *pipeline_ts_detection_fp32
242-
- *pipeline_ort_detection_dynamic_fp32
251+
- *pipeline_ort_detection_dynamic_fp16
243252
- *pipeline_trt_detection_dynamic_fp16
244253
- *pipeline_ncnn_detection_static_fp32
245254
- *pipeline_pplnn_detection_dynamic_fp32
@@ -303,7 +312,7 @@ models:
303312
- configs/textrecog/crnn/crnn_mini-vgg_5e_mj.py
304313
pipelines:
305314
- *pipeline_ts_recognition_fp32
306-
- *pipeline_ort_recognition_dynamic_fp32
315+
- *pipeline_ort_recognition_dynamic_fp16
307316
- *pipeline_trt_recognition_dynamic_fp16_H32_C1
308317
- *pipeline_ncnn_recognition_static_fp32
309318
- *pipeline_pplnn_recognition_dynamic_fp32

0 commit comments

Comments
 (0)