Skip to content

Commit 62d9d32

Browse files
feat: implement script to export mnist model from pytorch to onnx
1 parent 371b191 commit 62d9d32

File tree

5 files changed

+31
-3
lines changed

5 files changed

+31
-3
lines changed

.vscode/settings.json

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"python.analysis.extraPaths": ["./pytorch/.venv/lib/python3.12/site-packages/"]
3+
}

pytorch/.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
venv/
1+
.venv/
22
data/
3+
__pycache__/
34
mnist_cnn.pt
5+
mnist_cnn.onnx

pytorch/README.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22

33
## How to Run
44

5-
- `python3 -m venv venv`
6-
- `source venv/bin/activate`
5+
- `python3 -m venv .venv`
6+
- `source .venv/bin/activate`
77
- `pip install -r requirements.txt`
88
- `python mnist.py`
99

1010
The code is from https://github.com/pytorch/examples/tree/main/mnist.
11+
12+
### Export Model from Pytorch to ONNX
13+
14+
- `python export_to_onnx.py`

pytorch/export_to_onnx.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# Convert pt file to onnx file
2+
import torch
3+
from mnist import Net
4+
from pathlib import Path
5+
6+
7+
def main():
8+
MODEL_PATH = Path("./mnist_cnn.pt")
9+
mnist_model = Net()
10+
mnist_model.load_state_dict(torch.load(MODEL_PATH))
11+
mnist_model.eval()
12+
dymmy_input = torch.zeros(1, 1, 28, 28)
13+
torch.onnx.export(mnist_model, dymmy_input, "mnist_cnn.onnx", verbose=True)
14+
15+
16+
if __name__ == "__main__":
17+
main()

pytorch/requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
torch==2.4.0
22
torchvision==0.19.0
3+
onnx==1.16.2
4+
onnxruntime==1.19.0

0 commit comments

Comments
 (0)