diff --git a/examples/wgan_gp/__init__.py b/examples/wgan_gp/__init__.py new file mode 100644 index 000000000..3b3749607 --- /dev/null +++ b/examples/wgan_gp/__init__.py @@ -0,0 +1,5 @@ +""" +WGAN-GP implementation for PaddleScience. + +This module provides implementation of Wasserstein GAN with Gradient Penalty. +""" diff --git a/examples/wgan_gp/cases/wgan_gp_cifar.py b/examples/wgan_gp/cases/wgan_gp_cifar.py new file mode 100644 index 000000000..931bec6e5 --- /dev/null +++ b/examples/wgan_gp/cases/wgan_gp_cifar.py @@ -0,0 +1,177 @@ +import os + +import matplotlib.pyplot as plt +import paddle +import paddle.nn as nn +import paddle.vision.transforms as transforms + +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(ROOT_DIR) +from models.wgan_gp import WGAN_GP + + +class CIFAR10Generator(nn.Layer): + """ + Generator network for CIFAR-10 dataset. + """ + + def __init__(self, noise_dim=100, output_channels=3): + super(CIFAR10Generator, self).__init__() + + self.layers1 = nn.Sequential( + nn.Linear(noise_dim, 512 * 4 * 4), + nn.BatchNorm1D(512 * 4 * 4), + nn.ReLU(), + ) + self.layers2 = nn.Sequential( + nn.Conv2DTranspose(512, 256, 4, 2, 1), + nn.BatchNorm2D(256), + nn.ReLU(), + nn.Conv2DTranspose(256, 128, 4, 2, 1), + nn.BatchNorm2D(128), + nn.ReLU(), + nn.Conv2DTranspose(128, output_channels, 4, 2, 1), + nn.Tanh(), + ) + + def forward(self, x): + x = self.layers1(x) + x = x.reshape([-1, 512, 4, 4]) + x = self.layers2(x) + return x + + +class CIFAR10Discriminator(nn.Layer): + """ + Discriminator network for CIFAR-10 dataset. + """ + + def __init__(self, input_channels=3): + super(CIFAR10Discriminator, self).__init__() + + self.model = nn.Sequential( + nn.Conv2D(input_channels, 128, 4, 2, 1), + nn.LeakyReLU(0.2), + nn.Conv2D(128, 256, 4, 2, 1), + nn.LeakyReLU(0.2), + nn.Conv2D(256, 512, 4, 2, 1), + nn.LeakyReLU(0.2), + nn.Flatten(), + nn.Linear(512 * 4 * 4, 1), + ) + + def forward(self, x): + return self.model(x) + + +def main(): + """ + Main function to train WGAN-GP on CIFAR-10 dataset. + """ + output_dir = "output/cifar10" + os.makedirs(output_dir, exist_ok=True) + + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + ] + ) + + train_dataset = paddle.vision.datasets.Cifar10( + mode="train", + transform=transform, + download=True, + ) + + generator = CIFAR10Generator(noise_dim=100, output_channels=3) + discriminator = CIFAR10Discriminator(input_channels=3) + + wgan_gp = WGAN_GP( + generator=generator, + discriminator=discriminator, + lambda_gp=10.0, + critic_iters=5, + ) + + data_loader = paddle.io.DataLoader( + train_dataset, + batch_size=64, + shuffle=True, + ) + + g_optimizer = paddle.optimizer.Adam( + parameters=generator.parameters(), + learning_rate=1e-4, + beta1=0.5, + beta2=0.9, + ) + + d_optimizer = paddle.optimizer.Adam( + parameters=discriminator.parameters(), + learning_rate=1e-4, + beta1=0.5, + beta2=0.9, + ) + + history = { + "g_loss": [], + "d_loss": [], + } + + iterations = 50000 + save_interval = 5000 + data_loader_iter = iter(data_loader) + + for iteration in range(iterations): + try: + real_data = next(data_loader_iter) + if isinstance(real_data, (list, tuple)): + # Extract images from (images, labels) tuple + real_data = real_data[0] + except StopIteration: + data_loader_iter = iter(data_loader) + real_data = next(data_loader_iter) + if isinstance(real_data, (list, tuple)): + # Extract images from (images, labels) tuple + real_data = real_data[0] + + step_results = wgan_gp.train_step(real_data, g_optimizer, d_optimizer) + + history["g_loss"].append(step_results["g_loss"]) + history["d_loss"].append(step_results["d_loss"]) + + if iteration % 100 == 0: + print( + f"Iteration {iteration}: g_loss = {step_results['g_loss']:.4f}, d_loss = {step_results['d_loss']:.4f}" + ) + + if iteration % save_interval == 0 or iteration == iterations - 1: + with paddle.no_grad(): + samples = wgan_gp.generate(16) + + from utils.visualization import save_image_grid + + save_image_grid(samples, f"{output_dir}/samples_{iteration}.png") + + paddle.save( + generator.state_dict(), f"{output_dir}/generator_{iteration}.pdparams" + ) + paddle.save( + discriminator.state_dict(), + f"{output_dir}/discriminator_{iteration}.pdparams", + ) + + plt.figure(figsize=(10, 5)) + plt.plot(history["g_loss"], label="Generator Loss") + plt.plot(history["d_loss"], label="Discriminator Loss") + plt.xlabel("Iterations") + plt.ylabel("Loss") + plt.legend() + plt.grid(True) + plt.savefig(f"{output_dir}/loss_curves.png") + plt.close() + + +if __name__ == "__main__": + main() diff --git a/examples/wgan_gp/cases/wgan_gp_mnist.py b/examples/wgan_gp/cases/wgan_gp_mnist.py new file mode 100644 index 000000000..8ebe00dd2 --- /dev/null +++ b/examples/wgan_gp/cases/wgan_gp_mnist.py @@ -0,0 +1,119 @@ +import os + +import matplotlib.pyplot as plt +import paddle +import paddle.nn as nn +import paddle.vision.transforms as transforms +ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(ROOT_DIR) +from models.wgan_gp import WGAN_GP + +class MNISTGenerator(nn.Layer): + """ + Generator network for MNIST dataset. + """ + def __init__(self, noise_dim=100, output_channels=3): + super(CIFAR10Generator, self).__init__() + + self.layers1 = nn.Sequential( + nn.Linear(noise_dim, 128 * 7 * 7), + nn.BatchNorm1D(128 * 7 * 7), + nn.ReLU(), + ) + self.layers2 = nn.Sequential( + nn.Conv2DTranspose(128, 64, 4, 2, 1), + nn.BatchNorm2D(64), + nn.ReLU(), + nn.Conv2DTranspose(64, output_channels, 4, 2, 1), + nn.Tanh(), + ) + + def forward(self, x): + x = self.layers1(x) + x = x.reshape([-1, 128, 7, 7]) + x = self.layers2(x) + return x + + + +class MNISTDiscriminator(nn.Layer): + """ + Discriminator network for MNIST dataset. + """ + + def __init__(self, input_channels=1): + super(MNISTDiscriminator, self).__init__() + + self.model = nn.Sequential( + nn.Conv2D(input_channels, 64, 4, 2, 1), + nn.LeakyReLU(0.2), + nn.Conv2D(64, 128, 4, 2, 1), + nn.LeakyReLU(0.2), + nn.Flatten(), + nn.Linear(128 * 7 * 7, 1), + ) + + def forward(self, x): + return self.model(x) + + +def main(): + """ + Main function to train WGAN-GP on MNIST dataset. + """ + output_dir = "output/mnist" + os.makedirs(output_dir, exist_ok=True) + + transform = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + train_dataset = paddle.vision.datasets.MNIST( + mode="train", + transform=transform, + download=True, + ) + + generator = MNISTGenerator(noise_dim=100, output_channels=1) + discriminator = MNISTDiscriminator(input_channels=1) + + wgan_gp = WGAN_GP( + generator=generator, + discriminator=discriminator, + lambda_gp=10.0, + critic_iters=5, + ) + + history = wgan_gp.train( + train_dataset, + batch_size=64, + iterations=20000, + g_learning_rate=1e-4, + d_learning_rate=1e-4, + save_interval=1000, + save_path=output_dir, + ) + + plt.figure(figsize=(10, 5)) + plt.plot(history["g_loss"], label="Generator Loss") + plt.plot(history["d_loss"], label="Discriminator Loss") + plt.xlabel("Iterations") + plt.ylabel("Loss") + plt.legend() + plt.grid(True) + plt.savefig(f"{output_dir}/loss_curves.png") + plt.close() + + with paddle.no_grad(): + samples = wgan_gp.generate(16) + + from utils.visualization import save_image_grid + + save_image_grid(samples, f"{output_dir}/final_samples.png") + + +if __name__ == "__main__": + main() diff --git a/examples/wgan_gp/cases/wgan_gp_toy.py b/examples/wgan_gp/cases/wgan_gp_toy.py new file mode 100644 index 000000000..efdbd0df8 --- /dev/null +++ b/examples/wgan_gp/cases/wgan_gp_toy.py @@ -0,0 +1,215 @@ +import os + +import matplotlib.pyplot as plt +import numpy as np +import paddle +import paddle.nn as nn + +from ..models.wgan_gp import WGAN_GP + + +class ToyGenerator(nn.Layer): + """ + Generator network for toy datasets. + """ + + def __init__(self, noise_dim=2, output_dim=2, hidden_dim=128): + super(ToyGenerator, self).__init__() + + self.model = nn.Sequential( + nn.Linear(noise_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, output_dim), + ) + + def forward(self, x): + return self.model(x) + + +class ToyDiscriminator(nn.Layer): + """ + Discriminator network for toy datasets. + """ + + def __init__(self, input_dim=2, hidden_dim=128): + super(ToyDiscriminator, self).__init__() + + self.model = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, 1), + ) + + def forward(self, x): + return self.model(x) + + +class GaussianMixture(paddle.io.Dataset): + """ + Gaussian mixture dataset for toy experiments. + """ + + def __init__(self, n_samples=10000, n_components=8, scale=2.0, std=0.2): + super(GaussianMixture, self).__init__() + + angles = np.linspace(0, 2 * np.pi, n_components, endpoint=False) + centers = scale * np.column_stack((np.cos(angles), np.sin(angles))) + + samples_per_component = n_samples // n_components + self.data = [] + + for center in centers: + samples = np.random.normal( + loc=center, scale=std, size=(samples_per_component, 2) + ) + self.data.append(samples) + + self.data = np.vstack(self.data).astype(np.float32) + np.random.shuffle(self.data) + + self.data = paddle.to_tensor(self.data) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx] + + +def visualize_samples(real_samples, fake_samples, save_path=None): + """ + Visualize real and generated samples. + + Args: + real_samples: Real data samples + fake_samples: Generated data samples + save_path: Path to save the visualization (default: None) + """ + plt.figure(figsize=(12, 6)) + + plt.subplot(1, 2, 1) + plt.scatter(real_samples[:, 0], real_samples[:, 1], alpha=0.5) + plt.title("Real Samples") + plt.xlim(-3, 3) + plt.ylim(-3, 3) + + plt.subplot(1, 2, 2) + plt.scatter(fake_samples[:, 0], fake_samples[:, 1], alpha=0.5) + plt.title("Generated Samples") + plt.xlim(-3, 3) + plt.ylim(-3, 3) + + plt.tight_layout() + + if save_path: + plt.savefig(save_path) + else: + plt.show() + + plt.close() + + +def main(): + """ + Main function to train WGAN-GP on toy dataset. + """ + output_dir = "output/toy" + os.makedirs(output_dir, exist_ok=True) + + dataset = GaussianMixture(n_samples=10000, n_components=8) + + generator = ToyGenerator(noise_dim=2, output_dim=2, hidden_dim=128) + discriminator = ToyDiscriminator(input_dim=2, hidden_dim=128) + + wgan_gp = WGAN_GP( + generator=generator, + discriminator=discriminator, + lambda_gp=10.0, + critic_iters=5, + ) + + data_loader = paddle.io.DataLoader( + dataset, + batch_size=64, + shuffle=True, + ) + + g_optimizer = paddle.optimizer.Adam( + parameters=generator.parameters(), + learning_rate=1e-4, + beta1=0.5, + beta2=0.9, + ) + + d_optimizer = paddle.optimizer.Adam( + parameters=discriminator.parameters(), + learning_rate=1e-4, + beta1=0.5, + beta2=0.9, + ) + + history = { + "g_loss": [], + "d_loss": [], + } + + iterations = 10000 + save_interval = 1000 + data_loader_iter = iter(data_loader) + + for iteration in range(iterations): + try: + real_data = next(data_loader_iter) + except StopIteration: + data_loader_iter = iter(data_loader) + real_data = next(data_loader_iter) + + step_results = wgan_gp.train_step(real_data, g_optimizer, d_optimizer) + + history["g_loss"].append(step_results["g_loss"]) + history["d_loss"].append(step_results["d_loss"]) + + if iteration % 100 == 0: + print( + f"Iteration {iteration}: g_loss = {step_results['g_loss']:.4f}, d_loss = {step_results['d_loss']:.4f}" + ) + + if iteration % save_interval == 0 or iteration == iterations - 1: + with paddle.no_grad(): + fake_samples = wgan_gp.generate(1000, noise_dim=2) + + visualize_samples( + dataset.data.numpy(), + fake_samples.numpy(), + save_path=f"{output_dir}/samples_{iteration}.png", + ) + + paddle.save( + generator.state_dict(), f"{output_dir}/generator_{iteration}.pdparams" + ) + paddle.save( + discriminator.state_dict(), + f"{output_dir}/discriminator_{iteration}.pdparams", + ) + + plt.figure(figsize=(10, 5)) + plt.plot(history["g_loss"], label="Generator Loss") + plt.plot(history["d_loss"], label="Discriminator Loss") + plt.xlabel("Iterations") + plt.ylabel("Loss") + plt.legend() + plt.grid(True) + plt.savefig(f"{output_dir}/loss_curves.png") + plt.close() + + +if __name__ == "__main__": + main() diff --git a/examples/wgan_gp/models/__init__.py b/examples/wgan_gp/models/__init__.py new file mode 100644 index 000000000..a470ff5e3 --- /dev/null +++ b/examples/wgan_gp/models/__init__.py @@ -0,0 +1,5 @@ +""" +WGAN-GP models module. + +This module contains the implementation of GAN models including WGAN-GP. +""" diff --git a/examples/wgan_gp/models/base_gan.py b/examples/wgan_gp/models/base_gan.py new file mode 100644 index 000000000..d795febf6 --- /dev/null +++ b/examples/wgan_gp/models/base_gan.py @@ -0,0 +1,160 @@ +import abc + +import paddle + + +class BaseGAN(abc.ABC): + """ + Base class for GAN implementations. + + This abstract class defines the common interface for all GAN variants. + """ + + def __init__(self, generator, discriminator): + """ + Initialize the GAN with generator and discriminator networks. + + Args: + generator: Generator network + discriminator: Discriminator network + """ + self.generator = generator + self.discriminator = discriminator + + @abc.abstractmethod + def generator_loss(self, fake_output): + """ + Calculate the generator loss. + + Args: + fake_output: Discriminator output for fake samples + + Returns: + Generator loss value + """ + pass + + @abc.abstractmethod + def discriminator_loss(self, real_output, fake_output): + """ + Calculate the discriminator loss. + + Args: + real_output: Discriminator output for real samples + fake_output: Discriminator output for fake samples + + Returns: + Discriminator loss value + """ + pass + + @abc.abstractmethod + def train_step(self, real_data, g_optimizer, d_optimizer): + """ + Perform a single training step. + + Args: + real_data: Batch of real data + g_optimizer: Generator optimizer + d_optimizer: Discriminator optimizer + + Returns: + Dictionary of loss values and metrics + """ + pass + + def generate(self, num_samples, noise_dim=100): + """ + Generate samples using the generator. + + Args: + num_samples: Number of samples to generate + noise_dim: Dimension of the noise vector (default: 100) + + Returns: + Generated samples + """ + noise = paddle.randn([num_samples, noise_dim]) + return self.generator(noise) + + def train( + self, + train_data, + batch_size=64, + iterations=10000, + g_learning_rate=1e-4, + d_learning_rate=1e-4, + save_interval=1000, + save_path=None, + ): + """ + Train the GAN model. + + Args: + train_data: Training dataset + batch_size: Batch size for training (default: 64) + iterations: Number of training iterations (default: 10000) + g_learning_rate: Generator learning rate (default: 1e-4) + d_learning_rate: Discriminator learning rate (default: 1e-4) + save_interval: Interval for saving samples and model (default: 1000) + save_path: Path to save samples and model (default: None) + + Returns: + Dictionary of training history + """ + data_loader = paddle.io.DataLoader( + train_data, + batch_size=batch_size, + shuffle=True, + ) + + g_optimizer = paddle.optimizer.Adam( + parameters=self.generator.parameters(), + learning_rate=g_learning_rate, + beta1=0.5, + beta2=0.9, + ) + + d_optimizer = paddle.optimizer.Adam( + parameters=self.discriminator.parameters(), + learning_rate=d_learning_rate, + beta1=0.5, + beta2=0.9, + ) + + history = { + "g_loss": [], + "d_loss": [], + } + + data_loader_iter = iter(data_loader) + + for iteration in range(iterations): + try: + real_data = next(data_loader_iter) + except StopIteration: + data_loader_iter = iter(data_loader) + real_data = next(data_loader_iter) + + step_results = self.train_step(real_data, g_optimizer, d_optimizer) + + history["g_loss"].append(step_results["g_loss"]) + history["d_loss"].append(step_results["d_loss"]) + + if save_path is not None and iteration % save_interval == 0: + samples = self.generate(16) + + from utils.visualization import save_image_grid + + save_image_grid(samples, f"{save_path}/samples_{iteration}.png") + + paddle.save( + self.generator.state_dict(), + f"{save_path}/generator_{iteration}.pdparams", + ) + paddle.save( + self.discriminator.state_dict(), + f"{save_path}/discriminator_{iteration}.pdparams", + ) + + return history diff --git a/examples/wgan_gp/models/wgan.py b/examples/wgan_gp/models/wgan.py new file mode 100644 index 000000000..a27fc52c8 --- /dev/null +++ b/examples/wgan_gp/models/wgan.py @@ -0,0 +1,106 @@ +import paddle + +from .base_gan import BaseGAN + + +class WGAN(BaseGAN): + """ + Wasserstein GAN implementation. + + This class implements the Wasserstein GAN as described in the paper + "Wasserstein GAN" by Arjovsky et al. + """ + + def __init__(self, generator, discriminator, clip_value=0.01): + """ + Initialize the WGAN with generator and discriminator networks. + + Args: + generator: Generator network + discriminator: Discriminator network + clip_value: Value for weight clipping (default: 0.01) + """ + super(WGAN, self).__init__(generator, discriminator) + self.clip_value = clip_value + + def generator_loss(self, fake_output): + """ + Calculate the generator loss. + + Args: + fake_output: Discriminator output for fake samples + + Returns: + Generator loss value + """ + return -paddle.mean(fake_output) + + def discriminator_loss(self, real_output, fake_output): + """ + Calculate the discriminator loss. + + Args: + real_output: Discriminator output for real samples + fake_output: Discriminator output for fake samples + + Returns: + Discriminator loss value + """ + return paddle.mean(fake_output) - paddle.mean(real_output) + + def _clip_weights(self): + """ + Clip discriminator weights to enforce Lipschitz constraint. + """ + for param in self.discriminator.parameters(): + param.set_value(paddle.clip(param, -self.clip_value, self.clip_value)) + + def train_step(self, real_data, g_optimizer, d_optimizer, critic_iters=5): + """ + Perform a single training step. + + Args: + real_data: Batch of real data + g_optimizer: Generator optimizer + d_optimizer: Discriminator optimizer + critic_iters: Number of discriminator updates per generator update (default: 5) + + Returns: + Dictionary of loss values and metrics + """ + batch_size = real_data.shape[0] + noise_dim = 100 # Default noise dimension + + d_loss_sum = 0 + for _ in range(critic_iters): + noise = paddle.randn([batch_size, noise_dim]) + fake_data = self.generator(noise) + + real_output = self.discriminator(real_data) + fake_output = self.discriminator(fake_data) + + d_loss = self.discriminator_loss(real_output, fake_output) + d_loss_sum += d_loss.item() + + d_optimizer.clear_grad() + d_loss.backward() + d_optimizer.step() + + self._clip_weights() + + d_loss_avg = d_loss_sum / critic_iters + + noise = paddle.randn([batch_size, noise_dim]) + fake_data = self.generator(noise) + fake_output = self.discriminator(fake_data) + + g_loss = self.generator_loss(fake_output) + + g_optimizer.clear_grad() + g_loss.backward() + g_optimizer.step() + + return { + "g_loss": g_loss.item(), + "d_loss": d_loss_avg, + } diff --git a/examples/wgan_gp/models/wgan_gp.py b/examples/wgan_gp/models/wgan_gp.py new file mode 100644 index 000000000..18a585f0b --- /dev/null +++ b/examples/wgan_gp/models/wgan_gp.py @@ -0,0 +1,141 @@ +import paddle + +from .base_gan import BaseGAN + + +class WGAN_GP(BaseGAN): + """ + Wasserstein GAN with Gradient Penalty implementation. + + This class implements the Wasserstein GAN with Gradient Penalty as described + in the paper "Improved Training of Wasserstein GANs" by Gulrajani et al. + """ + + def __init__(self, generator, discriminator, lambda_gp=10.0, critic_iters=5): + """ + Initialize the WGAN-GP with generator and discriminator networks. + + Args: + generator: Generator network + discriminator: Discriminator network + lambda_gp: Gradient penalty coefficient (default: 10.0) + critic_iters: Number of discriminator updates per generator update (default: 5) + """ + super(WGAN_GP, self).__init__(generator, discriminator) + self.lambda_gp = lambda_gp + self.critic_iters = critic_iters + + def generator_loss(self, fake_output): + """ + Calculate the generator loss. + + Args: + fake_output: Discriminator output for fake samples + + Returns: + Generator loss value + """ + return -paddle.mean(fake_output) + + def discriminator_loss(self, real_output, fake_output, gradient_penalty): + """ + Calculate the discriminator loss with gradient penalty. + + Args: + real_output: Discriminator output for real samples + fake_output: Discriminator output for fake samples + gradient_penalty: Gradient penalty value + + Returns: + Discriminator loss value + """ + return ( + paddle.mean(fake_output) + - paddle.mean(real_output) + + self.lambda_gp * gradient_penalty + ) + + def gradient_penalty(self, real_samples, fake_samples): + """ + Calculate the gradient penalty. + + Args: + real_samples: Real data samples + fake_samples: Generated data samples + + Returns: + Gradient penalty value + """ + batch_size = real_samples.shape[0] + + alpha = paddle.rand(shape=[batch_size, 1, 1, 1]) + + interpolates = real_samples + alpha * (fake_samples - real_samples) + interpolates.stop_gradient = False + + disc_interpolates = self.discriminator(interpolates) + + gradients = paddle.grad( + outputs=disc_interpolates, + inputs=interpolates, + grad_outputs=paddle.ones_like(disc_interpolates), + create_graph=False, + retain_graph=True, + )[0] + + gradients_norm = paddle.sqrt( + paddle.sum(paddle.square(gradients), axis=[1, 2, 3]) + ) + + gradient_penalty = paddle.mean(paddle.square(gradients_norm - 1.0)) + + return gradient_penalty + + def train_step(self, real_data, g_optimizer, d_optimizer): + """ + Perform a single training step. + + Args: + real_data: Batch of real data + g_optimizer: Generator optimizer + d_optimizer: Discriminator optimizer + + Returns: + Dictionary of loss values and metrics + """ + batch_size = real_data.shape[0] + noise_dim = 100 # Default noise dimension + + d_loss_sum = 0 + for _ in range(self.critic_iters): + noise = paddle.randn([batch_size, noise_dim]) + fake_data = self.generator(noise) + + real_output = self.discriminator(real_data) + fake_output = self.discriminator(fake_data) + + gp = self.gradient_penalty(real_data, fake_data) + + d_loss = self.discriminator_loss(real_output, fake_output, gp) + d_loss_sum += d_loss.item() + + d_optimizer.clear_grad() + d_loss.backward() + d_optimizer.step() + + d_loss_avg = d_loss_sum / self.critic_iters + + noise = paddle.randn([batch_size, noise_dim]) + fake_data = self.generator(noise) + fake_output = self.discriminator(fake_data) + + g_loss = self.generator_loss(fake_output) + + g_optimizer.clear_grad() + g_loss.backward() + g_optimizer.step() + + return { + "g_loss": g_loss.item(), + "d_loss": d_loss_avg, + } diff --git a/examples/wgan_gp/utils/__init__.py b/examples/wgan_gp/utils/__init__.py new file mode 100644 index 000000000..065f9866c --- /dev/null +++ b/examples/wgan_gp/utils/__init__.py @@ -0,0 +1,5 @@ +""" +PaddleScience utilities module. + +This module provides utility functions for PaddleScience models. +""" diff --git a/examples/wgan_gp/utils/losses.py b/examples/wgan_gp/utils/losses.py new file mode 100644 index 000000000..37baae5ed --- /dev/null +++ b/examples/wgan_gp/utils/losses.py @@ -0,0 +1,34 @@ +import paddle + + +def generator_loss(fake_output): + """ + WGAN-GP generator loss function. + + Args: + fake_output: Discriminator output for fake samples + + Returns: + Generator loss value + """ + return -paddle.mean(fake_output) + + +def discriminator_loss(real_output, fake_output, gradient_penalty, lambda_gp=10.0): + """ + WGAN-GP discriminator loss function with gradient penalty. + + Args: + real_output: Discriminator output for real samples + fake_output: Discriminator output for fake samples + gradient_penalty: Gradient penalty value + lambda_gp: Gradient penalty coefficient (default: 10.0) + + Returns: + Discriminator loss value + """ + return ( + paddle.mean(fake_output) + - paddle.mean(real_output) + + lambda_gp * gradient_penalty + ) diff --git a/examples/wgan_gp/utils/metrics.py b/examples/wgan_gp/utils/metrics.py new file mode 100644 index 000000000..de71ee1d1 --- /dev/null +++ b/examples/wgan_gp/utils/metrics.py @@ -0,0 +1,70 @@ +import matplotlib.pyplot as plt +import numpy as np +import paddle + + +def save_image_grid(images, path, nrow=8, padding=2, normalize=True): + """ + Save a grid of images to a file. + + Args: + images: Tensor of images to display + path: Path to save the image grid + nrow: Number of images per row (default: 8) + padding: Padding between images (default: 2) + normalize: Whether to normalize images to [0, 1] (default: True) + """ + if isinstance(images, paddle.Tensor): + images = images.numpy() + + if normalize: + images = (images - images.min()) / (images.max() - images.min() + 1e-8) + + nmaps = images.shape[0] + xmaps = min(nrow, nmaps) + ymaps = int(np.ceil(float(nmaps) / xmaps)) + height, width = int(images.shape[1] + padding), int(images.shape[2] + padding) + + grid = np.zeros( + (height * ymaps + padding, width * xmaps + padding, 3), dtype=np.uint8 + ) + k = 0 + for y in range(ymaps): + for x in range(xmaps): + if k >= nmaps: + break + image = images[k] + if image.shape[-1] == 1: + image = np.repeat(image, 3, axis=-1) + image = (image * 255).astype(np.uint8) + grid[ + y * height + padding : (y + 1) * height, + x * width + padding : (x + 1) * width, + ] = image + k += 1 + + plt.figure(figsize=(10, 10)) + plt.imshow(grid) + plt.axis("off") + plt.savefig(path, bbox_inches="tight") + plt.close() + + +def plot_loss_curves(g_losses, d_losses, path): + """ + Plot generator and discriminator loss curves. + + Args: + g_losses: List of generator losses + d_losses: List of discriminator losses + path: Path to save the plot + """ + plt.figure(figsize=(10, 5)) + plt.plot(g_losses, label="Generator Loss") + plt.plot(d_losses, label="Discriminator Loss") + plt.xlabel("Iterations") + plt.ylabel("Loss") + plt.legend() + plt.grid(True) + plt.savefig(path) + plt.close() diff --git a/examples/wgan_gp/utils/visualization.py b/examples/wgan_gp/utils/visualization.py new file mode 100644 index 000000000..de71ee1d1 --- /dev/null +++ b/examples/wgan_gp/utils/visualization.py @@ -0,0 +1,70 @@ +import matplotlib.pyplot as plt +import numpy as np +import paddle + + +def save_image_grid(images, path, nrow=8, padding=2, normalize=True): + """ + Save a grid of images to a file. + + Args: + images: Tensor of images to display + path: Path to save the image grid + nrow: Number of images per row (default: 8) + padding: Padding between images (default: 2) + normalize: Whether to normalize images to [0, 1] (default: True) + """ + if isinstance(images, paddle.Tensor): + images = images.numpy() + + if normalize: + images = (images - images.min()) / (images.max() - images.min() + 1e-8) + + nmaps = images.shape[0] + xmaps = min(nrow, nmaps) + ymaps = int(np.ceil(float(nmaps) / xmaps)) + height, width = int(images.shape[1] + padding), int(images.shape[2] + padding) + + grid = np.zeros( + (height * ymaps + padding, width * xmaps + padding, 3), dtype=np.uint8 + ) + k = 0 + for y in range(ymaps): + for x in range(xmaps): + if k >= nmaps: + break + image = images[k] + if image.shape[-1] == 1: + image = np.repeat(image, 3, axis=-1) + image = (image * 255).astype(np.uint8) + grid[ + y * height + padding : (y + 1) * height, + x * width + padding : (x + 1) * width, + ] = image + k += 1 + + plt.figure(figsize=(10, 10)) + plt.imshow(grid) + plt.axis("off") + plt.savefig(path, bbox_inches="tight") + plt.close() + + +def plot_loss_curves(g_losses, d_losses, path): + """ + Plot generator and discriminator loss curves. + + Args: + g_losses: List of generator losses + d_losses: List of discriminator losses + path: Path to save the plot + """ + plt.figure(figsize=(10, 5)) + plt.plot(g_losses, label="Generator Loss") + plt.plot(d_losses, label="Discriminator Loss") + plt.xlabel("Iterations") + plt.ylabel("Loss") + plt.legend() + plt.grid(True) + plt.savefig(path) + plt.close()