Skip to content

Commit 236184f

Browse files
authored
Fixing invalid index of a 0-dim tensor.
Fixing invalid index of a 0-dim tensor. This will be an error in PyTorch 0.5. Use tensor.item() to convert a 0-dim tensor to a Python number
1 parent 29f5b4f commit 236184f

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

chapter3_NN/deep-nn.ipynb

+4-4
Original file line numberDiff line numberDiff line change
@@ -507,10 +507,10 @@
507507
" loss.backward()\n",
508508
" optimizer.step()\n",
509509
" # 记录误差\n",
510-
" train_loss += loss.data[0]\n",
510+
" train_loss += loss.item()\n",
511511
" # 计算分类的准确率\n",
512512
" _, pred = out.max(1)\n",
513-
" num_correct = (pred == label).sum().data[0]\n",
513+
" num_correct = (pred == label).sum().item()\n",
514514
" acc = num_correct / im.shape[0]\n",
515515
" train_acc += acc\n",
516516
" \n",
@@ -526,10 +526,10 @@
526526
" out = net(im)\n",
527527
" loss = criterion(out, label)\n",
528528
" # 记录误差\n",
529-
" eval_loss += loss.data[0]\n",
529+
" eval_loss += loss.item()\n",
530530
" # 记录准确率\n",
531531
" _, pred = out.max(1)\n",
532-
" num_correct = (pred == label).sum().data[0]\n",
532+
" num_correct = (pred == label).sum().item()\n",
533533
" acc = num_correct / im.shape[0]\n",
534534
" eval_acc += acc\n",
535535
" \n",

0 commit comments

Comments
 (0)