-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcontext_encoder_test.py
71 lines (49 loc) · 1.99 KB
/
context_encoder_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import sys
import glob
from tqdm import tqdm
from PIL import Image
import torch
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variable
from generator import generator as Generator
from torchvision import transforms
from data_loader import image_loader
mask_x=mask_y=64
img_x, img_y = 128, 128
transforms_ = [
transforms.Resize((img_x, img_y)),
transforms.ToTensor(),
transforms.Normalize((.5, .5, .5), (.5, .5, .5))
]
check_point_path = sys.argv[1]
test_path = sys.argv[2]
output_path = sys.argv[3]
num_files = len(glob.glob(test_path + '*.png'))
test_data_loader = DataLoader(
image_loader(path=test_path + '*.png', transforms_=transforms_, mode='test'),
batch_size=num_files,
shuffle=False,
num_workers=10
)
generator = Generator()
generator.load_state_dict(torch.load(check_point_path)['generator_state_dict'])
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
if torch.cuda.is_available():
generator.cuda()
def save_sample():
global test_data_loader
imgs, masked_imgs, crop_ind = next(iter(test_data_loader))
imgs = Variable(imgs.type(Tensor))
masked_imgs = Variable(masked_imgs.type(Tensor))
crop_ind = crop_ind[0].item() #As center crops Hardcoding
generated_patches = generator(masked_imgs)
filled_imgs = masked_imgs.clone()
filled_imgs[:, :, crop_ind:crop_ind+mask_x, crop_ind:crop_ind+mask_y] = generated_patches
samples = torch.cat((masked_imgs.data, filled_imgs.data, imgs.data), -2)
with tqdm(total=samples.shape[0]) as t:
for i in range(samples.shape[0]):
save_image(samples[i, :, :, :], output_path + 'generated_imgs-%d.jpg'%(i+1), nrow=6, normalize=True)
t.update()
save_sample()