18
18
import torch .nn .functional as F
19
19
import torch
20
20
21
- os .makedirs (' images' , exist_ok = True )
21
+ os .makedirs (" images" , exist_ok = True )
22
22
23
23
parser = argparse .ArgumentParser ()
24
- parser .add_argument (' --n_epochs' , type = int , default = 200 , help = ' number of epochs of training' )
25
- parser .add_argument (' --batch_size' , type = int , default = 8 , help = ' size of the batches' )
26
- parser .add_argument (' --dataset_name' , type = str , default = ' img_align_celeba' , help = ' name of the dataset' )
27
- parser .add_argument (' --lr' , type = float , default = 0.0002 , help = ' adam: learning rate' )
28
- parser .add_argument (' --b1' , type = float , default = 0.5 , help = ' adam: decay of first order momentum of gradient' )
29
- parser .add_argument (' --b2' , type = float , default = 0.999 , help = ' adam: decay of first order momentum of gradient' )
30
- parser .add_argument (' --n_cpu' , type = int , default = 8 , help = ' number of cpu threads to use during batch generation' )
31
- parser .add_argument (' --latent_dim' , type = int , default = 100 , help = ' dimensionality of the latent space' )
32
- parser .add_argument (' --img_size' , type = int , default = 128 , help = ' size of each image dimension' )
33
- parser .add_argument (' --mask_size' , type = int , default = 32 , help = ' size of random mask' )
34
- parser .add_argument (' --channels' , type = int , default = 3 , help = ' number of image channels' )
35
- parser .add_argument (' --sample_interval' , type = int , default = 500 , help = ' interval between image sampling' )
24
+ parser .add_argument (" --n_epochs" , type = int , default = 200 , help = " number of epochs of training" )
25
+ parser .add_argument (" --batch_size" , type = int , default = 8 , help = " size of the batches" )
26
+ parser .add_argument (" --dataset_name" , type = str , default = " img_align_celeba" , help = " name of the dataset" )
27
+ parser .add_argument (" --lr" , type = float , default = 0.0002 , help = " adam: learning rate" )
28
+ parser .add_argument (" --b1" , type = float , default = 0.5 , help = " adam: decay of first order momentum of gradient" )
29
+ parser .add_argument (" --b2" , type = float , default = 0.999 , help = " adam: decay of first order momentum of gradient" )
30
+ parser .add_argument (" --n_cpu" , type = int , default = 8 , help = " number of cpu threads to use during batch generation" )
31
+ parser .add_argument (" --latent_dim" , type = int , default = 100 , help = " dimensionality of the latent space" )
32
+ parser .add_argument (" --img_size" , type = int , default = 128 , help = " size of each image dimension" )
33
+ parser .add_argument (" --mask_size" , type = int , default = 32 , help = " size of random mask" )
34
+ parser .add_argument (" --channels" , type = int , default = 3 , help = " number of image channels" )
35
+ parser .add_argument (" --sample_interval" , type = int , default = 500 , help = " interval between image sampling" )
36
36
opt = parser .parse_args ()
37
37
print (opt )
38
38
39
39
cuda = True if torch .cuda .is_available () else False
40
40
41
- # Calculate output of image discriminator (PatchGAN)
42
- patch_h , patch_w = int (opt .img_size / 2 ** 3 ), int (opt .img_size / 2 ** 3 )
43
- patch = (1 , patch_h , patch_w )
44
-
45
- def weights_init_normal (m ):
46
- classname = m .__class__ .__name__
47
- if classname .find ('Conv' ) != - 1 :
48
- torch .nn .init .normal_ (m .weight .data , 0.0 , 0.02 )
49
- elif classname .find ('BatchNorm2d' ) != - 1 :
50
- torch .nn .init .normal_ (m .weight .data , 1.0 , 0.02 )
51
- torch .nn .init .constant_ (m .bias .data , 0.0 )
41
+ input_shape = (opt .channels , opt .img_size , opt .img_size )
52
42
53
43
# Loss function
54
44
adversarial_loss = torch .nn .MSELoss ()
55
45
56
46
# Initialize generator and discriminator
57
- generator = Generator (channels = opt . channels )
58
- discriminator = Discriminator (channels = opt . channels )
47
+ generator = Generator (input_shape )
48
+ discriminator = Discriminator (input_shape )
59
49
60
50
if cuda :
61
51
generator .cuda ()
@@ -67,28 +57,32 @@ def weights_init_normal(m):
67
57
discriminator .apply (weights_init_normal )
68
58
69
59
# Dataset loader
70
- transforms_ = [ transforms .Resize ((opt .img_size , opt .img_size ), Image .BICUBIC ),
71
- transforms .ToTensor (),
72
- transforms .Normalize ((0.5 ,0.5 ,0.5 ), (0.5 ,0.5 ,0.5 )) ]
73
- transforms_lr = [ transforms .Resize ((opt .img_size // 4 , opt .img_size // 4 ), Image .BICUBIC ),
74
- transforms .ToTensor (),
75
- transforms .Normalize ((0.5 ,0.5 ,0.5 ), (0.5 ,0.5 ,0.5 )) ]
76
- dataloader = DataLoader (ImageDataset ("../../data/%s" % opt .dataset_name ,
77
- transforms_x = transforms_ , transforms_lr = transforms_lr ),
78
- batch_size = opt .batch_size , shuffle = True , num_workers = opt .n_cpu )
60
+ transforms_ = [
61
+ transforms .Resize ((opt .img_size , opt .img_size ), Image .BICUBIC ),
62
+ transforms .ToTensor (),
63
+ transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 )),
64
+ ]
65
+ transforms_lr = [
66
+ transforms .Resize ((opt .img_size // 4 , opt .img_size // 4 ), Image .BICUBIC ),
67
+ transforms .ToTensor (),
68
+ transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 )),
69
+ ]
70
+ dataloader = DataLoader (
71
+ ImageDataset ("../../data/%s" % opt .dataset_name , transforms_x = transforms_ , transforms_lr = transforms_lr ),
72
+ batch_size = opt .batch_size ,
73
+ shuffle = True ,
74
+ num_workers = opt .n_cpu ,
75
+ )
79
76
80
77
# Optimizers
81
78
optimizer_G = torch .optim .Adam (generator .parameters (), lr = opt .lr , betas = (opt .b1 , opt .b2 ))
82
79
optimizer_D = torch .optim .Adam (discriminator .parameters (), lr = opt .lr , betas = (opt .b1 , opt .b2 ))
83
80
84
81
Tensor = torch .cuda .FloatTensor if cuda else torch .FloatTensor
85
82
86
- # Adversarial ground truths
87
- valid = Variable (Tensor (np .ones (patch )), requires_grad = False )
88
- fake = Variable (Tensor (np .zeros (patch )), requires_grad = False )
89
83
90
84
def apply_random_mask (imgs ):
91
- idx = np .random .randint (0 , opt .img_size - opt .mask_size , (imgs .shape [0 ], 2 ))
85
+ idx = np .random .randint (0 , opt .img_size - opt .mask_size , (imgs .shape [0 ], 2 ))
92
86
93
87
masked_imgs = imgs .clone ()
94
88
for i , (y1 , x1 ) in enumerate (idx ):
@@ -97,25 +91,26 @@ def apply_random_mask(imgs):
97
91
98
92
return masked_imgs
99
93
94
+
100
95
def save_sample (saved_samples ):
101
96
# Generate inpainted image
102
- gen_imgs = generator (saved_samples [' masked' ], saved_samples [' lowres' ])
97
+ gen_imgs = generator (saved_samples [" masked" ], saved_samples [" lowres" ])
103
98
# Save sample
104
- sample = torch .cat ((saved_samples [' masked' ].data , gen_imgs .data , saved_samples [' imgs' ].data ), - 2 )
105
- save_image (sample ,' images/%d.png' % batches_done , nrow = 5 , normalize = True )
99
+ sample = torch .cat ((saved_samples [" masked" ].data , gen_imgs .data , saved_samples [" imgs" ].data ), - 2 )
100
+ save_image (sample , " images/%d.png" % batches_done , nrow = 5 , normalize = True )
106
101
107
102
108
103
saved_samples = {}
109
104
for epoch in range (opt .n_epochs ):
110
105
for i , batch in enumerate (dataloader ):
111
- imgs = batch ['x' ]
112
- imgs_lr = batch [' x_lr' ]
106
+ imgs = batch ["x" ]
107
+ imgs_lr = batch [" x_lr" ]
113
108
114
109
masked_imgs = apply_random_mask (imgs )
115
110
116
111
# Adversarial ground truths
117
- valid = Variable (Tensor (imgs .shape [0 ], * patch ).fill_ (1.0 ), requires_grad = False )
118
- fake = Variable (Tensor (imgs .shape [0 ], * patch ).fill_ (0.0 ), requires_grad = False )
112
+ valid = Variable (Tensor (imgs .shape [0 ], * discriminator . output_shape ).fill_ (1.0 ), requires_grad = False )
113
+ fake = Variable (Tensor (imgs .shape [0 ], * discriminator . output_shape ).fill_ (0.0 ), requires_grad = False )
119
114
120
115
if cuda :
121
116
imgs = imgs .type (Tensor )
@@ -155,18 +150,20 @@ def save_sample(saved_samples):
155
150
d_loss .backward ()
156
151
optimizer_D .step ()
157
152
158
- print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch , opt .n_epochs , i , len (dataloader ),
159
- d_loss .item (), g_loss .item ()))
153
+ print (
154
+ "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
155
+ % (epoch , opt .n_epochs , i , len (dataloader ), d_loss .item (), g_loss .item ())
156
+ )
160
157
161
158
# Save first ten samples
162
159
if not saved_samples :
163
- saved_samples [' imgs' ] = real_imgs [:1 ].clone ()
164
- saved_samples [' masked' ] = masked_imgs [:1 ].clone ()
165
- saved_samples [' lowres' ] = imgs_lr [:1 ].clone ()
166
- elif saved_samples [' imgs' ].size (0 ) < 10 :
167
- saved_samples [' imgs' ] = torch .cat ((saved_samples [' imgs' ], real_imgs [:1 ]), 0 )
168
- saved_samples [' masked' ] = torch .cat ((saved_samples [' masked' ], masked_imgs [:1 ]), 0 )
169
- saved_samples [' lowres' ] = torch .cat ((saved_samples [' lowres' ], imgs_lr [:1 ]), 0 )
160
+ saved_samples [" imgs" ] = real_imgs [:1 ].clone ()
161
+ saved_samples [" masked" ] = masked_imgs [:1 ].clone ()
162
+ saved_samples [" lowres" ] = imgs_lr [:1 ].clone ()
163
+ elif saved_samples [" imgs" ].size (0 ) < 10 :
164
+ saved_samples [" imgs" ] = torch .cat ((saved_samples [" imgs" ], real_imgs [:1 ]), 0 )
165
+ saved_samples [" masked" ] = torch .cat ((saved_samples [" masked" ], masked_imgs [:1 ]), 0 )
166
+ saved_samples [" lowres" ] = torch .cat ((saved_samples [" lowres" ], imgs_lr [:1 ]), 0 )
170
167
171
168
batches_done = epoch * len (dataloader ) + i
172
169
if batches_done % opt .sample_interval == 0 :
0 commit comments