1
+ """
2
+
3
+ Examples on how to quantize with PPQ
4
+
5
+ """
6
+ from typing import Iterable
7
+
8
+ from loguru import logger
9
+ import torch
10
+ from torch .utils .data import DataLoader
11
+ from ppq import BaseGraph , QuantizationSettingFactory , TargetPlatform
12
+ from ppq import graphwise_error_analyse , layerwise_error_analyse
13
+ from ppq .api import (
14
+ export_ppq_graph ,
15
+ quantize_onnx_model
16
+ )
17
+ import sys
18
+ from torchvision import transforms
19
+ import torchvision
20
+ import torch
21
+ from atomquant .onnx .dataloader import get_calib_dataloader_coco
22
+ import os
23
+ import cv2
24
+ import numpy as np
25
+ import onnxruntime as ort
26
+ from torchvision .datasets .coco import CocoDetection
27
+ from alfred .dl .torch .common import device
28
+
29
+
30
+ def preprocess_func (img , target ):
31
+ w = 640
32
+ h = 640
33
+ a = cv2 .resize (img , (w , h ))
34
+ a_t = np .array (a ).astype (np .float32 )
35
+ boxes = []
36
+ for t in target :
37
+ boxes .append (t ["bbox" ])
38
+ target = np .array (boxes )
39
+ a_t = torch .as_tensor (a_t )
40
+ target = torch .as_tensor (target )
41
+ return a_t , target
42
+
43
+
44
+ def collate_fn (batch ):
45
+ images , targets = zip (* batch )
46
+ if isinstance (images [0 ], torch .Tensor ):
47
+ images = torch .stack (images )
48
+ targets = torch .stack (targets )
49
+ else :
50
+ images = np .array (images )
51
+ return images
52
+
53
+
54
+ if __name__ == "__main__" :
55
+ ONNX_PATH = sys .argv [1 ]
56
+
57
+ coco_root = os .path .expanduser ("~/data/coco/images/val2017" )
58
+ anno_f = os .path .expanduser (
59
+ "~/data/coco/annotations/instances_val2017_val_val_train.json"
60
+ )
61
+
62
+ # coco_ds = CocoDetection(coco_root, anno_f, )
63
+
64
+ session = ort .InferenceSession (ONNX_PATH )
65
+ input_name = session .get_inputs ()[0 ].name
66
+
67
+ calib_dataloader = get_calib_dataloader_coco (
68
+ coco_root ,
69
+ anno_f ,
70
+ preprocess_func = preprocess_func ,
71
+ input_names = input_name ,
72
+ bs = 1 ,
73
+ max_step = 50 ,
74
+ collate_fn = collate_fn
75
+ )
76
+
77
+ REQUIRE_ANALYSE = False
78
+ BATCHSIZE = 1
79
+ # INPUT_SHAPE = [3, 224, 224]
80
+ INPUT_SHAPE = [640 , 640 , 3 ]
81
+ DEVICE = "cuda"
82
+ PLATFORM = (
83
+ TargetPlatform .ORT_OOS_INT8
84
+ )
85
+ EXECUTING_DEVICE = "cpu" # 'cuda' or 'cpu'.
86
+
87
+ # create a setting for quantizing your network with PPL CUDA.
88
+ # quant_setting = QuantizationSettingFactory.pplcuda_setting()
89
+ quant_setting = QuantizationSettingFactory .default_setting ()
90
+ quant_setting .equalization = True # use layerwise equalization algorithm.
91
+ quant_setting .dispatcher = (
92
+ "conservative" # dispatch this network in conservertive way.
93
+ )
94
+
95
+
96
+ # quantize your model.
97
+ quantized = quantize_onnx_model (
98
+ onnx_import_file = ONNX_PATH ,
99
+ calib_dataloader = calib_dataloader .dataloader_holder ,
100
+ calib_steps = 120 ,
101
+ input_shape = [BATCHSIZE ] + INPUT_SHAPE ,
102
+ setting = quant_setting ,
103
+ # collate_fn=collate_fn,
104
+ platform = PLATFORM ,
105
+ device = DEVICE ,
106
+ verbose = 0 ,
107
+ )
108
+
109
+ # Quantization Result is a PPQ BaseGraph instance.
110
+ assert isinstance (quantized , BaseGraph )
111
+
112
+ try :
113
+ if REQUIRE_ANALYSE :
114
+ print ("正计算网络量化误差(SNR),最后一层的误差应小于 0.1 以保证量化精度:" )
115
+ reports = graphwise_error_analyse (
116
+ graph = quantized ,
117
+ running_device = EXECUTING_DEVICE ,
118
+ steps = 32 ,
119
+ dataloader = calib_dataloader .dataloader_holder ,
120
+ collate_fn = lambda x : x .to (EXECUTING_DEVICE ),
121
+ )
122
+ for op , snr in reports .items ():
123
+ if snr > 0.1 :
124
+ logger .warning (f"层 { op } 的累计量化误差显著,请考虑进行优化" )
125
+ print ("正计算逐层量化误差(SNR),每一层的独立量化误差应小于 0.1 以保证量化精度:" )
126
+ layerwise_error_analyse (
127
+ graph = quantized ,
128
+ running_device = EXECUTING_DEVICE ,
129
+ interested_outputs = None ,
130
+ dataloader = calib_dataloader .dataloader_holder ,
131
+ collate_fn = lambda x : x .to (EXECUTING_DEVICE ),
132
+ )
133
+ except Exception as e :
134
+ logger .warning ('analyse got some error, but that is OK, pass it.' )
135
+
136
+
137
+ # EXPORT_TARGET = TargetPlatform.ORT_OOS_INT8
138
+ EXPORT_TARGET = TargetPlatform .TRT_INT8
139
+ # EXPORT_TARGET = TargetPlatform.TRT_INT8
140
+ os .makedirs ('Output/' , exist_ok = True )
141
+ # export quantized graph.
142
+ export_ppq_graph (
143
+ graph = quantized ,
144
+ platform = EXPORT_TARGET ,
145
+ graph_save_to = f"Output/quantized_{ EXPORT_TARGET } .onnx" ,
146
+ config_save_to = f"Output/quantized_{ EXPORT_TARGET } .json" ,
147
+ )
148
+
149
+
0 commit comments