Skip to content

【Hackathon 8th No.23】Improved Training of Wasserstein GANs 论文复现 #1147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 47 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
39a81ae
Create __init__.py
robinbg Apr 27, 2025
7e3f417
Create __init__.py
robinbg Apr 27, 2025
0e508e5
Add files via upload
robinbg Apr 27, 2025
1d8dcbb
Create __init__.py
robinbg Apr 27, 2025
6c34944
Add files via upload
robinbg Apr 27, 2025
28165b2
Create wgan_gp_toy.py
robinbg Apr 27, 2025
82bc050
Add files via upload
robinbg Apr 27, 2025
4f8f051
Update wgan_gp_cifar.py
robinbg Apr 27, 2025
f46c359
Update wgan_gp_mnist.py
robinbg Apr 27, 2025
af98e26
Update wgan_gp_toy.py
robinbg Apr 27, 2025
877d02c
Update base_gan.py
robinbg Apr 27, 2025
56f4fbd
Update wgan_gp_cifar.py
robinbg Apr 27, 2025
479f23f
Update wgan_gp_mnist.py
robinbg Apr 27, 2025
06f9823
Update wgan_gp_mnist.py
robinbg Apr 27, 2025
4f90065
Update wgan_gp_toy.py
robinbg Apr 27, 2025
9722961
Update base_gan.py
robinbg Apr 27, 2025
ff4ec81
Update wgan.py
robinbg Apr 27, 2025
cc6373c
Update wgan_gp.py
robinbg Apr 27, 2025
623b601
Update losses.py
robinbg Apr 27, 2025
7fbdfe2
Update metrics.py
robinbg Apr 27, 2025
d3f6980
Update metrics.py
robinbg Apr 27, 2025
649829d
Update wgan_gp_cifar.py
robinbg Apr 27, 2025
e7bf531
Update wgan_gp_mnist.py
robinbg Apr 27, 2025
f6a2655
Update wgan_gp_toy.py
robinbg Apr 27, 2025
54d6c3c
Update wgan_gp_mnist.py
robinbg Apr 27, 2025
dd67773
Update wgan.py
robinbg Apr 27, 2025
65d1596
Update wgan_gp.py
robinbg Apr 27, 2025
0d7b0b0
Update wgan_gp_cifar.py
robinbg Apr 27, 2025
fa44634
Update wgan_gp_mnist.py
robinbg Apr 27, 2025
581303a
Update wgan_gp_toy.py
robinbg Apr 27, 2025
497450c
Update wgan_gp_mnist.py
robinbg Apr 27, 2025
2c7595a
Update wgan_gp_cifar.py
robinbg Apr 27, 2025
20e9df6
Update wgan_gp_mnist.py
robinbg Apr 27, 2025
ae0855d
Update wgan_gp_toy.py
robinbg Apr 27, 2025
92c8243
Update wgan_gp_cifar.py
robinbg Apr 27, 2025
255e4d5
Update wgan_gp_mnist.py
robinbg Apr 27, 2025
f1aad59
Update wgan_gp_mnist.py
robinbg Apr 27, 2025
077b1ed
Update wgan_gp_toy.py
robinbg Apr 27, 2025
bd7418f
Update wgan_gp_mnist.py
robinbg Apr 27, 2025
3532bd3
Fix code style issues in WGAN-GP implementation
robinbg Apr 28, 2025
a77787a
Fix f-string syntax errors in wgan_gp_cifar.py and wgan_gp_toy.py
robinbg Apr 28, 2025
5c04b13
Fix code formatting issues with black to comply with PEP 8 style guide
robinbg Apr 28, 2025
b2ab193
Update wgan_gp_cifar.py
robinbg Apr 30, 2025
0bf3d49
Update wgan_gp_mnist.py
robinbg Apr 30, 2025
21371a4
Update wgan_gp_cifar.py
robinbg May 12, 2025
a21b565
Update wgan_gp_mnist.py
robinbg May 12, 2025
a115907
Update wgan_gp.py
robinbg May 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions examples/wgan_gp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
WGAN-GP implementation for PaddleScience.

