@@ -331,15 +331,20 @@ def test(self):
331
331
# for i, data in enumerate(loader_train, 1):
332
332
for i , (input , label ) in enumerate (loader_test , 1 ):
333
333
334
- input = input .to (device )
335
- label = label .to (device )
336
334
337
335
# forward netG
338
- input_stn = net_STN (input )
339
- output = net_CLS (input_stn )
340
- pred = output .max (1 , keepdim = True )[1 ]
336
+ if self .scope == 'stn' :
337
+ input_stn = net_STN (input )
338
+ output = net_CLS (input_stn )
339
+ pred = output .max (1 , keepdim = True )[1 ]
341
340
342
- loss_CLS = fn_CLS (output , label )
341
+ loss_CLS = fn_CLS (output , label )
342
+
343
+ elif self .scope == 'cls' :
344
+ output = net_CLS (input )
345
+ pred = output .max (1 , keepdim = True )[1 ]
346
+
347
+ loss_CLS = fn_CLS (output , label )
343
348
344
349
# get losses
345
350
loss_CLS_test += [loss_CLS .item ()]
@@ -348,21 +353,22 @@ def test(self):
348
353
print ('TEST: BATCH %04d/%04d: CLS: %.4f ACC: %.4f' % (i , num_batch_test , mean (loss_CLS_test ), 100 * mean (pred_CLS_test )))
349
354
350
355
## show output
351
- input = transform_inv (input )
352
- input_stn = transform_inv (input_stn )
353
-
354
- for j in range (input .shape [0 ]):
355
- name = batch_size * (i - 1 ) + j
356
- fileset = {'name' : name ,
357
- 'input' : "%04d-input.png" % name ,
358
- 'input_stn' : "%04d-input_stn.png" % name }
359
-
360
- if nch_in == 3 :
361
- plt .imsave (os .path .join (dir_result_save , fileset ['input' ]), input [j , :, :, :].squeeze ())
362
- plt .imsave (os .path .join (dir_result_save , fileset ['input_stn' ]), input_stn [j , :, :, :].squeeze ())
363
- elif nch_in == 1 :
364
- plt .imsave (os .path .join (dir_result_save , fileset ['input' ]), input [j , :, :, :].squeeze (), cmap = 'gray' )
365
- plt .imsave (os .path .join (dir_result_save , fileset ['input_stn' ]), input_stn [j , :, :, :].squeeze (), cmap = 'gray' )
356
+ if self .scope == 'stn' :
357
+ input = transform_inv (input )
358
+ input_stn = transform_inv (input_stn )
359
+
360
+ for j in range (input .shape [0 ]):
361
+ name = batch_size * (i - 1 ) + j
362
+ fileset = {'name' : name ,
363
+ 'input' : "%04d-input.png" % name ,
364
+ 'input_stn' : "%04d-input_stn.png" % name }
365
+
366
+ if nch_in == 3 :
367
+ plt .imsave (os .path .join (dir_result_save , fileset ['input' ]), input [j , :, :, :].squeeze ())
368
+ plt .imsave (os .path .join (dir_result_save , fileset ['input_stn' ]), input_stn [j , :, :, :].squeeze ())
369
+ elif nch_in == 1 :
370
+ plt .imsave (os .path .join (dir_result_save , fileset ['input' ]), input [j , :, :, :].squeeze (), cmap = 'gray' )
371
+ plt .imsave (os .path .join (dir_result_save , fileset ['input_stn' ]), input_stn [j , :, :, :].squeeze (), cmap = 'gray' )
366
372
367
373
append_index (dir_result , fileset )
368
374
0 commit comments