Skip to content

Commit 72ba807

Browse files
committed
add new GAN stability measure (zero centered gp on fake images as well) out of Cornell and Brown university
1 parent 5ff7b57 commit 72ba807

File tree

3 files changed

+30
-10
lines changed

3 files changed

+30
-10
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,3 +490,12 @@ Thank you to Matthew Mann for his inspiring [simple port](https://github.com/man
490490
primaryClass = {cs.CV}
491491
}
492492
```
493+
494+
```bibtex
495+
@inproceedings{Huang2025TheGI,
496+
title = {The GAN is dead; long live the GAN! A Modern GAN Baseline},
497+
author = {Yiwen Huang and Aaron Gokaslan and Volodymyr Kuleshov and James Tompkin},
498+
year = {2025},
499+
url = {https://api.semanticscholar.org/CorpusID:275405495}
500+
}
501+
```

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,14 @@
2626
],
2727
install_requires=[
2828
'aim',
29-
'einops>=0.7.0',
29+
'einops>=0.8.0',
3030
'contrastive_learner>=0.1.0',
3131
'fire',
3232
'kornia>=0.5.4',
3333
'numpy',
3434
'retry',
3535
'tqdm',
36-
'torch',
36+
'torch>=2.2',
3737
'torchvision',
3838
'pillow',
3939
'vector-quantize-pytorch==0.1.0'

stylegan2_pytorch/stylegan2_pytorch.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -232,14 +232,14 @@ def loss_backwards(fp16, loss, optimizer, loss_id, **kwargs):
232232
else:
233233
loss.backward(**kwargs)
234234

235-
def gradient_penalty(images, output, weight = 10):
235+
def gradient_penalty(images, output, weight = 10, center = 0.):
236236
batch_size = images.shape[0]
237237
gradients = torch_grad(outputs=output, inputs=images,
238238
grad_outputs=torch.ones(output.size(), device=images.device),
239239
create_graph=True, retain_graph=True, only_inputs=True)[0]
240240

241241
gradients = gradients.reshape(batch_size, -1)
242-
return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
242+
return weight * ((gradients.norm(2, dim=1) - center) ** 2).mean()
243243

244244
def calc_pl_lengths(styles, images):
245245
device = images.device
@@ -396,15 +396,23 @@ def __init__(self, D, image_size):
396396
super().__init__()
397397
self.D = D
398398

399-
def forward(self, images, prob = 0., types = [], detach = False):
399+
def forward(self, images, prob = 0., types = [], detach = False, return_aug_images = False, input_requires_grad = False):
400400
if random() < prob:
401401
images = random_hflip(images, prob=0.5)
402402
images = DiffAugment(images, types=types)
403403

404404
if detach:
405405
images = images.detach()
406406

407-
return self.D(images)
407+
if input_requires_grad:
408+
images.requires_grad_()
409+
410+
logits = self.D(images)
411+
412+
if not return_aug_images:
413+
return logits
414+
415+
return images, logits
408416

409417
# stylegan2 classes
410418

@@ -1030,10 +1038,13 @@ def train(self):
10301038
w_styles = styles_def_to_tensor(w_space)
10311039

10321040
generated_images = G(w_styles, noise)
1033-
fake_output, fake_q_loss = D_aug(generated_images.clone().detach(), detach = True, **aug_kwargs)
1041+
generated_images, (fake_output, fake_q_loss) = D_aug(generated_images.clone().detach(), return_aug_images = True, input_requires_grad = apply_gradient_penalty, detach = True, **aug_kwargs)
10341042

10351043
image_batch = next(self.loader).cuda(self.rank)
1036-
image_batch.requires_grad_()
1044+
1045+
if apply_gradient_penalty:
1046+
image_batch.requires_grad_()
1047+
10371048
real_output, real_q_loss = D_aug(image_batch, **aug_kwargs)
10381049

10391050
real_output_loss = real_output
@@ -1053,7 +1064,7 @@ def train(self):
10531064
disc_loss = disc_loss + quantize_loss
10541065

10551066
if apply_gradient_penalty:
1056-
gp = gradient_penalty(image_batch, real_output)
1067+
gp = gradient_penalty(image_batch, real_output) + gradient_penalty(generated_images, fake_output)
10571068
self.last_gp_loss = gp.clone().detach().item()
10581069
self.track(self.last_gp_loss, 'GP')
10591070
disc_loss = disc_loss + gp
@@ -1382,7 +1393,7 @@ def load(self, num = -1):
13821393

13831394
self.steps = name * self.save_every
13841395

1385-
load_data = torch.load(self.model_name(name))
1396+
load_data = torch.load(self.model_name(name), weights_only = True)
13861397

13871398
if 'version' in load_data:
13881399
print(f"loading from version {load_data['version']}")

0 commit comments

Comments
 (0)