Skip to content

Commit 0b18253

Browse files
author
hanyoseob
committed
Update functions details
1 parent df5a374 commit 0b18253

File tree

1 file changed

+27
-21
lines changed

1 file changed

+27
-21
lines changed

train.py

+27-21
Original file line numberDiff line numberDiff line change
@@ -331,15 +331,20 @@ def test(self):
331331
# for i, data in enumerate(loader_train, 1):
332332
for i, (input, label) in enumerate(loader_test, 1):
333333

334-
input = input.to(device)
335-
label = label.to(device)
336334

337335
# forward netG
338-
input_stn = net_STN(input)
339-
output = net_CLS(input_stn)
340-
pred = output.max(1, keepdim=True)[1]
336+
if self.scope == 'stn':
337+
input_stn = net_STN(input)
338+
output = net_CLS(input_stn)
339+
pred = output.max(1, keepdim=True)[1]
341340

342-
loss_CLS = fn_CLS(output, label)
341+
loss_CLS = fn_CLS(output, label)
342+
343+
elif self.scope == 'cls':
344+
output = net_CLS(input)
345+
pred = output.max(1, keepdim=True)[1]
346+
347+
loss_CLS = fn_CLS(output, label)
343348

344349
# get losses
345350
loss_CLS_test += [loss_CLS.item()]
@@ -348,21 +353,22 @@ def test(self):
348353
print('TEST: BATCH %04d/%04d: CLS: %.4f ACC: %.4f' % (i, num_batch_test, mean(loss_CLS_test), 100 * mean(pred_CLS_test)))
349354

350355
## show output
351-
input = transform_inv(input)
352-
input_stn = transform_inv(input_stn)
353-
354-
for j in range(input.shape[0]):
355-
name = batch_size * (i - 1) + j
356-
fileset = {'name': name,
357-
'input': "%04d-input.png" % name,
358-
'input_stn': "%04d-input_stn.png" % name}
359-
360-
if nch_in == 3:
361-
plt.imsave(os.path.join(dir_result_save, fileset['input']), input[j, :, :, :].squeeze())
362-
plt.imsave(os.path.join(dir_result_save, fileset['input_stn']), input_stn[j, :, :, :].squeeze())
363-
elif nch_in == 1:
364-
plt.imsave(os.path.join(dir_result_save, fileset['input']), input[j, :, :, :].squeeze(), cmap='gray')
365-
plt.imsave(os.path.join(dir_result_save, fileset['input_stn']), input_stn[j, :, :, :].squeeze(), cmap='gray')
356+
if self.scope == 'stn':
357+
input = transform_inv(input)
358+
input_stn = transform_inv(input_stn)
359+
360+
for j in range(input.shape[0]):
361+
name = batch_size * (i - 1) + j
362+
fileset = {'name': name,
363+
'input': "%04d-input.png" % name,
364+
'input_stn': "%04d-input_stn.png" % name}
365+
366+
if nch_in == 3:
367+
plt.imsave(os.path.join(dir_result_save, fileset['input']), input[j, :, :, :].squeeze())
368+
plt.imsave(os.path.join(dir_result_save, fileset['input_stn']), input_stn[j, :, :, :].squeeze())
369+
elif nch_in == 1:
370+
plt.imsave(os.path.join(dir_result_save, fileset['input']), input[j, :, :, :].squeeze(), cmap='gray')
371+
plt.imsave(os.path.join(dir_result_save, fileset['input_stn']), input_stn[j, :, :, :].squeeze(), cmap='gray')
366372

367373
append_index(dir_result, fileset)
368374

0 commit comments

Comments
 (0)