Skip to content

Commit 6514ba7

Browse files
author
Your Name
committed
add damn it ppq
1 parent e002b88 commit 6514ba7

File tree

2 files changed

+150
-0
lines changed

2 files changed

+150
-0
lines changed

Diff for: deploy/quant_atom/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Output/

Diff for: deploy/quant_atom/qt_ppq_sinst.py

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
"""
2+
3+
Examples on how to quantize with PPQ
4+
5+
"""
6+
from typing import Iterable
7+
8+
from loguru import logger
9+
import torch
10+
from torch.utils.data import DataLoader
11+
from ppq import BaseGraph, QuantizationSettingFactory, TargetPlatform
12+
from ppq import graphwise_error_analyse, layerwise_error_analyse
13+
from ppq.api import (
14+
export_ppq_graph,
15+
quantize_onnx_model
16+
)
17+
import sys
18+
from torchvision import transforms
19+
import torchvision
20+
import torch
21+
from atomquant.onnx.dataloader import get_calib_dataloader_coco
22+
import os
23+
import cv2
24+
import numpy as np
25+
import onnxruntime as ort
26+
from torchvision.datasets.coco import CocoDetection
27+
from alfred.dl.torch.common import device
28+
29+
30+
def preprocess_func(img, target):
31+
w = 640
32+
h = 640
33+
a = cv2.resize(img, (w, h))
34+
a_t = np.array(a).astype(np.float32)
35+
boxes = []
36+
for t in target:
37+
boxes.append(t["bbox"])
38+
target = np.array(boxes)
39+
a_t = torch.as_tensor(a_t)
40+
target = torch.as_tensor(target)
41+
return a_t, target
42+
43+
44+
def collate_fn(batch):
45+
images, targets = zip(*batch)
46+
if isinstance(images[0], torch.Tensor):
47+
images = torch.stack(images)
48+
targets = torch.stack(targets)
49+
else:
50+
images = np.array(images)
51+
return images
52+
53+
54+
if __name__ == "__main__":
55+
ONNX_PATH = sys.argv[1]
56+
57+
coco_root = os.path.expanduser("~/data/coco/images/val2017")
58+
anno_f = os.path.expanduser(
59+
"~/data/coco/annotations/instances_val2017_val_val_train.json"
60+
)
61+
62+
# coco_ds = CocoDetection(coco_root, anno_f, )
63+
64+
session = ort.InferenceSession(ONNX_PATH)
65+
input_name = session.get_inputs()[0].name
66+
67+
calib_dataloader = get_calib_dataloader_coco(
68+
coco_root,
69+
anno_f,
70+
preprocess_func=preprocess_func,
71+
input_names=input_name,
72+
bs=1,
73+
max_step=50,
74+
collate_fn=collate_fn
75+
)
76+
77+
REQUIRE_ANALYSE = False
78+
BATCHSIZE = 1
79+
# INPUT_SHAPE = [3, 224, 224]
80+
INPUT_SHAPE = [640, 640, 3]
81+
DEVICE = "cuda"
82+
PLATFORM = (
83+
TargetPlatform.ORT_OOS_INT8
84+
)
85+
EXECUTING_DEVICE = "cpu" # 'cuda' or 'cpu'.
86+
87+
# create a setting for quantizing your network with PPL CUDA.
88+
# quant_setting = QuantizationSettingFactory.pplcuda_setting()
89+
quant_setting = QuantizationSettingFactory.default_setting()
90+
quant_setting.equalization = True # use layerwise equalization algorithm.
91+
quant_setting.dispatcher = (
92+
"conservative" # dispatch this network in conservertive way.
93+
)
94+
95+
96+
# quantize your model.
97+
quantized = quantize_onnx_model(
98+
onnx_import_file=ONNX_PATH,
99+
calib_dataloader=calib_dataloader.dataloader_holder,
100+
calib_steps=120,
101+
input_shape=[BATCHSIZE] + INPUT_SHAPE,
102+
setting=quant_setting,
103+
# collate_fn=collate_fn,
104+
platform=PLATFORM,
105+
device=DEVICE,
106+
verbose=0,
107+
)
108+
109+
# Quantization Result is a PPQ BaseGraph instance.
110+
assert isinstance(quantized, BaseGraph)
111+
112+
try:
113+
if REQUIRE_ANALYSE:
114+
print("正计算网络量化误差(SNR),最后一层的误差应小于 0.1 以保证量化精度:")
115+
reports = graphwise_error_analyse(
116+
graph=quantized,
117+
running_device=EXECUTING_DEVICE,
118+
steps=32,
119+
dataloader=calib_dataloader.dataloader_holder,
120+
collate_fn=lambda x: x.to(EXECUTING_DEVICE),
121+
)
122+
for op, snr in reports.items():
123+
if snr > 0.1:
124+
logger.warning(f"层 {op} 的累计量化误差显著,请考虑进行优化")
125+
print("正计算逐层量化误差(SNR),每一层的独立量化误差应小于 0.1 以保证量化精度:")
126+
layerwise_error_analyse(
127+
graph=quantized,
128+
running_device=EXECUTING_DEVICE,
129+
interested_outputs=None,
130+
dataloader=calib_dataloader.dataloader_holder,
131+
collate_fn=lambda x: x.to(EXECUTING_DEVICE),
132+
)
133+
except Exception as e:
134+
logger.warning('analyse got some error, but that is OK, pass it.')
135+
136+
137+
# EXPORT_TARGET = TargetPlatform.ORT_OOS_INT8
138+
EXPORT_TARGET = TargetPlatform.TRT_INT8
139+
# EXPORT_TARGET = TargetPlatform.TRT_INT8
140+
os.makedirs('Output/', exist_ok=True)
141+
# export quantized graph.
142+
export_ppq_graph(
143+
graph=quantized,
144+
platform=EXPORT_TARGET,
145+
graph_save_to=f"Output/quantized_{EXPORT_TARGET}.onnx",
146+
config_save_to=f"Output/quantized_{EXPORT_TARGET}.json",
147+
)
148+
149+

0 commit comments

Comments
 (0)