This module provides implementation of Wasserstein GAN with Gradient Penalty.
"""
177 changes: 177 additions & 0 deletions examples/wgan_gp/cases/wgan_gp_cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import os

import matplotlib.pyplot as plt
import paddle
import paddle.nn as nn
import paddle.vision.transforms as transforms

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_DIR)
from models.wgan_gp import WGAN_GP


class CIFAR10Generator(nn.Layer):
"""
Generator network for CIFAR-10 dataset.
"""

def __init__(self, noise_dim=100, output_channels=3):
super(CIFAR10Generator, self).__init__()

self.layers1 = nn.Sequential(
nn.Linear(noise_dim, 512 * 4 * 4),
nn.BatchNorm1D(512 * 4 * 4),
nn.ReLU(),
)
self.layers2 = nn.Sequential(
nn.Conv2DTranspose(512, 256, 4, 2, 1),
nn.BatchNorm2D(256),
nn.ReLU(),
nn.Conv2DTranspose(256, 128, 4, 2, 1),
nn.BatchNorm2D(128),
nn.ReLU(),
nn.Conv2DTranspose(128, output_channels, 4, 2, 1),
nn.Tanh(),
)

def forward(self, x):
x = self.layers1(x)
x = x.reshape([-1, 512, 4, 4])
x = self.layers2(x)
return x


class CIFAR10Discriminator(nn.Layer):
"""
Discriminator network for CIFAR-10 dataset.
"""

def __init__(self, input_channels=3):
super(CIFAR10Discriminator, self).__init__()

self.model = nn.Sequential(
nn.Conv2D(input_channels, 128, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2D(128, 256, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2D(256, 512, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Flatten(),
nn.Linear(512 * 4 * 4, 1),
)

def forward(self, x):
return self.model(x)


def main():
"""
Main function to train WGAN-GP on CIFAR-10 dataset.
"""
output_dir = "output/cifar10"
os.makedirs(output_dir, exist_ok=True)

transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)

train_dataset = paddle.vision.datasets.Cifar10(
mode="train",
transform=transform,
download=True,
)

generator = CIFAR10Generator(noise_dim=100, output_channels=3)
discriminator = CIFAR10Discriminator(input_channels=3)

wgan_gp = WGAN_GP(
generator=generator,
discriminator=discriminator,
lambda_gp=10.0,
critic_iters=5,
)

data_loader = paddle.io.DataLoader(
train_dataset,
batch_size=64,
shuffle=True,
)

g_optimizer = paddle.optimizer.Adam(
parameters=generator.parameters(),
learning_rate=1e-4,
beta1=0.5,
beta2=0.9,
)

d_optimizer = paddle.optimizer.Adam(
parameters=discriminator.parameters(),
learning_rate=1e-4,
beta1=0.5,
beta2=0.9,
)

history = {
"g_loss": [],
"d_loss": [],
}

iterations = 50000
save_interval = 5000
data_loader_iter = iter(data_loader)

for iteration in range(iterations):
try:
real_data = next(data_loader_iter)
if isinstance(real_data, (list, tuple)):
# Extract images from (images, labels) tuple
real_data = real_data[0]
except StopIteration:
data_loader_iter = iter(data_loader)
real_data = next(data_loader_iter)
if isinstance(real_data, (list, tuple)):
# Extract images from (images, labels) tuple
real_data = real_data[0]

step_results = wgan_gp.train_step(real_data, g_optimizer, d_optimizer)

history["g_loss"].append(step_results["g_loss"])
history["d_loss"].append(step_results["d_loss"])

if iteration % 100 == 0:
print(
f"Iteration {iteration}: g_loss = {step_results['g_loss']:.4f}, d_loss = {step_results['d_loss']:.4f}"
)

if iteration % save_interval == 0 or iteration == iterations - 1:
with paddle.no_grad():
samples = wgan_gp.generate(16)

from utils.visualization import save_image_grid

save_image_grid(samples, f"{output_dir}/samples_{iteration}.png")

paddle.save(
generator.state_dict(), f"{output_dir}/generator_{iteration}.pdparams"
)
paddle.save(
discriminator.state_dict(),
f"{output_dir}/discriminator_{iteration}.pdparams",
)

plt.figure(figsize=(10, 5))
plt.plot(history["g_loss"], label="Generator Loss")
plt.plot(history["d_loss"], label="Discriminator Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.savefig(f"{output_dir}/loss_curves.png")
plt.close()


if __name__ == "__main__":
main()
119 changes: 119 additions & 0 deletions examples/wgan_gp/cases/wgan_gp_mnist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os

import matplotlib.pyplot as plt
import paddle
import paddle.nn as nn
import paddle.vision.transforms as transforms
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_DIR)
from models.wgan_gp import WGAN_GP

class MNISTGenerator(nn.Layer):
"""
Generator network for MNIST dataset.
"""
def __init__(self, noise_dim=100, output_channels=3):
super(CIFAR10Generator, self).__init__()

self.layers1 = nn.Sequential(
nn.Linear(noise_dim, 128 * 7 * 7),
nn.BatchNorm1D(128 * 7 * 7),
nn.ReLU(),
)
self.layers2 = nn.Sequential(
nn.Conv2DTranspose(128, 64, 4, 2, 1),
nn.BatchNorm2D(64),
nn.ReLU(),
nn.Conv2DTranspose(64, output_channels, 4, 2, 1),
nn.Tanh(),
)

def forward(self, x):
x = self.layers1(x)
x = x.reshape([-1, 128, 7, 7])
x = self.layers2(x)
return x



class MNISTDiscriminator(nn.Layer):
"""
Discriminator network for MNIST dataset.
"""

def __init__(self, input_channels=1):
super(MNISTDiscriminator, self).__init__()

self.model = nn.Sequential(
nn.Conv2D(input_channels, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2D(64, 128, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Flatten(),
nn.Linear(128 * 7 * 7, 1),
)

def forward(self, x):
return self.model(x)


def main():
"""
Main function to train WGAN-GP on MNIST dataset.
"""
output_dir = "output/mnist"
os.makedirs(output_dir, exist_ok=True)

transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

train_dataset = paddle.vision.datasets.MNIST(
mode="train",
transform=transform,
download=True,
)

generator = MNISTGenerator(noise_dim=100, output_channels=1)
discriminator = MNISTDiscriminator(input_channels=1)

wgan_gp = WGAN_GP(
generator=generator,
discriminator=discriminator,
lambda_gp=10.0,
critic_iters=5,
)

history = wgan_gp.train(
train_dataset,
batch_size=64,
iterations=20000,
g_learning_rate=1e-4,
d_learning_rate=1e-4,
save_interval=1000,
save_path=output_dir,
)

plt.figure(figsize=(10, 5))
plt.plot(history["g_loss"], label="Generator Loss")
plt.plot(history["d_loss"], label="Discriminator Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.grid(True)
plt.savefig(f"{output_dir}/loss_curves.png")
plt.close()

with paddle.no_grad():
samples = wgan_gp.generate(16)

from utils.visualization import save_image_grid

save_image_grid(samples, f"{output_dir}/final_samples.png")


if __name__ == "__main__":
main()
Loading