|
23 | 23 | "import matplotlib.pyplot as plt\n",
|
24 | 24 | "\n",
|
25 | 25 | "import torch.nn as nn\n",
|
26 |
| - "import torch.optim as optim\n" |
| 26 | + "import torch.optim as optim\n", |
| 27 | + "\n", |
| 28 | + "torch.manual_seed(0)" |
27 | 29 | ]
|
28 | 30 | },
|
29 | 31 | {
|
|
669 | 671 | }
|
670 | 672 | ],
|
671 | 673 | "source": [
|
672 |
| - "dataiter = iter(test_loader)\n", |
673 |
| - "images, labels = dataiter.next()\n", |
| 674 | + "images, labels = next(iter(test_loader))\n", |
674 | 675 | "\n",
|
675 | 676 | "outputs = model(images.to(device))\n",
|
676 | 677 | "\n",
|
|
679 | 680 | "figure = plt.figure()\n",
|
680 | 681 | "num_of_images = 20\n",
|
681 | 682 | "for index in range(num_of_images):\n",
|
682 |
| - " plt.subplot(4, 5, index+1)\n", |
683 |
| - " plt.axis('off')\n", |
684 |
| - " plt.imshow(images[index,0,:,:], cmap='gray')\n", |
| 683 | + " plt.subplot(4, 5, index + 1)\n", |
| 684 | + " plt.axis(\"off\")\n", |
| 685 | + " plt.imshow(images[index, 0, :, :], cmap=\"gray\")\n", |
685 | 686 | " prd = int(predicted[index])\n",
|
686 | 687 | " gt = int(labels[index])\n",
|
687 |
| - " plt.title(f\"pred:{prd}\\n gt:{gt}\" +\"\\nwrong!\"*(prd!=gt))\n", |
| 688 | + " plt.title(f\"pred:{prd}\\n gt:{gt}\" + \"\\nwrong!\" * (prd != gt))\n", |
688 | 689 | "plt.tight_layout()"
|
689 | 690 | ]
|
690 | 691 | }
|
|
0 commit comments