Skip to content

Commit 0ef099d

Browse files
authored
Update train.py
解决模型评估过程出现显存爆炸
1 parent 0df12e1 commit 0ef099d

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

train.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def train():
8080
optimizer.step()
8181

8282
acc = calculat_acc(output, target)
83-
acc_history.append(acc)
84-
loss_history.append(loss)
83+
acc_history.append(float(acc))
84+
loss_history.append(float(loss))
8585
print('train_loss: {:.4}|train_acc: {:.4}'.format(
8686
torch.mean(torch.Tensor(loss_history)),
8787
torch.mean(torch.Tensor(acc_history)),
@@ -99,7 +99,7 @@ def train():
9999
output = cnn(img)
100100

101101
acc = calculat_acc(output, target)
102-
acc_history.append(acc)
102+
acc_history.append(float(acc))
103103
loss_history.append(float(loss))
104104
print('test_loss: {:.4}|test_acc: {:.4}'.format(
105105
torch.mean(torch.Tensor(loss_history)),
@@ -110,4 +110,4 @@ def train():
110110

111111
if __name__=="__main__":
112112
train()
113-
pass
113+
pass

0 commit comments

Comments
 (0)