Skip to content

Commit c2589e6

Browse files
committed
make basic cnn identical
1 parent 880b6b0 commit c2589e6

File tree

2 files changed

+180
-157
lines changed

2 files changed

+180
-157
lines changed

pytorch_ipynb/cnn/cnn-basic.ipynb

+139-130
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"CPython 3.6.8\n",
2424
"IPython 7.2.0\n",
2525
"\n",
26-
"torch 1.0.0\n"
26+
"torch 1.1.0\n"
2727
]
2828
}
2929
],
@@ -171,7 +171,7 @@
171171
" \n",
172172
" # 28x28x1 => 28x28x4\n",
173173
" self.conv_1 = torch.nn.Conv2d(in_channels=1,\n",
174-
" out_channels=4,\n",
174+
" out_channels=8,\n",
175175
" kernel_size=(3, 3),\n",
176176
" stride=(1, 1),\n",
177177
" padding=1) # (1(28-1) - 28 + 3) / 2 = 1\n",
@@ -180,18 +180,27 @@
180180
" stride=(2, 2),\n",
181181
" padding=0) # (2(14-1) - 28 + 2) = 0 \n",
182182
" # 14x14x4 => 14x14x8\n",
183-
" self.conv_2 = torch.nn.Conv2d(in_channels=4,\n",
184-
" out_channels=8,\n",
183+
" self.conv_2 = torch.nn.Conv2d(in_channels=8,\n",
184+
" out_channels=16,\n",
185185
" kernel_size=(3, 3),\n",
186186
" stride=(1, 1),\n",
187187
" padding=1) # (1(14-1) - 14 + 3) / 2 = 1 \n",
188188
" # 14x14x8 => 7x7x8 \n",
189189
" self.pool_2 = torch.nn.MaxPool2d(kernel_size=(2, 2),\n",
190190
" stride=(2, 2),\n",
191191
" padding=0) # (2(7-1) - 14 + 2) = 0\n",
192-
" \n",
193-
" self.linear_1 = torch.nn.Linear(7*7*8, num_classes)\n",
194192
"\n",
193+
" self.linear_1 = torch.nn.Linear(7*7*16, num_classes)\n",
194+
"\n",
195+
" # optionally initialize weights from Gaussian;\n",
196+
" # Guassian weight init is not recommended and only for demonstration purposes\n",
197+
" for m in self.modules():\n",
198+
" if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):\n",
199+
" m.weight.data.normal_(0.0, 0.01)\n",
200+
" m.bias.data.zero_()\n",
201+
" if m.bias is not None:\n",
202+
" m.bias.detach().zero_()\n",
203+
" \n",
195204
" \n",
196205
" def forward(self, x):\n",
197206
" out = self.conv_1(x)\n",
@@ -202,7 +211,7 @@
202211
" out = F.relu(out)\n",
203212
" out = self.pool_2(out)\n",
204213
" \n",
205-
" logits = self.linear_1(out.view(-1, 7*7*8))\n",
214+
" logits = self.linear_1(out.view(-1, 7*7*16))\n",
206215
" probas = F.softmax(logits, dim=1)\n",
207216
" return logits, probas\n",
208217
"\n",
@@ -231,127 +240,127 @@
231240
"name": "stdout",
232241
"output_type": "stream",
233242
"text": [
234-
"Epoch: 001/010 | Batch 000/469 | Cost: 2.3016\n",
235-
"Epoch: 001/010 | Batch 050/469 | Cost: 2.2714\n",
236-
"Epoch: 001/010 | Batch 100/469 | Cost: 1.6118\n",
237-
"Epoch: 001/010 | Batch 150/469 | Cost: 0.7966\n",
238-
"Epoch: 001/010 | Batch 200/469 | Cost: 0.5077\n",
239-
"Epoch: 001/010 | Batch 250/469 | Cost: 0.3221\n",
240-
"Epoch: 001/010 | Batch 300/469 | Cost: 0.2850\n",
241-
"Epoch: 001/010 | Batch 350/469 | Cost: 0.3116\n",
242-
"Epoch: 001/010 | Batch 400/469 | Cost: 0.2836\n",
243-
"Epoch: 001/010 | Batch 450/469 | Cost: 0.3169\n",
244-
"Epoch: 001/010 training accuracy: 92.72%\n",
245-
"Time elapsed: 0.21 min\n",
246-
"Epoch: 002/010 | Batch 000/469 | Cost: 0.2469\n",
247-
"Epoch: 002/010 | Batch 050/469 | Cost: 0.2342\n",
248-
"Epoch: 002/010 | Batch 100/469 | Cost: 0.2883\n",
249-
"Epoch: 002/010 | Batch 150/469 | Cost: 0.2920\n",
250-
"Epoch: 002/010 | Batch 200/469 | Cost: 0.1798\n",
251-
"Epoch: 002/010 | Batch 250/469 | Cost: 0.2277\n",
252-
"Epoch: 002/010 | Batch 300/469 | Cost: 0.1747\n",
253-
"Epoch: 002/010 | Batch 350/469 | Cost: 0.2430\n",
254-
"Epoch: 002/010 | Batch 400/469 | Cost: 0.1578\n",
255-
"Epoch: 002/010 | Batch 450/469 | Cost: 0.1279\n",
256-
"Epoch: 002/010 training accuracy: 95.07%\n",
257-
"Time elapsed: 0.41 min\n",
258-
"Epoch: 003/010 | Batch 000/469 | Cost: 0.1223\n",
259-
"Epoch: 003/010 | Batch 050/469 | Cost: 0.1999\n",
260-
"Epoch: 003/010 | Batch 100/469 | Cost: 0.2212\n",
261-
"Epoch: 003/010 | Batch 150/469 | Cost: 0.0905\n",
262-
"Epoch: 003/010 | Batch 200/469 | Cost: 0.1502\n",
263-
"Epoch: 003/010 | Batch 250/469 | Cost: 0.2391\n",
264-
"Epoch: 003/010 | Batch 300/469 | Cost: 0.1108\n",
265-
"Epoch: 003/010 | Batch 350/469 | Cost: 0.1734\n",
266-
"Epoch: 003/010 | Batch 400/469 | Cost: 0.1426\n",
267-
"Epoch: 003/010 | Batch 450/469 | Cost: 0.1253\n",
268-
"Epoch: 003/010 training accuracy: 96.21%\n",
243+
"Epoch: 001/010 | Batch 000/469 | Cost: 2.3026\n",
244+
"Epoch: 001/010 | Batch 050/469 | Cost: 2.3036\n",
245+
"Epoch: 001/010 | Batch 100/469 | Cost: 2.3001\n",
246+
"Epoch: 001/010 | Batch 150/469 | Cost: 2.3050\n",
247+
"Epoch: 001/010 | Batch 200/469 | Cost: 2.2984\n",
248+
"Epoch: 001/010 | Batch 250/469 | Cost: 2.2986\n",
249+
"Epoch: 001/010 | Batch 300/469 | Cost: 2.2983\n",
250+
"Epoch: 001/010 | Batch 350/469 | Cost: 2.2941\n",
251+
"Epoch: 001/010 | Batch 400/469 | Cost: 2.2962\n",
252+
"Epoch: 001/010 | Batch 450/469 | Cost: 2.2265\n",
253+
"Epoch: 001/010 training accuracy: 65.38%\n",
254+
"Time elapsed: 0.29 min\n",
255+
"Epoch: 002/010 | Batch 000/469 | Cost: 1.8989\n",
256+
"Epoch: 002/010 | Batch 050/469 | Cost: 0.6029\n",
257+
"Epoch: 002/010 | Batch 100/469 | Cost: 0.6099\n",
258+
"Epoch: 002/010 | Batch 150/469 | Cost: 0.4786\n",
259+
"Epoch: 002/010 | Batch 200/469 | Cost: 0.4518\n",
260+
"Epoch: 002/010 | Batch 250/469 | Cost: 0.3553\n",
261+
"Epoch: 002/010 | Batch 300/469 | Cost: 0.3167\n",
262+
"Epoch: 002/010 | Batch 350/469 | Cost: 0.2241\n",
263+
"Epoch: 002/010 | Batch 400/469 | Cost: 0.2259\n",
264+
"Epoch: 002/010 | Batch 450/469 | Cost: 0.3056\n",
265+
"Epoch: 002/010 training accuracy: 93.11%\n",
269266
"Time elapsed: 0.62 min\n",
270-
"Epoch: 004/010 | Batch 000/469 | Cost: 0.1368\n",
271-
"Epoch: 004/010 | Batch 050/469 | Cost: 0.1984\n",
272-
"Epoch: 004/010 | Batch 100/469 | Cost: 0.1296\n",
273-
"Epoch: 004/010 | Batch 150/469 | Cost: 0.1439\n",
274-
"Epoch: 004/010 | Batch 200/469 | Cost: 0.1141\n",
275-
"Epoch: 004/010 | Batch 250/469 | Cost: 0.0566\n",
276-
"Epoch: 004/010 | Batch 300/469 | Cost: 0.1119\n",
277-
"Epoch: 004/010 | Batch 350/469 | Cost: 0.1777\n",
278-
"Epoch: 004/010 | Batch 400/469 | Cost: 0.2209\n",
279-
"Epoch: 004/010 | Batch 450/469 | Cost: 0.1390\n",
280-
"Epoch: 004/010 training accuracy: 96.77%\n",
281-
"Time elapsed: 0.82 min\n",
282-
"Epoch: 005/010 | Batch 000/469 | Cost: 0.1305\n",
283-
"Epoch: 005/010 | Batch 050/469 | Cost: 0.0445\n",
284-
"Epoch: 005/010 | Batch 100/469 | Cost: 0.1327\n",
285-
"Epoch: 005/010 | Batch 150/469 | Cost: 0.0846\n",
286-
"Epoch: 005/010 | Batch 200/469 | Cost: 0.0760\n",
287-
"Epoch: 005/010 | Batch 250/469 | Cost: 0.0795\n",
288-
"Epoch: 005/010 | Batch 300/469 | Cost: 0.1364\n",
289-
"Epoch: 005/010 | Batch 350/469 | Cost: 0.1419\n",
290-
"Epoch: 005/010 | Batch 400/469 | Cost: 0.0903\n",
291-
"Epoch: 005/010 | Batch 450/469 | Cost: 0.0599\n",
292-
"Epoch: 005/010 training accuracy: 97.15%\n",
293-
"Time elapsed: 1.03 min\n",
294-
"Epoch: 006/010 | Batch 000/469 | Cost: 0.0721\n",
295-
"Epoch: 006/010 | Batch 050/469 | Cost: 0.0481\n",
296-
"Epoch: 006/010 | Batch 100/469 | Cost: 0.0386\n",
297-
"Epoch: 006/010 | Batch 150/469 | Cost: 0.0421\n",
298-
"Epoch: 006/010 | Batch 200/469 | Cost: 0.1176\n",
299-
"Epoch: 006/010 | Batch 250/469 | Cost: 0.0719\n",
300-
"Epoch: 006/010 | Batch 300/469 | Cost: 0.0534\n",
301-
"Epoch: 006/010 | Batch 350/469 | Cost: 0.0230\n",
302-
"Epoch: 006/010 | Batch 400/469 | Cost: 0.0941\n",
303-
"Epoch: 006/010 | Batch 450/469 | Cost: 0.0848\n",
304-
"Epoch: 006/010 training accuracy: 97.43%\n",
305-
"Time elapsed: 1.23 min\n",
306-
"Epoch: 007/010 | Batch 000/469 | Cost: 0.1986\n",
307-
"Epoch: 007/010 | Batch 050/469 | Cost: 0.0445\n",
308-
"Epoch: 007/010 | Batch 100/469 | Cost: 0.0524\n",
309-
"Epoch: 007/010 | Batch 150/469 | Cost: 0.0639\n",
310-
"Epoch: 007/010 | Batch 200/469 | Cost: 0.0667\n",
311-
"Epoch: 007/010 | Batch 250/469 | Cost: 0.0952\n",
312-
"Epoch: 007/010 | Batch 300/469 | Cost: 0.0294\n",
313-
"Epoch: 007/010 | Batch 350/469 | Cost: 0.0974\n",
314-
"Epoch: 007/010 | Batch 400/469 | Cost: 0.1130\n",
315-
"Epoch: 007/010 | Batch 450/469 | Cost: 0.0552\n",
316-
"Epoch: 007/010 training accuracy: 97.77%\n",
317-
"Time elapsed: 1.43 min\n",
318-
"Epoch: 008/010 | Batch 000/469 | Cost: 0.1190\n",
319-
"Epoch: 008/010 | Batch 050/469 | Cost: 0.1556\n",
320-
"Epoch: 008/010 | Batch 100/469 | Cost: 0.0912\n",
321-
"Epoch: 008/010 | Batch 150/469 | Cost: 0.0401\n",
322-
"Epoch: 008/010 | Batch 200/469 | Cost: 0.0832\n",
323-
"Epoch: 008/010 | Batch 250/469 | Cost: 0.0418\n",
324-
"Epoch: 008/010 | Batch 300/469 | Cost: 0.0886\n",
325-
"Epoch: 008/010 | Batch 350/469 | Cost: 0.0844\n",
326-
"Epoch: 008/010 | Batch 400/469 | Cost: 0.0673\n",
327-
"Epoch: 008/010 | Batch 450/469 | Cost: 0.1391\n",
328-
"Epoch: 008/010 training accuracy: 97.55%\n",
329-
"Time elapsed: 1.64 min\n",
330-
"Epoch: 009/010 | Batch 000/469 | Cost: 0.0826\n",
331-
"Epoch: 009/010 | Batch 050/469 | Cost: 0.1026\n",
332-
"Epoch: 009/010 | Batch 100/469 | Cost: 0.1812\n",
333-
"Epoch: 009/010 | Batch 150/469 | Cost: 0.0658\n",
334-
"Epoch: 009/010 | Batch 200/469 | Cost: 0.0883\n",
335-
"Epoch: 009/010 | Batch 250/469 | Cost: 0.1577\n",
336-
"Epoch: 009/010 | Batch 300/469 | Cost: 0.0479\n",
337-
"Epoch: 009/010 | Batch 350/469 | Cost: 0.0779\n",
338-
"Epoch: 009/010 | Batch 400/469 | Cost: 0.0407\n",
339-
"Epoch: 009/010 | Batch 450/469 | Cost: 0.0236\n",
340-
"Epoch: 009/010 training accuracy: 97.82%\n",
341-
"Time elapsed: 1.84 min\n",
342-
"Epoch: 010/010 | Batch 000/469 | Cost: 0.0183\n",
343-
"Epoch: 010/010 | Batch 050/469 | Cost: 0.0740\n",
344-
"Epoch: 010/010 | Batch 100/469 | Cost: 0.0425\n",
345-
"Epoch: 010/010 | Batch 150/469 | Cost: 0.0332\n",
346-
"Epoch: 010/010 | Batch 200/469 | Cost: 0.0795\n",
347-
"Epoch: 010/010 | Batch 250/469 | Cost: 0.0568\n",
348-
"Epoch: 010/010 | Batch 300/469 | Cost: 0.1070\n",
349-
"Epoch: 010/010 | Batch 350/469 | Cost: 0.1660\n",
350-
"Epoch: 010/010 | Batch 400/469 | Cost: 0.0204\n",
351-
"Epoch: 010/010 | Batch 450/469 | Cost: 0.0613\n",
352-
"Epoch: 010/010 training accuracy: 97.77%\n",
353-
"Time elapsed: 2.04 min\n",
354-
"Total Training Time: 2.04 min\n"
267+
"Epoch: 003/010 | Batch 000/469 | Cost: 0.3313\n",
268+
"Epoch: 003/010 | Batch 050/469 | Cost: 0.1042\n",
269+
"Epoch: 003/010 | Batch 100/469 | Cost: 0.1328\n",
270+
"Epoch: 003/010 | Batch 150/469 | Cost: 0.2803\n",
271+
"Epoch: 003/010 | Batch 200/469 | Cost: 0.0975\n",
272+
"Epoch: 003/010 | Batch 250/469 | Cost: 0.1839\n",
273+
"Epoch: 003/010 | Batch 300/469 | Cost: 0.1774\n",
274+
"Epoch: 003/010 | Batch 350/469 | Cost: 0.1143\n",
275+
"Epoch: 003/010 | Batch 400/469 | Cost: 0.1753\n",
276+
"Epoch: 003/010 | Batch 450/469 | Cost: 0.1543\n",
277+
"Epoch: 003/010 training accuracy: 95.68%\n",
278+
"Time elapsed: 0.93 min\n",
279+
"Epoch: 004/010 | Batch 000/469 | Cost: 0.1057\n",
280+
"Epoch: 004/010 | Batch 050/469 | Cost: 0.1035\n",
281+
"Epoch: 004/010 | Batch 100/469 | Cost: 0.1851\n",
282+
"Epoch: 004/010 | Batch 150/469 | Cost: 0.1608\n",
283+
"Epoch: 004/010 | Batch 200/469 | Cost: 0.1458\n",
284+
"Epoch: 004/010 | Batch 250/469 | Cost: 0.1913\n",
285+
"Epoch: 004/010 | Batch 300/469 | Cost: 0.1295\n",
286+
"Epoch: 004/010 | Batch 350/469 | Cost: 0.1518\n",
287+
"Epoch: 004/010 | Batch 400/469 | Cost: 0.1717\n",
288+
"Epoch: 004/010 | Batch 450/469 | Cost: 0.0792\n",
289+
"Epoch: 004/010 training accuracy: 96.46%\n",
290+
"Time elapsed: 1.24 min\n",
291+
"Epoch: 005/010 | Batch 000/469 | Cost: 0.0905\n",
292+
"Epoch: 005/010 | Batch 050/469 | Cost: 0.1622\n",
293+
"Epoch: 005/010 | Batch 100/469 | Cost: 0.1934\n",
294+
"Epoch: 005/010 | Batch 150/469 | Cost: 0.1874\n",
295+
"Epoch: 005/010 | Batch 200/469 | Cost: 0.0742\n",
296+
"Epoch: 005/010 | Batch 250/469 | Cost: 0.1056\n",
297+
"Epoch: 005/010 | Batch 300/469 | Cost: 0.0997\n",
298+
"Epoch: 005/010 | Batch 350/469 | Cost: 0.0948\n",
299+
"Epoch: 005/010 | Batch 400/469 | Cost: 0.0575\n",
300+
"Epoch: 005/010 | Batch 450/469 | Cost: 0.1157\n",
301+
"Epoch: 005/010 training accuracy: 96.97%\n",
302+
"Time elapsed: 1.56 min\n",
303+
"Epoch: 006/010 | Batch 000/469 | Cost: 0.1326\n",
304+
"Epoch: 006/010 | Batch 050/469 | Cost: 0.1549\n",
305+
"Epoch: 006/010 | Batch 100/469 | Cost: 0.0784\n",
306+
"Epoch: 006/010 | Batch 150/469 | Cost: 0.0898\n",
307+
"Epoch: 006/010 | Batch 200/469 | Cost: 0.0991\n",
308+
"Epoch: 006/010 | Batch 250/469 | Cost: 0.0965\n",
309+
"Epoch: 006/010 | Batch 300/469 | Cost: 0.0477\n",
310+
"Epoch: 006/010 | Batch 350/469 | Cost: 0.0712\n",
311+
"Epoch: 006/010 | Batch 400/469 | Cost: 0.1109\n",
312+
"Epoch: 006/010 | Batch 450/469 | Cost: 0.0325\n",
313+
"Epoch: 006/010 training accuracy: 97.60%\n",
314+
"Time elapsed: 1.88 min\n",
315+
"Epoch: 007/010 | Batch 000/469 | Cost: 0.0665\n",
316+
"Epoch: 007/010 | Batch 050/469 | Cost: 0.0868\n",
317+
"Epoch: 007/010 | Batch 100/469 | Cost: 0.0427\n",
318+
"Epoch: 007/010 | Batch 150/469 | Cost: 0.0385\n",
319+
"Epoch: 007/010 | Batch 200/469 | Cost: 0.0611\n",
320+
"Epoch: 007/010 | Batch 250/469 | Cost: 0.0484\n",
321+
"Epoch: 007/010 | Batch 300/469 | Cost: 0.1288\n",
322+
"Epoch: 007/010 | Batch 350/469 | Cost: 0.0309\n",
323+
"Epoch: 007/010 | Batch 400/469 | Cost: 0.0359\n",
324+
"Epoch: 007/010 | Batch 450/469 | Cost: 0.0139\n",
325+
"Epoch: 007/010 training accuracy: 97.64%\n",
326+
"Time elapsed: 2.19 min\n",
327+
"Epoch: 008/010 | Batch 000/469 | Cost: 0.0939\n",
328+
"Epoch: 008/010 | Batch 050/469 | Cost: 0.1478\n",
329+
"Epoch: 008/010 | Batch 100/469 | Cost: 0.0769\n",
330+
"Epoch: 008/010 | Batch 150/469 | Cost: 0.0713\n",
331+
"Epoch: 008/010 | Batch 200/469 | Cost: 0.1272\n",
332+
"Epoch: 008/010 | Batch 250/469 | Cost: 0.0446\n",
333+
"Epoch: 008/010 | Batch 300/469 | Cost: 0.0525\n",
334+
"Epoch: 008/010 | Batch 350/469 | Cost: 0.1729\n",
335+
"Epoch: 008/010 | Batch 400/469 | Cost: 0.0672\n",
336+
"Epoch: 008/010 | Batch 450/469 | Cost: 0.0754\n",
337+
"Epoch: 008/010 training accuracy: 96.67%\n",
338+
"Time elapsed: 2.50 min\n",
339+
"Epoch: 009/010 | Batch 000/469 | Cost: 0.0988\n",
340+
"Epoch: 009/010 | Batch 050/469 | Cost: 0.0409\n",
341+
"Epoch: 009/010 | Batch 100/469 | Cost: 0.1046\n",
342+
"Epoch: 009/010 | Batch 150/469 | Cost: 0.0523\n",
343+
"Epoch: 009/010 | Batch 200/469 | Cost: 0.0815\n",
344+
"Epoch: 009/010 | Batch 250/469 | Cost: 0.0811\n",
345+
"Epoch: 009/010 | Batch 300/469 | Cost: 0.0416\n",
346+
"Epoch: 009/010 | Batch 350/469 | Cost: 0.0747\n",
347+
"Epoch: 009/010 | Batch 400/469 | Cost: 0.0467\n",
348+
"Epoch: 009/010 | Batch 450/469 | Cost: 0.0669\n",
349+
"Epoch: 009/010 training accuracy: 97.90%\n",
350+
"Time elapsed: 2.78 min\n",
351+
"Epoch: 010/010 | Batch 000/469 | Cost: 0.0257\n",
352+
"Epoch: 010/010 | Batch 050/469 | Cost: 0.0357\n",
353+
"Epoch: 010/010 | Batch 100/469 | Cost: 0.1469\n",
354+
"Epoch: 010/010 | Batch 150/469 | Cost: 0.0170\n",
355+
"Epoch: 010/010 | Batch 200/469 | Cost: 0.0493\n",
356+
"Epoch: 010/010 | Batch 250/469 | Cost: 0.0489\n",
357+
"Epoch: 010/010 | Batch 300/469 | Cost: 0.1348\n",
358+
"Epoch: 010/010 | Batch 350/469 | Cost: 0.0815\n",
359+
"Epoch: 010/010 | Batch 400/469 | Cost: 0.0552\n",
360+
"Epoch: 010/010 | Batch 450/469 | Cost: 0.0422\n",
361+
"Epoch: 010/010 training accuracy: 97.99%\n",
362+
"Time elapsed: 3.02 min\n",
363+
"Total Training Time: 3.02 min\n"
355364
]
356365
}
357366
],
@@ -418,7 +427,7 @@
418427
"name": "stdout",
419428
"output_type": "stream",
420429
"text": [
421-
"Test accuracy: 97.77%\n"
430+
"Test accuracy: 97.97%\n"
422431
]
423432
}
424433
],
@@ -437,7 +446,7 @@
437446
"output_type": "stream",
438447
"text": [
439448
"numpy 1.15.4\n",
440-
"torch 1.0.0\n",
449+
"torch 1.1.0\n",
441450
"\n"
442451
]
443452
}
@@ -463,7 +472,7 @@
463472
"name": "python",
464473
"nbconvert_exporter": "python",
465474
"pygments_lexer": "ipython3",
466-
"version": "3.7.1"
475+
"version": "3.6.8"
467476
},
468477
"toc": {
469478
"nav_menu": {},

0 commit comments

Comments
 (0)