Skip to content

Commit d6fe9ad

Browse files
feat: enable quantization on onnx mnist model (#2)
1 parent 3f88f84 commit d6fe9ad

File tree

6 files changed

+33
-3
lines changed

6 files changed

+33
-3
lines changed

README.md

+7
Original file line numberDiff line numberDiff line change
@@ -1 +1,8 @@
11
# ONNX MNIST on Web
2+
3+
Run MNIST model on web using ONNX runtime.
4+
5+
## How to Run
6+
7+
- [Pytorch README](./pytorch/README.md)
8+
- [Web README](./web/README.md)

pytorch/.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ data/
33
__pycache__/
44
mnist_cnn.pt
55
mnist_cnn.onnx
6+
mnist_cnn.infer.onnx

pytorch/export_to_onnx.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,40 @@
11
# Convert pt file to onnx file
22
import torch
33
from mnist import Net
4+
from onnxruntime.quantization import quantize_dynamic, QuantType, quant_pre_process
45
from pathlib import Path
56

7+
MODEL_PATH = Path("./mnist_cnn.pt")
8+
INTERMEDIATE_OUTPUT_DIR = Path(".")
9+
WEB_OUTPUT_DIR = Path("../web/public")
10+
ONNX_OUTPUT = "mnist_cnn.onnx"
11+
ONNX_QUANT_PREPROCESS_OUTPUT = "mnist_cnn.infer.onnx"
12+
ONNX_QUANT_OUTPUT = "mnist_cnn.quant.onnx"
13+
14+
15+
def quantizate_onnx_model():
16+
# Quantization
17+
# https://github.com/microsoft/onnxruntime-inference-examples/blob/main/quantization/image_classification/cpu/ReadMe.md
18+
quant_pre_process(
19+
input_model=INTERMEDIATE_OUTPUT_DIR / ONNX_OUTPUT,
20+
output_model_path=INTERMEDIATE_OUTPUT_DIR / ONNX_QUANT_PREPROCESS_OUTPUT,
21+
)
22+
quantize_dynamic(
23+
model_input=INTERMEDIATE_OUTPUT_DIR / ONNX_QUANT_PREPROCESS_OUTPUT,
24+
model_output=WEB_OUTPUT_DIR / ONNX_QUANT_OUTPUT,
25+
weight_type=QuantType.QUInt8, # bug: https://github.com/microsoft/onnxruntime/issues/15888#issuecomment-1856864610
26+
)
27+
628

729
def main():
8-
MODEL_PATH = Path("./mnist_cnn.pt")
930
mnist_model = Net()
1031
mnist_model.load_state_dict(torch.load(MODEL_PATH))
1132
mnist_model.eval()
1233
dymmy_input = torch.zeros(1, 1, 28, 28)
1334
torch.onnx.export(
14-
mnist_model, dymmy_input, "../web/public/mnist_cnn.onnx", verbose=True
35+
mnist_model, dymmy_input, INTERMEDIATE_OUTPUT_DIR / ONNX_OUTPUT, verbose=True
1536
)
37+
quantizate_onnx_model()
1638

1739

1840
if __name__ == "__main__":

web/public/mnist_cnn.onnx

-4.6 MB
Binary file not shown.

web/public/mnist_cnn.quant.onnx

1.16 MB
Binary file not shown.

web/src/utils/mnist.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ ort.env.wasm.wasmPaths = './dist/'; // defined in vite.config.ts as viteStaticCo
55
export const MNIST_IMAGE_SIDE_SIZE = 28;
66

77
export const initOnnx = (): Promise<ort.InferenceSession> => {
8-
const session = ort.InferenceSession.create('./mnist_cnn.onnx', {
8+
const session = ort.InferenceSession.create('./mnist_cnn.quant.onnx', {
99
enableProfiling: true,
1010
executionProviders: ['wasm'],
1111
});

0 commit comments

Comments
 (0)