Skip to content

Commit 371b191

Browse files
committedAug 27, 2024
initial commit
0 parents  commit 371b191

File tree

6 files changed

+217
-0
lines changed

6 files changed

+217
-0
lines changed
 

‎.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.DS_Store

‎README.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# ONNX MNIST on Web

‎pytorch/.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
venv/
2+
data/
3+
mnist_cnn.pt

‎pytorch/README.md

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# MNIST Model
2+
3+
## How to Run
4+
5+
- `python3 -m venv venv`
6+
- `source venv/bin/activate`
7+
- `pip install -r requirements.txt`
8+
- `python mnist.py`
9+
10+
The code is from https://github.com/pytorch/examples/tree/main/mnist.

‎pytorch/mnist.py

+200
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
# https://github.com/pytorch/examples/tree/main/mnist
2+
3+
import argparse
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
import torch.optim as optim
8+
from torchvision import datasets, transforms
9+
from torch.optim.lr_scheduler import StepLR
10+
11+
12+
class Net(nn.Module):
13+
def __init__(self):
14+
super(Net, self).__init__()
15+
self.conv1 = nn.Conv2d(1, 32, 3, 1)
16+
self.conv2 = nn.Conv2d(32, 64, 3, 1)
17+
self.dropout1 = nn.Dropout(0.25)
18+
self.dropout2 = nn.Dropout(0.5)
19+
self.fc1 = nn.Linear(9216, 128)
20+
self.fc2 = nn.Linear(128, 10)
21+
22+
def forward(self, x):
23+
x = self.conv1(x)
24+
x = F.relu(x)
25+
x = self.conv2(x)
26+
x = F.relu(x)
27+
x = F.max_pool2d(x, 2)
28+
x = self.dropout1(x)
29+
x = torch.flatten(x, 1)
30+
x = self.fc1(x)
31+
x = F.relu(x)
32+
x = self.dropout2(x)
33+
x = self.fc2(x)
34+
output = F.log_softmax(x, dim=1)
35+
return output
36+
37+
38+
def train(args, model, device, train_loader, optimizer, epoch):
39+
model.train()
40+
for batch_idx, (data, target) in enumerate(train_loader):
41+
data, target = data.to(device), target.to(device)
42+
optimizer.zero_grad()
43+
output = model(data)
44+
loss = F.nll_loss(output, target)
45+
loss.backward()
46+
optimizer.step()
47+
if batch_idx % args.log_interval == 0:
48+
print(
49+
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
50+
epoch,
51+
batch_idx * len(data),
52+
len(train_loader.dataset),
53+
100.0 * batch_idx / len(train_loader),
54+
loss.item(),
55+
)
56+
)
57+
if args.dry_run:
58+
break
59+
60+
61+
def test(model, device, test_loader):
62+
model.eval()
63+
test_loss = 0
64+
correct = 0
65+
with torch.no_grad():
66+
for data, target in test_loader:
67+
data, target = data.to(device), target.to(device)
68+
output = model(data)
69+
test_loss += F.nll_loss(
70+
output, target, reduction="sum"
71+
).item() # sum up batch loss
72+
pred = output.argmax(
73+
dim=1, keepdim=True
74+
) # get the index of the max log-probability
75+
correct += pred.eq(target.view_as(pred)).sum().item()
76+
77+
test_loss /= len(test_loader.dataset)
78+
79+
print(
80+
"\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
81+
test_loss,
82+
correct,
83+
len(test_loader.dataset),
84+
100.0 * correct / len(test_loader.dataset),
85+
)
86+
)
87+
88+
89+
def main():
90+
# Training settings
91+
parser = argparse.ArgumentParser(description="PyTorch MNIST Example")
92+
parser.add_argument(
93+
"--batch-size",
94+
type=int,
95+
default=64,
96+
metavar="N",
97+
help="input batch size for training (default: 64)",
98+
)
99+
parser.add_argument(
100+
"--test-batch-size",
101+
type=int,
102+
default=1000,
103+
metavar="N",
104+
help="input batch size for testing (default: 1000)",
105+
)
106+
parser.add_argument(
107+
"--epochs",
108+
type=int,
109+
default=14,
110+
metavar="N",
111+
help="number of epochs to train (default: 14)",
112+
)
113+
parser.add_argument(
114+
"--lr",
115+
type=float,
116+
default=1.0,
117+
metavar="LR",
118+
help="learning rate (default: 1.0)",
119+
)
120+
parser.add_argument(
121+
"--gamma",
122+
type=float,
123+
default=0.7,
124+
metavar="M",
125+
help="Learning rate step gamma (default: 0.7)",
126+
)
127+
parser.add_argument(
128+
"--no-cuda", action="store_true", default=False, help="disables CUDA training"
129+
)
130+
parser.add_argument(
131+
"--no-mps",
132+
action="store_true",
133+
default=False,
134+
help="disables macOS GPU training",
135+
)
136+
parser.add_argument(
137+
"--dry-run",
138+
action="store_true",
139+
default=False,
140+
help="quickly check a single pass",
141+
)
142+
parser.add_argument(
143+
"--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
144+
)
145+
parser.add_argument(
146+
"--log-interval",
147+
type=int,
148+
default=10,
149+
metavar="N",
150+
help="how many batches to wait before logging training status",
151+
)
152+
parser.add_argument(
153+
"--save-model",
154+
action="store_true",
155+
default=False,
156+
help="For Saving the current Model",
157+
)
158+
args = parser.parse_args()
159+
use_cuda = not args.no_cuda and torch.cuda.is_available()
160+
use_mps = not args.no_mps and torch.backends.mps.is_available()
161+
162+
torch.manual_seed(args.seed)
163+
164+
if use_cuda:
165+
device = torch.device("cuda")
166+
elif use_mps:
167+
device = torch.device("mps")
168+
else:
169+
device = torch.device("cpu")
170+
171+
train_kwargs = {"batch_size": args.batch_size}
172+
test_kwargs = {"batch_size": args.test_batch_size}
173+
if use_cuda:
174+
cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
175+
train_kwargs.update(cuda_kwargs)
176+
test_kwargs.update(cuda_kwargs)
177+
178+
transform = transforms.Compose(
179+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
180+
)
181+
dataset1 = datasets.MNIST("./data", train=True, download=True, transform=transform)
182+
dataset2 = datasets.MNIST("./data", train=False, transform=transform)
183+
train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
184+
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
185+
186+
model = Net().to(device)
187+
optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
188+
189+
scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
190+
for epoch in range(1, args.epochs + 1):
191+
train(args, model, device, train_loader, optimizer, epoch)
192+
test(model, device, test_loader)
193+
scheduler.step()
194+
195+
if args.save_model:
196+
torch.save(model.state_dict(), "mnist_cnn.pt")
197+
198+
199+
if __name__ == "__main__":
200+
main()

‎pytorch/requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
torch==2.4.0
2+
torchvision==0.19.0

0 commit comments

Comments
 (0)
Please sign in to comment.