Skip to content

Commit ce0e5b5

Browse files
Keep patchGAN output shape in discriminators
1 parent 03a9dde commit ce0e5b5

File tree

13 files changed

+275
-231
lines changed

13 files changed

+275
-231
lines changed

Diff for: implementations/bicyclegan/bicyclegan.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,13 @@
4646

4747
cuda = True if torch.cuda.is_available() else False
4848

49-
img_shape = (opt.channels, opt.img_height, opt.img_width)
49+
input_shape = (opt.channels, opt.img_height, opt.img_width)
5050

5151
# Loss functions
5252
mae_loss = torch.nn.L1Loss()
5353

5454
# Initialize generator, encoder and discriminators
55-
generator = Generator(opt.latent_dim, img_shape)
55+
generator = Generator(opt.latent_dim, input_shape)
5656
encoder = Encoder(opt.latent_dim)
5757
D_VAE = MultiDiscriminator()
5858
D_LR = MultiDiscriminator()

Diff for: implementations/ccgan/ccgan.py

+52-55
Original file line numberDiff line numberDiff line change
@@ -18,44 +18,34 @@
1818
import torch.nn.functional as F
1919
import torch
2020

21-
os.makedirs('images', exist_ok=True)
21+
os.makedirs("images", exist_ok=True)
2222

2323
parser = argparse.ArgumentParser()
24-
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
25-
parser.add_argument('--batch_size', type=int, default=8, help='size of the batches')
26-
parser.add_argument('--dataset_name', type=str, default='img_align_celeba', help='name of the dataset')
27-
parser.add_argument('--lr', type=float, default=0.0002, help='adam: learning rate')
28-
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
29-
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
30-
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
31-
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
32-
parser.add_argument('--img_size', type=int, default=128, help='size of each image dimension')
33-
parser.add_argument('--mask_size', type=int, default=32, help='size of random mask')
34-
parser.add_argument('--channels', type=int, default=3, help='number of image channels')
35-
parser.add_argument('--sample_interval', type=int, default=500, help='interval between image sampling')
24+
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
25+
parser.add_argument("--batch_size", type=int, default=8, help="size of the batches")
26+
parser.add_argument("--dataset_name", type=str, default="img_align_celeba", help="name of the dataset")
27+
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
28+
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
29+
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
30+
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
31+
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
32+
parser.add_argument("--img_size", type=int, default=128, help="size of each image dimension")
33+
parser.add_argument("--mask_size", type=int, default=32, help="size of random mask")
34+
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
35+
parser.add_argument("--sample_interval", type=int, default=500, help="interval between image sampling")
3636
opt = parser.parse_args()
3737
print(opt)
3838

3939
cuda = True if torch.cuda.is_available() else False
4040

41-
# Calculate output of image discriminator (PatchGAN)
42-
patch_h, patch_w = int(opt.img_size / 2**3), int(opt.img_size / 2**3)
43-
patch = (1, patch_h, patch_w)
44-
45-
def weights_init_normal(m):
46-
classname = m.__class__.__name__
47-
if classname.find('Conv') != -1:
48-
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
49-
elif classname.find('BatchNorm2d') != -1:
50-
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
51-
torch.nn.init.constant_(m.bias.data, 0.0)
41+
input_shape = (opt.channels, opt.img_size, opt.img_size)
5242

5343
# Loss function
5444
adversarial_loss = torch.nn.MSELoss()
5545

5646
# Initialize generator and discriminator
57-
generator = Generator(channels=opt.channels)
58-
discriminator = Discriminator(channels=opt.channels)
47+
generator = Generator(input_shape)
48+
discriminator = Discriminator(input_shape)
5949

6050
if cuda:
6151
generator.cuda()
@@ -67,28 +57,32 @@ def weights_init_normal(m):
6757
discriminator.apply(weights_init_normal)
6858

