|
1 | 1 | # Convert pt file to onnx file
|
2 | 2 | import torch
|
3 | 3 | from mnist import Net
|
| 4 | +from onnxruntime.quantization import quantize_dynamic, QuantType, quant_pre_process |
4 | 5 | from pathlib import Path
|
5 | 6 |
|
| 7 | +MODEL_PATH = Path("./mnist_cnn.pt") |
| 8 | +INTERMEDIATE_OUTPUT_DIR = Path(".") |
| 9 | +WEB_OUTPUT_DIR = Path("../web/public") |
| 10 | +ONNX_OUTPUT = "mnist_cnn.onnx" |
| 11 | +ONNX_QUANT_PREPROCESS_OUTPUT = "mnist_cnn.infer.onnx" |
| 12 | +ONNX_QUANT_OUTPUT = "mnist_cnn.quant.onnx" |
| 13 | + |
| 14 | + |
| 15 | +def quantizate_onnx_model(): |
| 16 | + # Quantization |
| 17 | + # https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification/cpu/ReadMe.md |
| 18 | + quant_pre_process( |
| 19 | + input_model=INTERMEDIATE_OUTPUT_DIR / ONNX_OUTPUT, |
| 20 | + output_model_path=INTERMEDIATE_OUTPUT_DIR / ONNX_QUANT_PREPROCESS_OUTPUT, |
| 21 | + ) |
| 22 | + quantize_dynamic( |
| 23 | + model_input=INTERMEDIATE_OUTPUT_DIR / ONNX_QUANT_PREPROCESS_OUTPUT, |
| 24 | + model_output=WEB_OUTPUT_DIR / ONNX_QUANT_OUTPUT, |
| 25 | + weight_type=QuantType.QUInt8, # bug: https://github.com/microsoft/onnxruntime/issues/15888#issuecomment-1856864610 |
| 26 | + ) |
| 27 | + |
6 | 28 |
|
7 | 29 | def main():
|
8 |
| - MODEL_PATH = Path("./mnist_cnn.pt") |
9 | 30 | mnist_model = Net()
|
10 | 31 | mnist_model.load_state_dict(torch.load(MODEL_PATH))
|
11 | 32 | mnist_model.eval()
|
12 | 33 | dymmy_input = torch.zeros(1, 1, 28, 28)
|
13 | 34 | torch.onnx.export(
|
14 |
| - mnist_model, dymmy_input, "../web/public/mnist_cnn.onnx", verbose=True |
| 35 | + mnist_model, dymmy_input, INTERMEDIATE_OUTPUT_DIR / ONNX_OUTPUT, verbose=True |
15 | 36 | )
|
| 37 | + quantizate_onnx_model() |
16 | 38 |
|
17 | 39 |
|
18 | 40 | if __name__ == "__main__":
|
|
0 commit comments