6959
# Dataset loader
70-
transforms_ = [ transforms.Resize((opt.img_size, opt.img_size), Image.BICUBIC),
71-
transforms.ToTensor(),
72-
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
73-
transforms_lr = [ transforms.Resize((opt.img_size//4, opt.img_size//4), Image.BICUBIC),
74-
transforms.ToTensor(),
75-
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
76-
dataloader = DataLoader(ImageDataset("../../data/%s" % opt.dataset_name,
77-
transforms_x=transforms_, transforms_lr=transforms_lr),
78-
batch_size=opt.batch_size, shuffle=True, num_workers=opt.n_cpu)
60+
transforms_ = [
61+
transforms.Resize((opt.img_size, opt.img_size), Image.BICUBIC),
62+
transforms.ToTensor(),
63+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
64+
]
65+
transforms_lr = [
66+
transforms.Resize((opt.img_size // 4, opt.img_size // 4), Image.BICUBIC),
67+
transforms.ToTensor(),
68+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
69+
]
70+
dataloader = DataLoader(
71+
ImageDataset("../../data/%s" % opt.dataset_name, transforms_x=transforms_, transforms_lr=transforms_lr),
72+
batch_size=opt.batch_size,
73+
shuffle=True,
74+
num_workers=opt.n_cpu,
75+
)
7976

8077
# Optimizers
8178
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
8279
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
8380

8481
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
8582

86-
# Adversarial ground truths
87-
valid = Variable(Tensor(np.ones(patch)), requires_grad=False)
88-
fake = Variable(Tensor(np.zeros(patch)), requires_grad=False)
8983

9084
def apply_random_mask(imgs):
91-
idx = np.random.randint(0, opt.img_size-opt.mask_size, (imgs.shape[0], 2))
85+
idx = np.random.randint(0, opt.img_size - opt.mask_size, (imgs.shape[0], 2))
9286

9387
masked_imgs = imgs.clone()
9488
for i, (y1, x1) in enumerate(idx):
@@ -97,25 +91,26 @@ def apply_random_mask(imgs):
9791

9892
return masked_imgs
9993

94+
10095
def save_sample(saved_samples):
10196
# Generate inpainted image
102-
gen_imgs = generator(saved_samples['masked'], saved_samples['lowres'])
97+
gen_imgs = generator(saved_samples["masked"], saved_samples["lowres"])
10398
# Save sample
104-
sample = torch.cat((saved_samples['masked'].data, gen_imgs.data, saved_samples['imgs'].data), -2)
105-
save_image(sample,'images/%d.png' % batches_done, nrow=5, normalize=True)
99+
sample = torch.cat((saved_samples["masked"].data, gen_imgs.data, saved_samples["imgs"].data), -2)
100+
save_image(sample, "images/%d.png" % batches_done, nrow=5, normalize=True)
106101

107102

108103
saved_samples = {}
109104
for epoch in range(opt.n_epochs):
110105
for i, batch in enumerate(dataloader):
111-
imgs = batch['x']
112-
imgs_lr = batch['x_lr']
106+
imgs = batch["x"]
107+
imgs_lr = batch["x_lr"]
113108

114109
masked_imgs = apply_random_mask(imgs)
115110

116111
# Adversarial ground truths
117-
valid = Variable(Tensor(imgs.shape[0], *patch).fill_(1.0), requires_grad=False)
118-
fake = Variable(Tensor(imgs.shape[0], *patch).fill_(0.0), requires_grad=False)
112+
valid = Variable(Tensor(imgs.shape[0], *discriminator.output_shape).fill_(1.0), requires_grad=False)
113+
fake = Variable(Tensor(imgs.shape[0], *discriminator.output_shape).fill_(0.0), requires_grad=False)
119114

120115
if cuda:
121116
imgs = imgs.type(Tensor)
@@ -155,18 +150,20 @@ def save_sample(saved_samples):
155150
d_loss.backward()
156151
optimizer_D.step()
157152

158-
print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader),
159-
d_loss.item(), g_loss.item()))
153+
print(
154+
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
155+
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
156+
)
160157

161158
# Save first ten samples
162159
if not saved_samples:
163-
saved_samples['imgs'] = real_imgs[:1].clone()
164-
saved_samples['masked'] = masked_imgs[:1].clone()
165-
saved_samples['lowres'] = imgs_lr[:1].clone()
166-
elif saved_samples['imgs'].size(0) < 10:
167-
saved_samples['imgs'] = torch.cat((saved_samples['imgs'], real_imgs[:1]), 0)
168-
saved_samples['masked'] = torch.cat((saved_samples['masked'], masked_imgs[:1]), 0)
169-
saved_samples['lowres'] = torch.cat((saved_samples['lowres'], imgs_lr[:1]), 0)
160+
saved_samples["imgs"] = real_imgs[:1].clone()
161+
saved_samples["masked"] = masked_imgs[:1].clone()
162+
saved_samples["lowres"] = imgs_lr[:1].clone()
163+
elif saved_samples["imgs"].size(0) < 10:
164+
saved_samples["imgs"] = torch.cat((saved_samples["imgs"], real_imgs[:1]), 0)
165+
saved_samples["masked"] = torch.cat((saved_samples["masked"], masked_imgs[:1]), 0)
166+
saved_samples["lowres"] = torch.cat((saved_samples["lowres"], imgs_lr[:1]), 0)
170167

171168
batches_done = epoch * len(dataloader) + i
172169
if batches_done % opt.sample_interval == 0:

Diff for: implementations/ccgan/models.py

+21-17
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# U-NET
77
##############################
88

9+
910
class UNetDown(nn.Module):
1011
def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
1112
super(UNetDown, self).__init__()
@@ -21,12 +22,15 @@ def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
2122
def forward(self, x):
2223
return self.model(x)
2324

25+
2426
class UNetUp(nn.Module):
2527
def __init__(self, in_size, out_size, dropout=0.0):
2628
super(UNetUp, self).__init__()
27-
model = [ nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, bias=False),
28-
nn.BatchNorm2d(out_size, 0.8),
29-
nn.ReLU(inplace=True)]
29+
model = [
30+
nn.ConvTranspose2d(in_size, out_size, 4, stride=2, padding=1, bias=False),
31+
nn.BatchNorm2d(out_size, 0.8),
32+
nn.ReLU(inplace=True),
33+
]
3034
if dropout:
3135
model.append(nn.Dropout(dropout))
3236

@@ -35,16 +39,16 @@ def __init__(self, in_size, out_size, dropout=0.0):
3539
def forward(self, x, skip_input):
3640
x = self.model(x)
3741
out = torch.cat((x, skip_input), 1)
38-
#out = torch.add(x, skip_input)
3942
return out
4043

44+
4145
class Generator(nn.Module):
42-
def __init__(self, channels=3):
46+
def __init__(self, input_shape):
4347
super(Generator, self).__init__()
44-
48+
channels, _, _ = input_shape
4549
self.down1 = UNetDown(channels, 64, normalize=False)
4650
self.down2 = UNetDown(64, 128)
47-
self.down3 = UNetDown(128+channels, 256, dropout=0.5)
51+
self.down3 = UNetDown(128 + channels, 256, dropout=0.5)
4852
self.down4 = UNetDown(256, 512, dropout=0.5)
4953
self.down5 = UNetDown(512, 512, dropout=0.5)
5054
self.down6 = UNetDown(512, 512, dropout=0.5)
@@ -53,12 +57,9 @@ def __init__(self, channels=3):
5357
self.up2 = UNetUp(1024, 512, dropout=0.5)
5458
self.up3 = UNetUp(1024, 256, dropout=0.5)
5559
self.up4 = UNetUp(512, 128)
56-
self.up5 = UNetUp(256+channels, 64)
60+
self.up5 = UNetUp(256 + channels, 64)
5761

58-
59-
final = [ nn.Upsample(scale_factor=2),
60-
nn.Conv2d(128, channels, 3, 1, 1),
61-
nn.Tanh() ]
62+
final = [nn.Upsample(scale_factor=2), nn.Conv2d(128, channels, 3, 1, 1), nn.Tanh()]
6263
self.final = nn.Sequential(*final)
6364

6465
def forward(self, x, x_lr):
@@ -78,10 +79,16 @@ def forward(self, x, x_lr):
7879

7980
return self.final(u5)
8081

82+
8183
class Discriminator(nn.Module):
82-
def __init__(self, channels=3):
84+
def __init__(self, input_shape):
8385
super(Discriminator, self).__init__()
8486

87+
channels, height, width = input_shape
88+
# Calculate output of image discriminator (PatchGAN)
89+
patch_h, patch_w = int(height / 2 ** 3), int(width / 2 ** 3)
90+
self.output_shape = (1, patch_h, patch_w)
91+
8592
def discriminator_block(in_filters, out_filters, stride, normalize):
8693
"""Returns layers of each discriminator block"""
8794
layers = [nn.Conv2d(in_filters, out_filters, 3, stride, 1)]
@@ -92,10 +99,7 @@ def discriminator_block(in_filters, out_filters, stride, normalize):
9299

93100
layers = []
94101
in_filters = channels
95-
for out_filters, stride, normalize in [ (64, 2, False),
96-
(128, 2, True),
97-
(256, 2, True),
98-
(512, 1, True)]:
102+
for out_filters, stride, normalize in [(64, 2, False), (128, 2, True), (256, 2, True), (512, 1, True)]:
99103
layers.extend(discriminator_block(in_filters, out_filters, stride, normalize))
100104
in_filters = out_filters
101105

Diff for: implementations/cgan/cgan.py

-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def forward(self, img, labels):
9393

9494
# Loss functions
9595
adversarial_loss = torch.nn.MSELoss()
96-
auxiliary_loss = torch.nn.CrossEntropyLoss()
9796

9897
# Initialize generator and discriminator
9998
generator = Generator()
@@ -103,7 +102,6 @@ def forward(self, img, labels):
103102
generator.cuda()
104103
discriminator.cuda()
105104
adversarial_loss.cuda()
106-
auxiliary_loss.cuda()
107105

108106
# Configure data loader
109107
os.makedirs("../../data/mnist", exist_ok=True)

0 commit comments

Comments
 (0)