diff --git a/example/BinaryDM/README.md b/example/BinaryDM/README.md new file mode 100644 index 00000000..a6ef1c52 --- /dev/null +++ b/example/BinaryDM/README.md @@ -0,0 +1,31 @@ +# BinaryDM in PaddlePaddle + +## 1. 简介 + +本示例介绍了一种权重二值化的扩散模型的训练方法。通过可学习多基二值化器和低秩表示模仿来增强二值扩散模型的表征能力并提高优化表现,能支持将扩散模型应用于极限资源任务场景中。 + +技术详情见论文 [BinaryDM: Accurate Weight Binarization for Efficient Diffusion Models](https://arxiv.org/pdf/2404.05662v4) + +![binarydm](.\imgs\binarydm.png) + +## 2.训练 + +### 2.1 环境准备 + +- paddlepaddle>=2.0.1 (paddlepaddle-gpu>=2.0.1) +- visualdl +- lmdb + +### 2.2 启动训练 + +``` +python main_binarydm.py --config {DATASET}.yml --exp {PROJECT_PATH} --doc {MODEL_NAME} --ni +``` + +## 致谢 + +本实现源于下列开源仓库: + +- [https://github.com/Xingyu-Zheng/BinaryDM](https://github.com/Xingyu-Zheng/BinaryDM) (official implementation of BinaryDM). +- [https://openi.pcl.ac.cn/iMon/ddim-paddle](https://openi.pcl.ac.cn/iMon/ddim-paddle) (PaddlePaddle version for DDIM). +- [https://github.com/ermongroup/ddim](https://github.com/ermongroup/ddim) (code structure). diff --git a/example/BinaryDM/configs/bedroom.yml b/example/BinaryDM/configs/bedroom.yml new file mode 100644 index 00000000..7130768a --- /dev/null +++ b/example/BinaryDM/configs/bedroom.yml @@ -0,0 +1,50 @@ +data: + dataset: "LSUN" + category: "bedroom" + image_size: 256 + channels: 3 + logit_transform: false + uniform_dequantization: false + gaussian_dequantization: false + random_flip: true + rescaled: true + num_workers: 32 + +model: + type: "simple" + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 1, 2, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [16, ] + dropout: 0.0 + var_type: fixedsmall + ema_rate: 0.999 + ema: True + resamp_with_conv: True + +diffusion: + beta_schedule: linear + beta_start: 0.0001 + beta_end: 0.02 + num_diffusion_timesteps: 1000 + +training: + batch_size: 64 + n_epochs: 10000 + n_iters: 5000000 + snapshot_freq: 5000 + validation_freq: 2000 + +sampling: + batch_size: 32 + last_only: True + +optim: + weight_decay: 0.000 + optimizer: "Adam" + lr: 0.00002 + beta1: 0.9 + amsgrad: false + eps: 0.00000001 diff --git a/example/BinaryDM/configs/celeba.yml b/example/BinaryDM/configs/celeba.yml new file mode 100644 index 00000000..84be514e --- /dev/null +++ b/example/BinaryDM/configs/celeba.yml @@ -0,0 +1,50 @@ +data: + dataset: "CELEBA" + image_size: 64 + channels: 3 + logit_transform: false + uniform_dequantization: false + gaussian_dequantization: false + random_flip: true + rescaled: true + num_workers: 4 + +model: + type: "simple" + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 2, 2, 4] + num_res_blocks: 2 + attn_resolutions: [16, ] + dropout: 0.1 + var_type: fixedlarge + ema_rate: 0.9999 + ema: True + resamp_with_conv: True + +diffusion: + beta_schedule: linear + beta_start: 0.0001 + beta_end: 0.02 + num_diffusion_timesteps: 1000 + +training: + batch_size: 128 + n_epochs: 10000 + n_iters: 5000000 + snapshot_freq: 5000 + validation_freq: 20000 + +sampling: + batch_size: 32 + last_only: True + +optim: + weight_decay: 0.000 + optimizer: "Adam" + lr: 0.0002 + beta1: 0.9 + amsgrad: false + eps: 0.00000001 + grad_clip: 1.0 diff --git a/example/BinaryDM/configs/church.yml b/example/BinaryDM/configs/church.yml new file mode 100644 index 00000000..d2f35277 --- /dev/null +++ b/example/BinaryDM/configs/church.yml @@ -0,0 +1,50 @@ +data: + dataset: "LSUN" + category: "church_outdoor" + image_size: 256 + channels: 3 + logit_transform: false + uniform_dequantization: false + gaussian_dequantization: false + random_flip: true + rescaled: true + num_workers: 32 + +model: + type: "simple" + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 1, 2, 2, 4, 4] + num_res_blocks: 2 + attn_resolutions: [16, ] + dropout: 0.0 + var_type: fixedsmall + ema_rate: 0.999 + ema: True + resamp_with_conv: True + +diffusion: + beta_schedule: linear + beta_start: 0.0001 + beta_end: 0.02 + num_diffusion_timesteps: 1000 + +training: + batch_size: 64 + n_epochs: 10000 + n_iters: 5000000 + snapshot_freq: 5000 + validation_freq: 2000 + +sampling: + batch_size: 32 + last_only: True + +optim: + weight_decay: 0.000 + optimizer: "Adam" + lr: 0.00002 + beta1: 0.9 + amsgrad: false + eps: 0.00000001 diff --git a/example/BinaryDM/configs/cifar10.yml b/example/BinaryDM/configs/cifar10.yml new file mode 100644 index 00000000..0e48f55f --- /dev/null +++ b/example/BinaryDM/configs/cifar10.yml @@ -0,0 +1,50 @@ +data: + dataset: "CIFAR10" + image_size: 32 + channels: 3 + logit_transform: false + uniform_dequantization: false + gaussian_dequantization: false + random_flip: true + rescaled: true + num_workers: 4 + +model: + type: "simple" + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 2, 2] + num_res_blocks: 2 + attn_resolutions: [16, ] + dropout: 0.1 + var_type: fixedlarge + ema_rate: 0.9999 + ema: True + resamp_with_conv: True + +diffusion: + beta_schedule: linear + beta_start: 0.0001 + beta_end: 0.02 + num_diffusion_timesteps: 1000 + +training: + batch_size: 128 + n_epochs: 10000 + n_iters: 5000000 + snapshot_freq: 5000 + validation_freq: 2000 + +sampling: + batch_size: 64 + last_only: True + +optim: + weight_decay: 0.000 + optimizer: "Adam" + lr: 0.0002 + beta1: 0.9 + amsgrad: false + eps: 0.00000001 + grad_clip: 1.0 diff --git a/example/BinaryDM/configs/cifar10_improved.yml b/example/BinaryDM/configs/cifar10_improved.yml new file mode 100644 index 00000000..e06191e6 --- /dev/null +++ b/example/BinaryDM/configs/cifar10_improved.yml @@ -0,0 +1,51 @@ +data: + dataset: "CIFAR10" + image_size: 32 + channels: 3 + logit_transform: false + uniform_dequantization: false + gaussian_dequantization: false + random_flip: true + rescaled: true + num_workers: 4 + +model: + type: "simple" + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [1, 2, 2, 2] + num_res_blocks: 2 + attn_resolutions: [16, ] + dropout: 0.1 + var_type: fixedlarge + ema_rate: 0.9999 + ema: True + resamp_with_conv: True + use_scale_shift_norm: True + +diffusion: + beta_schedule: cosine + beta_start: null + beta_end: null + num_diffusion_timesteps: 1000 + +training: + batch_size: 128 + n_epochs: 10000 + n_iters: 5000000 + snapshot_freq: 5000 + validation_freq: 2000 + +sampling: + batch_size: 64 + last_only: True + +optim: + weight_decay: 0.000 + optimizer: "Adam" + lr: 0.0002 + beta1: 0.9 + amsgrad: false + eps: 0.00000001 + grad_clip: 1.0 diff --git a/example/BinaryDM/datasets/__init__.py b/example/BinaryDM/datasets/__init__.py new file mode 100644 index 00000000..ba5b6d17 --- /dev/null +++ b/example/BinaryDM/datasets/__init__.py @@ -0,0 +1,219 @@ +import os +import paddle +import numbers +import paddle.vision.transforms as transforms +import paddle.vision.transforms.functional as F +from paddle.vision.datasets import Cifar10 +from datasets.celeba import CelebA +from datasets.ffhq import FFHQ +from datasets.lsun import LSUN +from paddle.io import Subset +import numpy as np + + +class Crop(object): + def __init__(self, x1, x2, y1, y2): + self.x1 = x1 + self.x2 = x2 + self.y1 = y1 + self.y2 = y2 + + def __call__(self, img): + return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1) + + def __repr__(self): + return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format( + self.x1, self.x2, self.y1, self.y2 + ) + + +def get_dataset(args, config): + if config.data.random_flip is False: + tran_transform = test_transform = transforms.Compose( + [transforms.Resize([config.data.image_size]*2), transforms.Transpose(), lambda x: x if x.dtype != np.uint8 else x.astype('float32')/255.0] + ) + else: + tran_transform = transforms.Compose( + [ + transforms.Resize([config.data.image_size]*2), + transforms.RandomHorizontalFlip(prob=0.5), + transforms.Transpose(), lambda x: x if x.dtype != np.uint8 else x.astype('float32')/255.0, + ] + ) + test_transform = transforms.Compose( + [transforms.Resize([config.data.image_size]*2), transforms.Transpose(), lambda x: x if x.dtype != np.uint8 else x.astype('float32')/255.0] + ) + + # args.data_path = '/home/xingyu-zheng/laboratory/data/cifar10/cifar-10-python.tar.gz' + args.data_path = "D:/Laboratory/data/cifar10/cifar-10-python.tar.gz" + if config.data.dataset == "CIFAR10": + dataset = Cifar10( + # os.path.join(args.exp, "datasets", "cifar10"), + data_file=args.data_path, + mode="train", + download=True, + transform=tran_transform, + ) + test_dataset = Cifar10( + # os.path.join(args.exp, "datasets", "cifar10_test"), + data_file=args.data_path, + mode="test", + download=True, + transform=test_transform, + ) + + elif config.data.dataset == "CELEBA": + cx = 89 + cy = 121 + x1 = cy - 64 + x2 = cy + 64 + y1 = cx - 64 + y2 = cx + 64 + if config.data.random_flip: + dataset = CelebA( + root=os.path.join(args.exp, "datasets", "celeba"), + split="train", + transform=transforms.Compose( + [ + Crop(x1, x2, y1, y2), + transforms.Resize([config.data.image_size]*2), + transforms.RandomHorizontalFlip(), + transforms.Transpose(), lambda x: x if x.dtype != np.uint8 else x.astype('float32')/255.0, + ] + ), + download=True, + ) + else: + dataset = CelebA( + root=os.path.join(args.exp, "datasets", "celeba"), + split="train", + transform=transforms.Compose( + [ + Crop(x1, x2, y1, y2), + transforms.Resize([config.data.image_size]*2), + transforms.Transpose(), lambda x: x if x.dtype != np.uint8 else x.astype('float32')/255.0, + ] + ), + download=True, + ) + + test_dataset = CelebA( + root=os.path.join(args.exp, "datasets", "celeba"), + split="test", + transform=transforms.Compose( + [ + Crop(x1, x2, y1, y2), + transforms.Resize([config.data.image_size]*2), + transforms.Transpose(), lambda x: x if x.dtype != np.uint8 else x.astype('float32')/255.0, + ] + ), + download=True, + ) + + elif config.data.dataset == "LSUN": + train_folder = "{}_train".format(config.data.category) + val_folder = "{}_val".format(config.data.category) + if config.data.random_flip: + dataset = LSUN( + root=os.path.join(args.exp, "datasets", "lsun"), + classes=[train_folder], + transform=transforms.Compose( + [ + transforms.Resize([config.data.image_size]*2), + transforms.CenterCrop((config.data.image_size,)*2), + transforms.RandomHorizontalFlip(prob=0.5), + transforms.Transpose(), lambda x: x if x.dtype != np.uint8 else x.astype('float32')/255.0, + ] + ), + ) + else: + dataset = LSUN( + root=os.path.join(args.exp, "datasets", "lsun"), + classes=[train_folder], + transform=transforms.Compose( + [ + transforms.Resize([config.data.image_size]*2), + transforms.CenterCrop((config.data.image_size,)*2), + transforms.Transpose(), lambda x: x if x.dtype != np.uint8 else x.astype('float32')/255.0, + ] + ), + ) + + test_dataset = LSUN( + root=os.path.join(args.exp, "datasets", "lsun"), + classes=[val_folder], + transform=transforms.Compose( + [ + transforms.Resize([config.data.image_size]*2), + transforms.CenterCrop((config.data.image_size,)*2), + transforms.Transpose(), lambda x: x if x.dtype != np.uint8 else x.astype('float32')/255.0, + ] + ), + ) + + elif config.data.dataset == "FFHQ": + if config.data.random_flip: + dataset = FFHQ( + path=os.path.join(args.exp, "datasets", "FFHQ"), + transform=transforms.Compose( + [transforms.RandomHorizontalFlip(prob=0.5), transforms.Transpose(), lambda x: x if x.dtype != np.uint8 else x.astype('float32')/255.0] + ), + resolution=config.data.image_size, + ) + else: + dataset = FFHQ( + path=os.path.join(args.exp, "datasets", "FFHQ"), + transform=transforms.Compose(transforms.Transpose(), lambda x: x if x.dtype != np.uint8 else x.astype('float32')/255.0), + resolution=config.data.image_size, + ) + + num_items = len(dataset) + indices = list(range(num_items)) + random_state = np.random.get_state() + np.random.seed(2019) + np.random.shuffle(indices) + np.random.set_state(random_state) + train_indices, test_indices = ( + indices[: int(num_items * 0.9)], + indices[int(num_items * 0.9) :], + ) + test_dataset = Subset(dataset, test_indices) + dataset = Subset(dataset, train_indices) + else: + dataset, test_dataset = None, None + + return dataset, test_dataset + + +def logit_transform(image, lam=1e-6): + image = lam + (1 - 2 * lam) * image + return paddle.log(image) - paddle.log1p(-image) + + +def data_transform(config, X): + if config.data.uniform_dequantization: + X = X / 256.0 * 255.0 + paddle.rand(X.shape) / 256.0 + if config.data.gaussian_dequantization: + X = X + paddle.randn(X.shape) * 0.01 + + if config.data.rescaled: + X = 2 * X - 1.0 + elif config.data.logit_transform: + X = logit_transform(X) + + if hasattr(config, "image_mean"): + return X - config.image_mean.unsqueeze(0) + + return X + + +def inverse_data_transform(config, X): + if hasattr(config, "image_mean"): + X = X + config.image_mean.unsqueeze(0) + + if config.data.logit_transform: + X = paddle.nn.functional.sigmoid(X) + elif config.data.rescaled: + X = (X + 1.0) / 2.0 + + return paddle.clip(X, 0.0, 1.0) diff --git a/example/BinaryDM/datasets/celeba.py b/example/BinaryDM/datasets/celeba.py new file mode 100644 index 00000000..f80afb2a --- /dev/null +++ b/example/BinaryDM/datasets/celeba.py @@ -0,0 +1,163 @@ +import paddle +import os +import PIL +from .vision import VisionDataset +from .utils import download_file_from_google_drive, check_integrity + + +class CelebA(VisionDataset): + """`Large-scale CelebFaces Attributes (CelebA) Dataset `_ Dataset. + + Args: + root (string): Root directory where images are downloaded to. + split (string): One of {'train', 'valid', 'test'}. + Accordingly dataset is selected. + target_type (string or list, optional): Type of target to use, ``attr``, ``identity``, ``bbox``, + or ``landmarks``. Can also be a list to output a tuple with all specified target types. + The targets represent: + ``attr`` (np.array shape=(40,) dtype=int): binary (0, 1) labels for attributes + ``identity`` (int): label for each person (data points with the same identity are the same person) + ``bbox`` (np.array shape=(4,) dtype=int): bounding box (x, y, width, height) + ``landmarks`` (np.array shape=(10,) dtype=int): landmark points (lefteye_x, lefteye_y, righteye_x, + righteye_y, nose_x, nose_y, leftmouth_x, leftmouth_y, rightmouth_x, rightmouth_y) + Defaults to ``attr``. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.ToTensor`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + """ + + base_folder = "celeba" + # There currently does not appear to be a easy way to extract 7z in python (without introducing additional + # dependencies). The "in-the-wild" (not aligned+cropped) images are only in 7z, so they are not available + # right now. + file_list = [ + # File ID MD5 Hash Filename + ("0B7EVK8r0v71pZjFTYXZWM3FlRnM", "00d2c5bc6d35e252742224ab0c1e8fcb", "img_align_celeba.zip"), + # ("0B7EVK8r0v71pbWNEUjJKdDQ3dGc", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_align_celeba_png.7z"), + # ("0B7EVK8r0v71peklHb0pGdDl6R28", "b6cd7e93bc7a96c2dc33f819aa3ac651", "img_celeba.7z"), + ("0B7EVK8r0v71pblRyaVFSWGxPY0U", "75e246fa4810816ffd6ee81facbd244c", "list_attr_celeba.txt"), + ("1_ee_0u7vcNLOfNLegJRHmolfH5ICW-XS", "32bd1bd63d3c78cd57e08160ec5ed1e2", "identity_CelebA.txt"), + ("0B7EVK8r0v71pbThiMVRxWXZ4dU0", "00566efa6fedff7a56946cd1c10f1c16", "list_bbox_celeba.txt"), + ("0B7EVK8r0v71pd0FJY3Blby1HUTQ", "cc24ecafdb5b50baae59b03474781f8c", "list_landmarks_align_celeba.txt"), + # ("0B7EVK8r0v71pTzJIdlJWdHczRlU", "063ee6ddb681f96bc9ca28c6febb9d1a", "list_landmarks_celeba.txt"), + ("0B7EVK8r0v71pY0NSMzRuSXJEVkk", "d32c9cbf5e040fd4025c592c306e6668", "list_eval_partition.txt"), + ] + + def __init__(self, root, + split="train", + target_type="attr", + transform=None, target_transform=None, + download=False): + import pandas + super(CelebA, self).__init__(root) + self.split = split + if isinstance(target_type, list): + self.target_type = target_type + else: + self.target_type = [target_type] + self.transform = transform + self.target_transform = target_transform + + if download: + self.download() + + if not self._check_integrity(): + raise RuntimeError('Dataset not found or corrupted.' + + ' You can use download=True to download it') + + self.transform = transform + self.target_transform = target_transform + + if split.lower() == "train": + split = 0 + elif split.lower() == "valid": + split = 1 + elif split.lower() == "test": + split = 2 + else: + raise ValueError('Wrong split entered! Please use split="train" ' + 'or split="valid" or split="test"') + + with open(os.path.join(self.root, self.base_folder, "list_eval_partition.txt"), "r") as f: + splits = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) + + with open(os.path.join(self.root, self.base_folder, "identity_CelebA.txt"), "r") as f: + self.identity = pandas.read_csv(f, delim_whitespace=True, header=None, index_col=0) + + with open(os.path.join(self.root, self.base_folder, "list_bbox_celeba.txt"), "r") as f: + self.bbox = pandas.read_csv(f, delim_whitespace=True, header=1, index_col=0) + + with open(os.path.join(self.root, self.base_folder, "list_landmarks_align_celeba.txt"), "r") as f: + self.landmarks_align = pandas.read_csv(f, delim_whitespace=True, header=1) + + with open(os.path.join(self.root, self.base_folder, "list_attr_celeba.txt"), "r") as f: + self.attr = pandas.read_csv(f, delim_whitespace=True, header=1) + + mask = (splits[1] == split) + self.filename = splits[mask].index.values + self.identity = paddle.to_tensor(self.identity[mask].values) + self.bbox = paddle.to_tensor(self.bbox[mask].values) + self.landmarks_align = paddle.to_tensor(self.landmarks_align[mask].values) + self.attr = paddle.to_tensor(self.attr[mask].values) + self.attr = (self.attr + 1) // 2 # map from {-1, 1} to {0, 1} + + def _check_integrity(self): + for (_, md5, filename) in self.file_list: + fpath = os.path.join(self.root, self.base_folder, filename) + _, ext = os.path.splitext(filename) + # Allow original archive to be deleted (zip and 7z) + # Only need the extracted images + if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5): + return False + + # Should check a hash of the images + return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba")) + + def download(self): + import zipfile + + if self._check_integrity(): + print('Files already downloaded and verified') + return + + for (file_id, md5, filename) in self.file_list: + download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5) + + with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f: + f.extractall(os.path.join(self.root, self.base_folder)) + + def __getitem__(self, index): + X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index])) + + target = [] + for t in self.target_type: + if t == "attr": + target.append(self.attr[index, :]) + elif t == "identity": + target.append(self.identity[index, 0]) + elif t == "bbox": + target.append(self.bbox[index, :]) + elif t == "landmarks": + target.append(self.landmarks_align[index, :]) + else: + raise ValueError("Target type \"{}\" is not recognized.".format(t)) + target = tuple(target) if len(target) > 1 else target[0] + + if self.transform is not None: + X = self.transform(X) + + if self.target_transform is not None: + target = self.target_transform(target) + + return X, target + + def __len__(self): + return len(self.attr) + + def extra_repr(self): + lines = ["Target type: {target_type}", "Split: {split}"] + return '\n'.join(lines).format(**self.__dict__) diff --git a/example/BinaryDM/datasets/ffhq.py b/example/BinaryDM/datasets/ffhq.py new file mode 100644 index 00000000..29e7aff3 --- /dev/null +++ b/example/BinaryDM/datasets/ffhq.py @@ -0,0 +1,41 @@ +from io import BytesIO + +# import lmdb +from PIL import Image +from paddle.io import Dataset + + +class FFHQ(Dataset): + def __init__(self, path, transform, resolution=8): + self.env = lmdb.open( + path, + max_readers=32, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + + if not self.env: + raise IOError('Cannot open lmdb dataset', path) + + with self.env.begin(write=False) as txn: + self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) + + self.resolution = resolution + self.transform = transform + + def __len__(self): + return self.length + + def __getitem__(self, index): + with self.env.begin(write=False) as txn: + key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') + img_bytes = txn.get(key) + + buffer = BytesIO(img_bytes) + img = Image.open(buffer) + img = self.transform(img) + target = 0 + + return img, target \ No newline at end of file diff --git a/example/BinaryDM/datasets/lsun.py b/example/BinaryDM/datasets/lsun.py new file mode 100644 index 00000000..5fa453c1 --- /dev/null +++ b/example/BinaryDM/datasets/lsun.py @@ -0,0 +1,174 @@ +from .vision import VisionDataset +from PIL import Image +import os +import os.path +import io +from collections.abc import Iterable +import pickle + + +class LSUNClass(VisionDataset): + def __init__(self, root, transform=None, target_transform=None): + import lmdb + + super(LSUNClass, self).__init__( + root, transform=transform, target_transform=target_transform + ) + + self.env = lmdb.open( + root, + max_readers=1, + readonly=True, + lock=False, + readahead=False, + meminit=False, + ) + with self.env.begin(write=False) as txn: + self.length = txn.stat()["entries"] + root_split = root.split("/") + cache_file = os.path.join("/".join(root_split[:-1]), f"_cache_{root_split[-1]}") + if os.path.isfile(cache_file): + self.keys = pickle.load(open(cache_file, "rb")) + else: + with self.env.begin(write=False) as txn: + self.keys = [key for key, _ in txn.cursor()] + pickle.dump(self.keys, open(cache_file, "wb")) + + def __getitem__(self, index): + img, target = None, None + env = self.env + with env.begin(write=False) as txn: + imgbuf = txn.get(self.keys[index]) + + buf = io.BytesIO() + buf.write(imgbuf) + buf.seek(0) + img = Image.open(buf).convert("RGB") + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target + + def __len__(self): + return self.length + + +class LSUN(VisionDataset): + """ + `LSUN `_ dataset. + + Args: + root (string): Root directory for the database files. + classes (string or list): One of {'train', 'val', 'test'} or a list of + categories to load. e,g. ['bedroom_train', 'church_outdoor_train']. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + def __init__(self, root, classes="train", transform=None, target_transform=None): + super(LSUN, self).__init__( + root, transform=transform, target_transform=target_transform + ) + self.classes = self._verify_classes(classes) + + # for each class, create an LSUNClassDataset + self.dbs = [] + for c in self.classes: + self.dbs.append( + LSUNClass(root=root + "/" + c + "_lmdb", transform=transform) + ) + + self.indices = [] + count = 0 + for db in self.dbs: + count += len(db) + self.indices.append(count) + + self.length = count + + def _verify_classes(self, classes): + categories = [ + "bedroom", + "bridge", + "church_outdoor", + "classroom", + "conference_room", + "dining_room", + "kitchen", + "living_room", + "restaurant", + "tower", + ] + dset_opts = ["train", "val", "test"] + + try: + # verify_str_arg(classes, "classes", dset_opts) + if classes == "test": + classes = [classes] + else: + classes = [c + "_" + classes for c in categories] + except ValueError: + if not isinstance(classes, Iterable): + msg = ( + "Expected type str or Iterable for argument classes, " + "but got type {}." + ) + raise ValueError(msg.format(type(classes))) + + classes = list(classes) + msg_fmtstr = ( + "Expected type str for elements in argument classes, " + "but got type {}." + ) + for c in classes: + # verify_str_arg(c, custom_msg=msg_fmtstr.format(type(c))) + c_short = c.split("_") + category, dset_opt = "_".join(c_short[:-1]), c_short[-1] + + msg_fmtstr = "Unknown value '{}' for {}. Valid values are {{{}}}." + # msg = msg_fmtstr.format( + # category, "LSUN class", iterable_to_str(categories) + # ) + # verify_str_arg(category, valid_values=categories, custom_msg=msg) + + # msg = msg_fmtstr.format(dset_opt, "postfix", iterable_to_str(dset_opts)) + # verify_str_arg(dset_opt, valid_values=dset_opts, custom_msg=msg) + + return classes + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: Tuple (image, target) where target is the index of the target category. + """ + target = 0 + sub = 0 + for ind in self.indices: + if index < ind: + break + target += 1 + sub = ind + + db = self.dbs[target] + index = index - sub + + if self.target_transform is not None: + target = self.target_transform(target) + + img, _ = db[index] + return img, target + + def __len__(self): + return self.length + + def extra_repr(self): + return "Classes: {classes}".format(**self.__dict__) diff --git a/example/BinaryDM/datasets/utils.py b/example/BinaryDM/datasets/utils.py new file mode 100644 index 00000000..e9c82aa6 --- /dev/null +++ b/example/BinaryDM/datasets/utils.py @@ -0,0 +1,186 @@ +import os +import os.path +import hashlib +import errno +from tqdm import tqdm + + +def gen_bar_updater(): + pbar = tqdm(total=None) + + def bar_update(count, block_size, total_size): + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) + + return bar_update + + +def check_integrity(fpath, md5=None): + if md5 is None: + return True + if not os.path.isfile(fpath): + return False + md5o = hashlib.md5() + with open(fpath, 'rb') as f: + # read in 1MB chunks + for chunk in iter(lambda: f.read(1024 * 1024), b''): + md5o.update(chunk) + md5c = md5o.hexdigest() + if md5c != md5: + return False + return True + + +def makedir_exist_ok(dirpath): + """ + Python2 support for os.makedirs(.., exist_ok=True) + """ + try: + os.makedirs(dirpath) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + +def download_url(url, root, filename=None, md5=None): + """Download a file from a url and place it in root. + + Args: + url (str): URL to download file from + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. If None, use the basename of the URL + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + from six.moves import urllib + + root = os.path.expanduser(root) + if not filename: + filename = os.path.basename(url) + fpath = os.path.join(root, filename) + + makedir_exist_ok(root) + + # downloads file + if os.path.isfile(fpath) and check_integrity(fpath, md5): + print('Using downloaded and verified file: ' + fpath) + else: + try: + print('Downloading ' + url + ' to ' + fpath) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater() + ) + except OSError: + if url[:5] == 'https': + url = url.replace('https:', 'http:') + print('Failed download. Trying https -> http instead.' + ' Downloading ' + url + ' to ' + fpath) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater() + ) + + +def list_dir(root, prefix=False): + """List all directories at a given root + + Args: + root (str): Path to directory whose folders need to be listed + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the directories found + """ + root = os.path.expanduser(root) + directories = list( + filter( + lambda p: os.path.isdir(os.path.join(root, p)), + os.listdir(root) + ) + ) + + if prefix is True: + directories = [os.path.join(root, d) for d in directories] + + return directories + + +def list_files(root, suffix, prefix=False): + """List all files ending with a suffix at a given root + + Args: + root (str): Path to directory whose folders need to be listed + suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png'). + It uses the Python "str.endswith" method and is passed directly + prefix (bool, optional): If true, prepends the path to each result, otherwise + only returns the name of the files found + """ + root = os.path.expanduser(root) + files = list( + filter( + lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix), + os.listdir(root) + ) + ) + + if prefix is True: + files = [os.path.join(root, d) for d in files] + + return files + + +def download_file_from_google_drive(file_id, root, filename=None, md5=None): + """Download a Google Drive file from and place it in root. + + Args: + file_id (str): id of file to be downloaded + root (str): Directory to place downloaded file in + filename (str, optional): Name to save the file under. If None, use the id of the file. + md5 (str, optional): MD5 checksum of the download. If None, do not check + """ + # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url + import requests + url = "https://docs.google.com/uc?export=download" + + root = os.path.expanduser(root) + if not filename: + filename = file_id + fpath = os.path.join(root, filename) + + makedir_exist_ok(root) + + if os.path.isfile(fpath) and check_integrity(fpath, md5): + print('Using downloaded and verified file: ' + fpath) + else: + session = requests.Session() + + response = session.get(url, params={'id': file_id}, stream=True) + token = _get_confirm_token(response) + + if token: + params = {'id': file_id, 'confirm': token} + response = session.get(url, params=params, stream=True) + + _save_response_content(response, fpath) + + +def _get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + + return None + + +def _save_response_content(response, destination, chunk_size=32768): + with open(destination, "wb") as f: + pbar = tqdm(total=None) + progress = 0 + for chunk in response.iter_content(chunk_size): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + progress += len(chunk) + pbar.update(progress - pbar.n) + pbar.close() diff --git a/example/BinaryDM/datasets/vision.py b/example/BinaryDM/datasets/vision.py new file mode 100644 index 00000000..9bcb78e0 --- /dev/null +++ b/example/BinaryDM/datasets/vision.py @@ -0,0 +1,84 @@ +import os +import paddle +import paddle.io as data + + +class VisionDataset(data.Dataset): + _repr_indent = 4 + + def __init__(self, root, transforms=None, transform=None, target_transform=None): + if isinstance(root, str): + root = os.path.expanduser(root) + self.root = root + + has_transforms = transforms is not None + has_separate_transform = transform is not None or target_transform is not None + if has_transforms and has_separate_transform: + raise ValueError("Only transforms or transform/target_transform can " + "be passed as argument") + + # for backwards-compatibility + self.transform = transform + self.target_transform = target_transform + + if has_separate_transform: + transforms = StandardTransform(transform, target_transform) + self.transforms = transforms + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def __repr__(self): + head = "Dataset " + self.__class__.__name__ + body = ["Number of datapoints: {}".format(self.__len__())] + if self.root is not None: + body.append("Root location: {}".format(self.root)) + body += self.extra_repr().splitlines() + if hasattr(self, 'transform') and self.transform is not None: + body += self._format_transform_repr(self.transform, + "Transforms: ") + if hasattr(self, 'target_transform') and self.target_transform is not None: + body += self._format_transform_repr(self.target_transform, + "Target transforms: ") + lines = [head] + [" " * self._repr_indent + line for line in body] + return '\n'.join(lines) + + def _format_transform_repr(self, transform, head): + lines = transform.__repr__().splitlines() + return (["{}{}".format(head, lines[0])] + + ["{}{}".format(" " * len(head), line) for line in lines[1:]]) + + def extra_repr(self): + return "" + + +class StandardTransform(object): + def __init__(self, transform=None, target_transform=None): + self.transform = transform + self.target_transform = target_transform + + def __call__(self, input, target): + if self.transform is not None: + input = self.transform(input) + if self.target_transform is not None: + target = self.target_transform(target) + return input, target + + def _format_transform_repr(self, transform, head): + lines = transform.__repr__().splitlines() + return (["{}{}".format(head, lines[0])] + + ["{}{}".format(" " * len(head), line) for line in lines[1:]]) + + def __repr__(self): + body = [self.__class__.__name__] + if self.transform is not None: + body += self._format_transform_repr(self.transform, + "Transform: ") + if self.target_transform is not None: + body += self._format_transform_repr(self.target_transform, + "Target transform: ") + + return '\n'.join(body) diff --git a/example/BinaryDM/functions/__init__.py b/example/BinaryDM/functions/__init__.py new file mode 100644 index 00000000..f15ec18f --- /dev/null +++ b/example/BinaryDM/functions/__init__.py @@ -0,0 +1,17 @@ +import paddle.nn as nn +import paddle.optimizer as optim + + +def get_optimizer(config, parameters): + clip = nn.ClipGradByNorm(clip_norm=config.optim.grad_clip) + if config.optim.optimizer == 'Adam': + return optim.Adam(parameters=parameters, learning_rate=config.optim.lr, weight_decay=config.optim.weight_decay, + beta1=config.optim.beta1, beta2=0.999, + epsilon=config.optim.eps, grad_clip=clip) + elif config.optim.optimizer == 'RMSProp': + return optim.RMSprop(parameters=parameters, learning_rate=config.optim.lr, weight_decay=config.optim.weight_decay, grad_clip=clip) + elif config.optim.optimizer == 'SGD': + return optim.SGD(parameters=parameters, learning_rate=config.optim.lr, momentum=0.9, grad_clip=clip) + else: + raise NotImplementedError( + 'Optimizer {} not understood.'.format(config.optim.optimizer)) diff --git a/example/BinaryDM/functions/ckpt_util.py b/example/BinaryDM/functions/ckpt_util.py new file mode 100644 index 00000000..87b405c3 --- /dev/null +++ b/example/BinaryDM/functions/ckpt_util.py @@ -0,0 +1,89 @@ +import os, hashlib +import requests +from tqdm import tqdm + +URL_MAP = { + "cifar10": "https://heibox.uni-heidelberg.de/f/869980b53bf5416c8a28/?dl=1", + "ema_cifar10": "https://heibox.uni-heidelberg.de/f/2e4f01e2d9ee49bab1d5/?dl=1", + "lsun_bedroom": "https://heibox.uni-heidelberg.de/f/f179d4f21ebc4d43bbfe/?dl=1", + "ema_lsun_bedroom": "https://heibox.uni-heidelberg.de/f/b95206528f384185889b/?dl=1", + "lsun_cat": "https://heibox.uni-heidelberg.de/f/fac870bd988348eab88e/?dl=1", + "ema_lsun_cat": "https://heibox.uni-heidelberg.de/f/0701aac3aa69457bbe34/?dl=1", + "lsun_church": "https://heibox.uni-heidelberg.de/f/2711a6f712e34b06b9d8/?dl=1", + "ema_lsun_church": "https://heibox.uni-heidelberg.de/f/44ccb50ef3c6436db52e/?dl=1", +} +CKPT_MAP = { + "cifar10": "diffusion_cifar10_model/model-790000.ckpt", + "ema_cifar10": "ema_diffusion_cifar10_model/model-790000.ckpt", + "lsun_bedroom": "diffusion_lsun_bedroom_model/model-2388000.ckpt", + "ema_lsun_bedroom": "ema_diffusion_lsun_bedroom_model/model-2388000.ckpt", + "lsun_cat": "diffusion_lsun_cat_model/model-1761000.ckpt", + "ema_lsun_cat": "ema_diffusion_lsun_cat_model/model-1761000.ckpt", + "lsun_church": "diffusion_lsun_church_model/model-4432000.ckpt", + "ema_lsun_church": "ema_diffusion_lsun_church_model/model-4432000.ckpt", +} +MD5_MAP = { + "cifar10": "82ed3067fd1002f5cf4c339fb80c4669", + "ema_cifar10": "1fa350b952534ae442b1d5235cce5cd3", + "lsun_bedroom": "f70280ac0e08b8e696f42cb8e948ff1c", + "ema_lsun_bedroom": "1921fa46b66a3665e450e42f36c2720f", + "lsun_cat": "bbee0e7c3d7abfb6e2539eaf2fb9987b", + "ema_lsun_cat": "646f23f4821f2459b8bafc57fd824558", + "lsun_church": "eb619b8a5ab95ef80f94ce8a5488dae3", + "ema_lsun_church": "fdc68a23938c2397caba4a260bc2445f", +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root=None, check=False): + if 'church_outdoor' in name: + name = name.replace('church_outdoor', 'church') + assert name in URL_MAP + # Modify the path when necessary + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + root = ( + root + if root is not None + else os.path.join(cachedir, "diffusion_models_converted") + ) + path = os.path.join(root, CKPT_MAP[name]) + + if os.path.exists(path + ".pdl"): + return path + ".pdl" + + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + + import torch + import paddle + state_dict = torch.load(path, map_location="cpu") + state_dict_pp = {} + for key, value in state_dict.items(): + value = paddle.to_tensor(value.numpy()) + if ('temb_proj' in key or 'dense' in key) and len(value.shape) == 2: + value = value.transpose([1,0]) + state_dict_pp[key] = value + path = path + ".pdl" + paddle.save(state_dict_pp, path) + + return path diff --git a/example/BinaryDM/functions/denoising.py b/example/BinaryDM/functions/denoising.py new file mode 100644 index 00000000..c493a47f --- /dev/null +++ b/example/BinaryDM/functions/denoising.py @@ -0,0 +1,67 @@ +import paddle + + +def compute_alpha(beta, t): + beta = paddle.concat([paddle.zeros([1]), beta], 0) + a = (1 - beta).cumprod(0).index_select(t + 1, 0).reshape([-1, 1, 1, 1]) + return a + + +def generalized_steps(x, seq, model, b, **kwargs): + with paddle.no_grad(): + n = x.shape[0] + seq_next = [-1] + list(seq[:-1]) + x0_preds = [] + xs = [x] + for i, j in zip(reversed(seq), reversed(seq_next)): + t = (paddle.ones([n]) * i) + next_t = (paddle.ones([n]) * j) + at = compute_alpha(b, t.astype('int64')) + at_next = compute_alpha(b, next_t.astype('int64')) + xt = xs[-1] + et = model(xt, t) + x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() + x0_preds.append(x0_t) + c1 = ( + kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt() + ) + c2 = ((1 - at_next) - c1 ** 2).sqrt() + xt_next = at_next.sqrt() * x0_t + c1 * paddle.randn(x.shape) + c2 * et + xs.append(xt_next) + + return xs, x0_preds + + +def ddpm_steps(x, seq, model, b, **kwargs): + with paddle.no_grad(): + n = x.shape[0] + seq_next = [-1] + list(seq[:-1]) + xs = [x] + x0_preds = [] + betas = b + for i, j in zip(reversed(seq), reversed(seq_next)): + t = (paddle.ones([n]) * i) + next_t = (paddle.ones([n]) * j) + at = compute_alpha(betas, t.astype('int64')) + atm1 = compute_alpha(betas, next_t.astype('int64')) + beta_t = 1 - at / atm1 + x = xs[-1] + + output = model(x, t.astype('float32')) + e = output + + x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e + x0_from_e = paddle.clip(x0_from_e, -1, 1) + x0_preds.append(x0_from_e) + mean_eps = ( + (atm1.sqrt() * beta_t) * x0_from_e + ((1 - beta_t).sqrt() * (1 - atm1)) * x + ) / (1.0 - at) + + mean = mean_eps + noise = paddle.randn(x.shape) + mask = 1 - (t == 0).astype('float32') + mask = mask.reshape([-1, 1, 1, 1]) + logvar = beta_t.log() + sample = mean + mask * paddle.exp(0.5 * logvar) * noise + xs.append(sample) + return xs, x0_preds diff --git a/example/BinaryDM/functions/losses.py b/example/BinaryDM/functions/losses.py new file mode 100644 index 00000000..4ce40546 --- /dev/null +++ b/example/BinaryDM/functions/losses.py @@ -0,0 +1,20 @@ +import paddle + + +def noise_estimation_loss(model, + x0: paddle.Tensor, + t: paddle.Tensor, + e: paddle.Tensor, + b: paddle.Tensor, keepdim=False): + a = (1-b).cumprod(0).index_select(t, 0).reshape((-1, 1, 1, 1)) + x = x0 * a.sqrt() + e * (1.0 - a).sqrt() + output = model(x, t.astype('float32')) + if keepdim: + return (e - output).square().sum((1, 2, 3)) + else: + return (e - output).square().sum((1, 2, 3)).mean(0) + + +loss_registry = { + 'simple': noise_estimation_loss, +} diff --git a/example/BinaryDM/imgs/binarydm.png b/example/BinaryDM/imgs/binarydm.png new file mode 100644 index 00000000..30fb053b Binary files /dev/null and b/example/BinaryDM/imgs/binarydm.png differ diff --git a/example/BinaryDM/main_binarydm.py b/example/BinaryDM/main_binarydm.py new file mode 100644 index 00000000..be0c7cb2 --- /dev/null +++ b/example/BinaryDM/main_binarydm.py @@ -0,0 +1,228 @@ +import argparse +import traceback +import shutil +import logging +import yaml +import sys +import os +import paddle +import numpy as np +import visualdl as vdl + +from runners.diffusion_binarydm import Diffusion + +paddle.set_printoptions(sci_mode=False) + + +def parse_args_and_config(): + parser = argparse.ArgumentParser(description=globals()["__doc__"]) + + parser.add_argument( + "--config", type=str, required=True, help="Path to the config file" + ) + parser.add_argument("--seed", type=int, default=1234, help="Random seed") + parser.add_argument( + "--exp", type=str, default="exp", help="Path for saving running related data." + ) + parser.add_argument( + "--doc", + type=str, + required=True, + help="A string for documentation purpose. " + "Will be the name of the log folder.", + ) + parser.add_argument( + "--comment", type=str, default="", help="A string for experiment comment" + ) + parser.add_argument( + "--verbose", + type=str, + default="info", + help="Verbose level: info | debug | warning | critical", + ) + parser.add_argument("--test", action="store_true", help="Whether to test the model") + parser.add_argument( + "--sample", + action="store_true", + help="Whether to produce samples from the model", + ) + parser.add_argument("--fid", action="store_true") + parser.add_argument("--interpolation", action="store_true") + parser.add_argument( + "--resume_training", action="store_true", help="Whether to resume training" + ) + parser.add_argument( + "-i", + "--image_folder", + type=str, + default="images", + help="The folder name of samples", + ) + parser.add_argument( + "--ni", + action="store_true", + help="No interaction. Suitable for Slurm Job launcher", + ) + parser.add_argument("--use_pretrained", action="store_true") + parser.add_argument( + "--sample_type", + type=str, + default="generalized", + help="sampling approach (generalized or ddpm_noisy)", + ) + parser.add_argument( + "--skip_type", + type=str, + default="uniform", + help="skip according to (uniform or quadratic)", + ) + parser.add_argument( + "--timesteps", type=int, default=1000, help="number of steps involved" + ) + parser.add_argument( + "--eta", + type=float, + default=0.0, + help="eta used to control the variances of sigma", + ) + parser.add_argument("--sequence", action="store_true") + + args = parser.parse_args() + args.log_path = os.path.join(args.exp, "logs", args.doc) + + # parse config file + with open(os.path.join("configs", args.config), "r") as f: + config = yaml.safe_load(f) + new_config = dict2namespace(config) + + vdl_path = os.path.join(args.exp, "visualdl", args.doc) + + if not args.test and not args.sample: + if not args.resume_training: + if os.path.exists(args.log_path): + overwrite = False + if args.ni: + overwrite = True + else: + response = input("Folder already exists. Overwrite? (Y/N)") + if response.upper() == "Y": + overwrite = True + + if overwrite: + shutil.rmtree(args.log_path) + shutil.rmtree(vdl_path) + os.makedirs(args.log_path) + if os.path.exists(vdl_path): + shutil.rmtree(vdl_path) + else: + print("Folder exists. Program halted.") + sys.exit(0) + else: + os.makedirs(args.log_path) + + with open(os.path.join(args.log_path, "config.yml"), "w") as f: + yaml.dump(new_config, f, default_flow_style=False) + + new_config.vdl_logger = vdl.LogWriter(log_dir=vdl_path) + # setup logger + level = getattr(logging, args.verbose.upper(), None) + if not isinstance(level, int): + raise ValueError("level {} not supported".format(args.verbose)) + + handler1 = logging.StreamHandler() + handler2 = logging.FileHandler(os.path.join(args.log_path, "stdout.txt")) + formatter = logging.Formatter( + "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" + ) + handler1.setFormatter(formatter) + handler2.setFormatter(formatter) + logger = logging.getLogger() + logger.addHandler(handler1) + logger.addHandler(handler2) + logger.setLevel(level) + + else: + level = getattr(logging, args.verbose.upper(), None) + if not isinstance(level, int): + raise ValueError("level {} not supported".format(args.verbose)) + + handler1 = logging.StreamHandler() + formatter = logging.Formatter( + "%(levelname)s - %(filename)s - %(asctime)s - %(message)s" + ) + handler1.setFormatter(formatter) + logger = logging.getLogger() + logger.addHandler(handler1) + logger.setLevel(level) + + if args.sample: + os.makedirs(os.path.join(args.exp, "image_samples"), exist_ok=True) + args.image_folder = os.path.join( + args.exp, "image_samples", args.image_folder + ) + if not os.path.exists(args.image_folder): + os.makedirs(args.image_folder) + else: + if not (args.fid or args.interpolation): + overwrite = False + if args.ni: + overwrite = True + else: + response = input( + f"Image folder {args.image_folder} already exists. Overwrite? (Y/N)" + ) + if response.upper() == "Y": + overwrite = True + + if overwrite: + shutil.rmtree(args.image_folder) + os.makedirs(args.image_folder) + else: + print("Output image folder exists. Program halted.") + sys.exit(0) + + # add device + device = paddle.get_device() + logging.info("Using device: {}".format(device)) + new_config.device = device + + # set random seed + paddle.seed(args.seed) + np.random.seed(args.seed) + + return args, new_config + + +def dict2namespace(config): + namespace = argparse.Namespace() + for key, value in config.items(): + if isinstance(value, dict): + new_value = dict2namespace(value) + else: + new_value = value + setattr(namespace, key, new_value) + return namespace + + +def main(): + args, config = parse_args_and_config() + logging.info("Writing log file to {}".format(args.log_path)) + logging.info("Exp instance id = {}".format(os.getpid())) + logging.info("Exp comment = {}".format(args.comment)) + + try: + runner = Diffusion(args, config) + if args.sample: + runner.sample() + elif args.test: + runner.test() + else: + runner.train() + except Exception: + logging.error(traceback.format_exc()) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/example/BinaryDM/models/diffusion.py b/example/BinaryDM/models/diffusion.py new file mode 100644 index 00000000..7c7fdf05 --- /dev/null +++ b/example/BinaryDM/models/diffusion.py @@ -0,0 +1,388 @@ +import math +import paddle +import paddle.nn as nn + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = paddle.exp(paddle.arange(half_dim, dtype='float32') * - emb) + emb = timesteps.astype('float32').unsqueeze(1) * emb.unsqueeze(0) + emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], 1) + if embedding_dim % 2 == 1: # zero pad + emb = paddle.nn.functional.pad(emb, [0, 1, 0, 0]) + return emb + + +def spatial_fold(input, fold): + if fold == 1: + return input + + batch, channel, height, width = input.shape + h_fold = height // fold + w_fold = width // fold + + return ( + input.reshape((batch, channel, h_fold, fold, w_fold, fold)) + .transpose((0, 1, 3, 5, 2, 4)) + .reshape((batch, -1, h_fold, w_fold)) + ) + + +def spatial_unfold(input, unfold): + if unfold == 1: + return input + + batch, channel, height, width = input.shape + h_unfold = height * unfold + w_unfold = width * unfold + + return ( + input.reshape((batch, -1, unfold, unfold, height, width)) + .transpose((0, 1, 4, 2, 5, 3)) + .reshape((batch, -1, h_unfold, w_unfold)) + ) + + +def nonlinearity(x): + # swish + return x*paddle.nn.functional.sigmoid(x) + + +def Normalize(in_channels): + return paddle.nn.GroupNorm(num_groups=32, num_channels=in_channels, epsilon=1e-6) + + +class Upsample(nn.Layer): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = paddle.nn.Conv2D(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = paddle.nn.functional.interpolate( + x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Layer): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in paddle conv, must do it ourselves + self.conv = paddle.nn.Conv2D(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = [0, 1, 0, 1] + x = paddle.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = paddle.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Layer): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512, use_scale_shift_norm=False): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = paddle.nn.Conv2D(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + self.temb_proj = paddle.nn.Linear(temb_channels, + out_channels * 2 if use_scale_shift_norm else out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = paddle.nn.Dropout(dropout) + self.conv2 = paddle.nn.Conv2D(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = paddle.nn.Conv2D(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = paddle.nn.Conv2D(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + self.use_scale_shift_norm = use_scale_shift_norm + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + emb = self.temb_proj(nonlinearity(temb)).unsqueeze(-1).unsqueeze(-1) + if self.use_scale_shift_norm: + shift, scale = emb.split(2, 1) + h = self.norm2(h) * (1 + scale) + shift + else: + h = h + emb + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Layer): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = paddle.nn.Conv2D(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = paddle.nn.Conv2D(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = paddle.nn.Conv2D(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = paddle.nn.Conv2D(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape([b, c, h*w]) + q = q.transpose([0, 2, 1]) # b,hw,c + k = k.reshape([b, c, h*w]) # b,c,hw + w_ = paddle.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = paddle.nn.functional.softmax(w_, 2) + + # attend to values + v = v.reshape([b, c, h*w]) + w_ = w_.transpose([0, 2, 1]) # b,hw,hw (first hw of k, second of q) + # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = paddle.bmm(v, w_) + h_ = h_.reshape([b, c, h, w]) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Model(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult) + num_res_blocks = config.model.num_res_blocks + attn_resolutions = config.model.attn_resolutions + dropout = config.model.dropout + in_channels = config.model.in_channels + resolution = config.data.image_size + resamp_with_conv = config.model.resamp_with_conv + num_timesteps = config.diffusion.num_diffusion_timesteps + use_scale_shift_norm = config.model.use_scale_shift_norm if 'use_scale_shift_norm' in config.model else False + fold = config.model.fold if 'fold' in config.model else 1 + cond_channels = config.model.cond_channels if 'cond_channels' in config.model else 0 + + if config.model.type == 'bayesian': + self.logvar = self.create_parameter([num_timesteps,], default_initializer=nn.initializer.Constant(0.0)) + + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.cond_channels = cond_channels + self.fold = fold + + # timestep embedding + self.temb = nn.Layer() + self.temb.dense = nn.LayerList([ + paddle.nn.Linear(self.ch, + self.temb_ch), + paddle.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = paddle.nn.Conv2D((in_channels + cond_channels)*fold**2, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+ch_mult + self.down = nn.LayerList() + block_in = None + for i_level in range(self.num_resolutions): + block = nn.LayerList() + attn = nn.LayerList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + use_scale_shift_norm=use_scale_shift_norm)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Layer() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Layer() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + use_scale_shift_norm=use_scale_shift_norm) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + use_scale_shift_norm=use_scale_shift_norm) + + # upsampling + self.up = [] + for i_level in reversed(range(self.num_resolutions)): + block = nn.LayerList() + attn = nn.LayerList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + use_scale_shift_norm=use_scale_shift_norm)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Layer() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + self.up = nn.LayerList(self.up) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = paddle.nn.Conv2D(block_in, + out_ch*fold**2, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t): + assert x.shape[2] == x.shape[3] == self.resolution + + # timestep embedding + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + + # downsampling + x = spatial_fold(x, self.fold) + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + paddle.concat([h, hs.pop()], 1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = spatial_unfold(h, self.fold) + return h diff --git a/example/BinaryDM/models/diffusion_binarydm.py b/example/BinaryDM/models/diffusion_binarydm.py new file mode 100644 index 00000000..b4c623a4 --- /dev/null +++ b/example/BinaryDM/models/diffusion_binarydm.py @@ -0,0 +1,389 @@ +import math +import paddle +import paddle.nn as nn +from models.utils_binarydm import BNNConv2d +Conv2D = BNNConv2d + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = paddle.exp(paddle.arange(half_dim, dtype='float32') * - emb) + emb = timesteps.astype('float32').unsqueeze(1) * emb.unsqueeze(0) + emb = paddle.concat([paddle.sin(emb), paddle.cos(emb)], 1) + if embedding_dim % 2 == 1: # zero pad + emb = paddle.nn.functional.pad(emb, [0, 1, 0, 0]) + return emb + + +def spatial_fold(input, fold): + if fold == 1: + return input + + batch, channel, height, width = input.shape + h_fold = height // fold + w_fold = width // fold + + return ( + input.reshape((batch, channel, h_fold, fold, w_fold, fold)) + .transpose((0, 1, 3, 5, 2, 4)) + .reshape((batch, -1, h_fold, w_fold)) + ) + + +def spatial_unfold(input, unfold): + if unfold == 1: + return input + + batch, channel, height, width = input.shape + h_unfold = height * unfold + w_unfold = width * unfold + + return ( + input.reshape((batch, -1, unfold, unfold, height, width)) + .transpose((0, 1, 4, 2, 5, 3)) + .reshape((batch, -1, h_unfold, w_unfold)) + ) + + +def nonlinearity(x): + # swish + return x*paddle.nn.functional.sigmoid(x) + + +def Normalize(in_channels): + return paddle.nn.GroupNorm(num_groups=32, num_channels=in_channels, epsilon=1e-6) + + +class Upsample(nn.Layer): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = Conv2D(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = paddle.nn.functional.interpolate( + x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Layer): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in paddle conv, must do it ourselves + self.conv = Conv2D(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = [0, 1, 0, 1] + x = paddle.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = paddle.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Layer): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512, use_scale_shift_norm=False): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = Conv2D(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + self.temb_proj = paddle.nn.Linear(temb_channels, + out_channels * 2 if use_scale_shift_norm else out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = paddle.nn.Dropout(dropout) + self.conv2 = Conv2D(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = nn.Conv2D(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = nn.Conv2D(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + self.use_scale_shift_norm = use_scale_shift_norm + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + emb = self.temb_proj(nonlinearity(temb)).unsqueeze(-1).unsqueeze(-1) + if self.use_scale_shift_norm: + shift, scale = emb.split(2, 1) + h = self.norm2(h) * (1 + scale) + shift + else: + h = h + emb + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Layer): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = Conv2D(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = Conv2D(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = Conv2D(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = Conv2D(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape([b, c, h*w]) + q = q.transpose([0, 2, 1]) # b,hw,c + k = k.reshape([b, c, h*w]) # b,c,hw + w_ = paddle.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = paddle.nn.functional.softmax(w_, 2) + + # attend to values + v = v.reshape([b, c, h*w]) + w_ = w_.transpose([0, 2, 1]) # b,hw,hw (first hw of k, second of q) + # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = paddle.bmm(v, w_) + h_ = h_.reshape([b, c, h, w]) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Model(nn.Layer): + def __init__(self, config): + super().__init__() + self.config = config + ch, out_ch, ch_mult = config.model.ch, config.model.out_ch, tuple(config.model.ch_mult) + num_res_blocks = config.model.num_res_blocks + attn_resolutions = config.model.attn_resolutions + dropout = config.model.dropout + in_channels = config.model.in_channels + resolution = config.data.image_size + resamp_with_conv = config.model.resamp_with_conv + num_timesteps = config.diffusion.num_diffusion_timesteps + use_scale_shift_norm = config.model.use_scale_shift_norm if 'use_scale_shift_norm' in config.model else False + fold = config.model.fold if 'fold' in config.model else 1 + cond_channels = config.model.cond_channels if 'cond_channels' in config.model else 0 + + if config.model.type == 'bayesian': + self.logvar = self.create_parameter([num_timesteps,], default_initializer=nn.initializer.Constant(0.0)) + + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.cond_channels = cond_channels + self.fold = fold + + # timestep embedding + self.temb = nn.Layer() + self.temb.dense = nn.LayerList([ + paddle.nn.Linear(self.ch, + self.temb_ch), + paddle.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = nn.Conv2D((in_channels + cond_channels)*fold**2, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+ch_mult + self.down = nn.LayerList() + block_in = None + for i_level in range(self.num_resolutions): + block = nn.LayerList() + attn = nn.LayerList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + use_scale_shift_norm=use_scale_shift_norm)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Layer() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Layer() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + use_scale_shift_norm=use_scale_shift_norm) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + use_scale_shift_norm=use_scale_shift_norm) + + # upsampling + self.up = [] + for i_level in reversed(range(self.num_resolutions)): + block = nn.LayerList() + attn = nn.LayerList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + use_scale_shift_norm=use_scale_shift_norm)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Layer() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + self.up = nn.LayerList(self.up) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = nn.Conv2D(block_in, + out_ch*fold**2, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t): + assert x.shape[2] == x.shape[3] == self.resolution + + # timestep embedding + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + + # downsampling + x = spatial_fold(x, self.fold) + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + paddle.concat([h, hs.pop()], 1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + h = spatial_unfold(h, self.fold) + return h diff --git a/example/BinaryDM/models/ema.py b/example/BinaryDM/models/ema.py new file mode 100644 index 00000000..e90a60d5 --- /dev/null +++ b/example/BinaryDM/models/ema.py @@ -0,0 +1,51 @@ +import paddle + + +class EMAHelper(object): + def __init__(self, mu=0.999): + self.mu = mu + self.shadow = {} + + def register(self, module): + if isinstance(module, paddle.DataParallel): + module = module._layers + for name, param in module.named_parameters(): + if not param.stop_gradient: + self.shadow[name] = param.clone().detach() + + def update(self, module): + if isinstance(module, paddle.DataParallel): + module = module._layers + for name, param in module.named_parameters(): + if not param.stop_gradient: + self.shadow[name] = (( + 1. - self.mu) * param + self.mu * paddle.to_tensor(self.shadow[name])).detach() + + def ema(self, module): + if isinstance(module, paddle.DataParallel): + module = module._layers + for name, param in module.named_parameters(): + if not param.stop_gradient: + param.stop_gradient = True + param[:] = self.shadow[name] + param.stop_gradient = False + + def ema_copy(self, module): + if isinstance(module, paddle.DataParallel): + inner_module = module._layers + module_copy = type(inner_module)( + inner_module.config) + module_copy.set_state_dict(inner_module.state_dict()) + module_copy = paddle.DataParallel(module_copy) + else: + module_copy = type(module)(module.config) + module_copy.set_state_dict(module.state_dict()) + # module_copy = copy.deepcopy(module) + self.ema(module_copy) + return module_copy + + def state_dict(self): + return self.shadow + + def set_state_dict(self, state_dict): + self.shadow = state_dict diff --git a/example/BinaryDM/models/utils_binarydm.py b/example/BinaryDM/models/utils_binarydm.py new file mode 100644 index 00000000..886a4d9f --- /dev/null +++ b/example/BinaryDM/models/utils_binarydm.py @@ -0,0 +1,113 @@ +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np +import math + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1D(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2D(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3D(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + +class BNNConv2d(nn.Layer): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, bias=False, dilation=0, transposed=False, output_padding=None, groups=1, precision='bnn', order=2): + super(BNNConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.transposed = transposed + self.output_padding = output_padding + self.groups = groups + self.number_of_weights = in_channels * out_channels * kernel_size * kernel_size + self.shape = [out_channels, in_channels, kernel_size, kernel_size] + # tmp = paddle.rand(self.shape) * 0.001 + self.weight = paddle.create_parameter(shape=self.shape, dtype=paddle.float32) + # tmp = paddle.rand(out_channels) * 0.001 + self.bias = paddle.create_parameter(shape=[out_channels], dtype=paddle.float32) + + self.order = order + self.scaling_first_order = paddle.create_parameter(shape=[out_channels, 1, 1, 1], dtype=paddle.float32) + self.scaling_second_order = paddle.create_parameter(shape=[out_channels, 1, 1, 1], dtype=paddle.float32) + # paddle.create_parameter(shape=tmp.shape, dtype=tmp.dtype, default_initializer=nn.initializer.Assign(tmp)) + # self.sw = None + self.init_scale = False + + self.precision = precision + self.bnn_mode = 'bnn' + + self.binary_act = False + + self.is_int = False + # tmp = paddle.rand([out_channels, 1, 1, 1]) * 0.001 + self.nbits = 8 + self.Qn = -2 ** (self.nbits - 1) + self.Qp = 2 ** (self.nbits - 1) - 1 + self.n_levels = 2 ** self.nbits + + if self.in_channels != self.out_channels: + self.shortcut = nn.Conv2D(self.in_channels, self.out_channels, kernel_size=1, stride=self.stride, padding=0) + tmp = paddle.ones([1]) * 0.3 + self.shortcut_scale = paddle.create_parameter(shape=tmp.shape, dtype=tmp.dtype, default_initializer=nn.initializer.Assign(tmp)) + + def forward(self, x, bnn_mode='bnn'): + + x_raw = x + + if 'full' in [self.precision, self.bnn_mode, bnn_mode]: + return F.conv2d(x, self.weight, stride=self.stride, padding=self.padding, bias=self.bias) + + bw = self.weight + if not self.init_scale: + real_weights = self.weight.reshape(self.shape) + scaling_factor = paddle.mean(paddle.mean(paddle.mean(abs(real_weights),axis=3,keepdim=True),axis=2,keepdim=True),axis=1,keepdim=True) + self.scaling_first_order = paddle.create_parameter(shape=scaling_factor.shape, dtype=scaling_factor.dtype, default_initializer=nn.initializer.Assign(scaling_factor)) + + bw_fp = bw * self.scaling_first_order + bw = (paddle.sign(bw) * self.scaling_first_order).detach() - bw_fp.detach() + bw_fp + + if self.order == 1: + y = F.conv2d(x, bw, stride=self.stride, padding=self.padding, bias=self.bias) + if self.in_channels == self.out_channels: + if x_raw.shape[-1] < y.shape[-1]: + shortcut = F.interpolate(x_raw, scale_factor=2, mode="nearest") + elif x_raw.shape[-1] > y.shape[-1]: + shortcut = avg_pool_nd(2, kernel_size=self.stride, stride=self.stride)(x_raw) + else: + shortcut = x_raw + else: + shortcut = self.shortcut(x_raw) + return y + shortcut * paddle.abs(self.shortcut_scale) + + first_res_bw = self.weight - bw + + if not self.init_scale: + real_first_res = first_res_bw.view(self.shape) + scaling_factor = paddle.create_parameter(shape=real_first_res.shape, dtype=real_first_res.dtype, default_initializer=nn.initializer.Assign(real_first_res)) + self.scaling_second_order.data = scaling_factor + self.init_scale = True + + bw_fp = first_res_bw * self.scaling_second_order + bw = (paddle.sign(first_res_bw) * self.scaling_second_order).detach() - bw_fp.detach() + bw_fp + + y = F.conv2d(x, bw, stride=self.stride, padding=self.padding, bias=self.bias) + + if self.in_channels == self.out_channels: + if x_raw.shape[-1] < y.shape[-1]: + shortcut = F.interpolate(x_raw, scale_factor=2, mode="nearest") + elif x_raw.shape[-1] > y.shape[-1]: + shortcut = avg_pool_nd(2, kernel_size=self.stride, stride=self.stride)(x_raw) + else: + shortcut = x_raw + else: + shortcut = self.shortcut(x_raw) + return y + shortcut * paddle.abs(self.shortcut_scale) diff --git a/example/BinaryDM/run.sh b/example/BinaryDM/run.sh new file mode 100644 index 00000000..6d849304 --- /dev/null +++ b/example/BinaryDM/run.sh @@ -0,0 +1,2 @@ +CUDA_VISIBLE_DEVICES=0 python main.py --config cifar10.yml --exp ./ --doc cifar10-fp --timesteps 100 --ni 2>&1 | tee fp.log +CUDA_VISIBLE_DEVICES=0 python main_binarydm.py --config cifar10.yml --exp ./ --doc cifar10-binarydm --timesteps 100 --ni 2>&1 | tee binarydm.log \ No newline at end of file diff --git a/example/BinaryDM/runners/__init__.py b/example/BinaryDM/runners/__init__.py new file mode 100644 index 00000000..1658b789 --- /dev/null +++ b/example/BinaryDM/runners/__init__.py @@ -0,0 +1,48 @@ +""" +Patch missing operators and missing modules +""" +import paddle + + +if 'cumprod' not in paddle.__dict__: + import numpy as np + from functools import lru_cache + + @lru_cache() + def cumprod_mask(axis_length): + mask = np.ones([axis_length, axis_length]).astype('float32') + mask = np.tril(mask, k=0) + + return paddle.to_tensor(mask) + + def cumprod(x, axis=None): + if axis is None: + x = x.reshape([-1]) + axis = 0 + assert isinstance(axis, int) + + if axis < 0: + axis = len(x.shape) + axis + axis_length = x.shape[axis] + mask = cumprod_mask(axis_length).reshape([*list([1]*axis), -1, axis_length, *list([1]*(len(x.shape)-axis-1))]) + x = x.unsqueeze(axis) + x = x * mask.detach() + (paddle.ones_like(mask) * (1 - mask)).detach() + + return paddle.prod(x, axis=axis+1) + + paddle.cumprod = cumprod + paddle.Tensor.cumprod = lambda self, axis=None: cumprod(self, axis) + +if 'Subset' not in paddle.io.__dict__: + class Subset(paddle.io.Dataset): + def __init__(self, dataset, indices) -> None: + self.dataset = dataset + self.indices = indices + + def __getitem__(self, idx): + return self.dataset[self.indices[idx]] + + def __len__(self): + return len(self.indices) + + paddle.io.Subset = Subset diff --git a/example/BinaryDM/runners/diffusion.py b/example/BinaryDM/runners/diffusion.py new file mode 100644 index 00000000..01c762f3 --- /dev/null +++ b/example/BinaryDM/runners/diffusion.py @@ -0,0 +1,381 @@ +import os +import logging +import time +import glob + +import numpy as np +import tqdm +import paddle +import paddle.io as data + +from models.diffusion import Model +from models.ema import EMAHelper +from functions import get_optimizer +from functions.losses import loss_registry +from datasets import get_dataset, data_transform, inverse_data_transform +from functions.ckpt_util import get_ckpt_path + +import numpy as np +from PIL import Image + + +def paddle2hwcuint8(x, clip=False): + if clip: + x = paddle.clip(x, -1, 1) + x = (x + 1.0) / 2.0 + return x + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + def sigmoid(x): + return 1 / (np.exp(-x) + 1) + + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + elif beta_schedule == "sigmoid": + betas = np.linspace(-6, 6, num_diffusion_timesteps) + betas = sigmoid(betas) * (beta_end - beta_start) + beta_start + elif beta_schedule == "cosine": + betas = betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2, + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +class Diffusion(object): + def __init__(self, args, config, device=None): + self.args = args + self.config = config + if device is None: + device = paddle.get_device() + self.device = device + + self.model_var_type = config.model.var_type + betas = get_beta_schedule( + beta_schedule=config.diffusion.beta_schedule, + beta_start=config.diffusion.beta_start, + beta_end=config.diffusion.beta_end, + num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, + ) + betas = self.betas = paddle.to_tensor(betas).astype('float32') + self.num_timesteps = betas.shape[0] + + alphas = 1.0 - betas + alphas_cumprod = alphas.cumprod(0) + alphas_cumprod_prev = paddle.concat( + [paddle.ones([1]), alphas_cumprod[:-1]], 0 + ) + posterior_variance = ( + betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ) + if self.model_var_type == "fixedlarge": + self.logvar = betas.log() + # paddle.concat( + # [posterior_variance[1:2], betas[1:]], 0).log() + elif self.model_var_type == "fixedsmall": + self.logvar = posterior_variance.clip(min=1e-20).log() + + def train(self): + args, config = self.args, self.config + vdl_logger = self.config.vdl_logger + dataset, test_dataset = get_dataset(args, config) + train_loader = data.DataLoader( + dataset, + batch_size=config.training.batch_size, + shuffle=True, + num_workers=config.data.num_workers, + use_shared_memory=False, + ) + model = Model(config) + + model = model + model = paddle.DataParallel(model) + + optimizer = get_optimizer(self.config, model.parameters()) + + if self.config.model.ema: + ema_helper = EMAHelper(mu=self.config.model.ema_rate) + ema_helper.register(model) + else: + ema_helper = None + + start_epoch, step = 0, 0 + if self.args.resume_training: + states = paddle.load(os.path.join(self.args.log_path, "ckpt.pdl")) + model.set_state_dict({k.split("$model_")[-1]: v for k, v in states.items() if "$model_" in k}) + + optimizer.set_state_dict({k.split("$optimizer_")[-1]: v for k, v in states.items() if "$optimizer_" in k}) + optimizer._epsilon = self.config.optim.eps + start_epoch = states["$epoch"] + step = states["$step"] + if self.config.model.ema: + ema_helper.set_state_dict({k.split("$ema_")[-1]: v for k, v in states.items() if "$ema_" in k}) + + for epoch in range(start_epoch, self.config.training.n_epochs): + data_start = time.time() + data_time = 0 + for i, (x, y) in enumerate(train_loader): + n = x.shape[0] + data_time += time.time() - data_start + model.train() + step += 1 + + x = data_transform(self.config, x) + e = paddle.randn(x.shape) + b = self.betas + + # antithetic sampling + t = paddle.randint( + low=0, high=self.num_timesteps, shape=(n // 2 + 1,) + ) + t = paddle.concat([t, self.num_timesteps - t - 1], 0)[:n] + loss = loss_registry[config.model.type](model, x, t, e, b) + + vdl_logger.add_scalar("loss", loss, step=step) + + logging.info( + f"step: {step}, loss: {loss.numpy()}, data time: {data_time / (i+1)}" + ) + + optimizer.clear_grad() + loss.backward() + optimizer.step() + + if self.config.model.ema: + ema_helper.update(model) + + if step % self.config.training.snapshot_freq == 0 or step == 1: + states = dict( + **{"$model_"+k: v for k, v in model.state_dict().items()}, + **{"$optimizer_"+k: v for k, v in optimizer.state_dict().items()}, + **{"$epoch": epoch}, + **{"$step": step}, + ) + if self.config.model.ema: + states.update({"$ema_"+k: v for k, v in ema_helper.state_dict().items()}) + + paddle.save( + states, + os.path.join(self.args.log_path, "ckpt_{}.pdl".format(step)), + ) + paddle.save(states, os.path.join(self.args.log_path, "ckpt.pdl")) + + data_start = time.time() + + def sample(self): + model = Model(self.config) + + if not self.args.use_pretrained: + if getattr(self.config.sampling, "ckpt_id", None) is None: + states = paddle.load( + os.path.join(self.args.log_path, "ckpt.pdl") + ) + else: + states = paddle.load( + os.path.join( + self.args.log_path, f"ckpt_{self.config.sampling.ckpt_id}.pdl" + ) + ) + model = model + model = paddle.DataParallel(model) + model.set_state_dict({k.split("$model_")[-1]: v for k, v in states.items() if "$model_" in k}) + + if self.config.model.ema: + ema_helper = EMAHelper(mu=self.config.model.ema_rate) + ema_helper.register(model) + ema_helper.set_state_dict({k.split("$ema_")[-1]: v for k, v in states.items() if "$ema_" in k}) + ema_helper.ema(model) + else: + ema_helper = None + else: + # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion + if self.config.data.dataset == "CIFAR10": + name = "cifar10" + elif self.config.data.dataset == "LSUN": + name = f"lsun_{self.config.data.category}" + else: + raise ValueError + ckpt = get_ckpt_path(f"ema_{name}") + print("Loading checkpoint {}".format(ckpt)) + model.set_state_dict(paddle.load(ckpt)) + model = paddle.DataParallel(model) + + model.eval() + + if self.args.fid: + self.sample_fid(model) + elif self.args.interpolation: + self.sample_interpolation(model) + elif self.args.sequence: + self.sample_sequence(model) + else: + raise NotImplementedError("Sample procedeure not defined") + + def sample_fid(self, model): + config = self.config + img_id = len(glob.glob(f"{self.args.image_folder}/*")) + print(f"starting from image {img_id}") + total_n_samples = 50000 + n_rounds = (total_n_samples - img_id) // config.sampling.batch_size + + with paddle.no_grad(): + for _ in tqdm.tqdm( + range(n_rounds), desc="Generating image samples for FID evaluation." + ): + n = config.sampling.batch_size + x = paddle.randn( + n, + config.data.channels, + config.data.image_size, + config.data.image_size, + ) + + x = self.sample_image(x, model) + x = inverse_data_transform(config, x) + + for i in range(n): + Image.fromarray(np.uint8(x[i].numpy().transpose([1,2,0])*255)).save( + os.path.join(self.args.image_folder, f"{img_id}.png") + ) + img_id += 1 + + def sample_sequence(self, model): + config = self.config + + x = paddle.randn([ + 8, + config.data.channels, + config.data.image_size, + config.data.image_size, + ]) + + # NOTE: This means that we are producing each predicted x0, not x_{t-1} at timestep t. + with paddle.no_grad(): + _, x = self.sample_image(x, model, last=False) + + x = [inverse_data_transform(config, y) for y in x] + + for i in range(len(x)): + for j in range(x[i].shape[0]): + Image.fromarray(np.uint8(x[i][j].numpy().transpose([1,2,0])*255)).save( + os.path.join(self.args.image_folder, f"{j}_{i}.png") + ) + + def sample_interpolation(self, model): + config = self.config + + def slerp(z1, z2, alpha): + theta = paddle.acos(paddle.sum(z1 * z2) / (paddle.norm(z1) * paddle.norm(z2))) + return ( + paddle.sin((1 - alpha) * theta) / paddle.sin(theta) * z1 + + paddle.sin(alpha * theta) / paddle.sin(theta) * z2 + ) + + z1 = paddle.randn( + 1, + config.data.channels, + config.data.image_size, + config.data.image_size, + ) + z2 = paddle.randn( + 1, + config.data.channels, + config.data.image_size, + config.data.image_size, + ) + alpha = paddle.arange(0.0, 1.01, 0.1) + z_ = [] + for i in range(alpha.shape[0]): + z_.append(slerp(z1, z2, alpha[i])) + + x = paddle.concat(z_, 0) + xs = [] + + # Hard coded here, modify to your preferences + with paddle.no_grad(): + for i in range(0, x.shape[0], 8): + xs.append(self.sample_image(x[i : i + 8], model)) + x = inverse_data_transform(config, paddle.concat(xs, 0)) + for i in range(x.shape[0]): + Image.fromarray(np.uint8(x[i].numpy().transpose([1,2,0])*255)).save(os.path.join(self.args.image_folder, f"{i}.png")) + + def sample_image(self, x, model, last=True): + try: + skip = self.args.skip + except Exception: + skip = 1 + + if self.args.sample_type == "generalized": + if self.args.skip_type == "uniform": + skip = self.num_timesteps // self.args.timesteps + seq = range(0, self.num_timesteps, skip) + elif self.args.skip_type == "quad": + seq = ( + np.linspace( + 0, np.sqrt(self.num_timesteps * 0.8), self.args.timesteps + ) + ** 2 + ) + seq = [int(s) for s in list(seq)] + else: + raise NotImplementedError + from functions.denoising import generalized_steps + + xs = generalized_steps(x, seq, model, self.betas, eta=self.args.eta) + x = xs + elif self.args.sample_type == "ddpm_noisy": + if self.args.skip_type == "uniform": + skip = self.num_timesteps // self.args.timesteps + seq = range(0, self.num_timesteps, skip) + elif self.args.skip_type == "quad": + seq = ( + np.linspace( + 0, np.sqrt(self.num_timesteps * 0.8), self.args.timesteps + ) + ** 2 + ) + seq = [int(s) for s in list(seq)] + else: + raise NotImplementedError + from functions.denoising import ddpm_steps + + x = ddpm_steps(x, seq, model, self.betas) + else: + raise NotImplementedError + if last: + x = x[0][-1] + return x + + def test(self): + pass diff --git a/example/BinaryDM/runners/diffusion_binarydm.py b/example/BinaryDM/runners/diffusion_binarydm.py new file mode 100644 index 00000000..73db957c --- /dev/null +++ b/example/BinaryDM/runners/diffusion_binarydm.py @@ -0,0 +1,383 @@ +import os +import logging +import time +import glob + +import numpy as np +import tqdm +import paddle +import paddle.io as data + +# import sys +# sys.path.append('./') +from models.diffusion_binarydm import Model +from models.ema import EMAHelper +from functions import get_optimizer +from functions.losses import loss_registry +from datasets import get_dataset, data_transform, inverse_data_transform +from functions.ckpt_util import get_ckpt_path + +import numpy as np +from PIL import Image + + +def paddle2hwcuint8(x, clip=False): + if clip: + x = paddle.clip(x, -1, 1) + x = (x + 1.0) / 2.0 + return x + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + def sigmoid(x): + return 1 / (np.exp(-x) + 1) + + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace( + beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 + ) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + elif beta_schedule == "sigmoid": + betas = np.linspace(-6, 6, num_diffusion_timesteps) + betas = sigmoid(betas) * (beta_end - beta_start) + beta_start + elif beta_schedule == "cosine": + betas = betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: np.cos((t + 0.008) / 1.008 * np.pi / 2) ** 2, + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +class Diffusion(object): + def __init__(self, args, config, device=None): + self.args = args + self.config = config + if device is None: + device = paddle.get_device() + self.device = device + + self.model_var_type = config.model.var_type + betas = get_beta_schedule( + beta_schedule=config.diffusion.beta_schedule, + beta_start=config.diffusion.beta_start, + beta_end=config.diffusion.beta_end, + num_diffusion_timesteps=config.diffusion.num_diffusion_timesteps, + ) + betas = self.betas = paddle.to_tensor(betas).astype('float32') + self.num_timesteps = betas.shape[0] + + alphas = 1.0 - betas + alphas_cumprod = alphas.cumprod(0) + alphas_cumprod_prev = paddle.concat( + [paddle.ones([1]), alphas_cumprod[:-1]], 0 + ) + posterior_variance = ( + betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + ) + if self.model_var_type == "fixedlarge": + self.logvar = betas.log() + # paddle.concat( + # [posterior_variance[1:2], betas[1:]], 0).log() + elif self.model_var_type == "fixedsmall": + self.logvar = posterior_variance.clip(min=1e-20).log() + + def train(self): + args, config = self.args, self.config + vdl_logger = self.config.vdl_logger + dataset, test_dataset = get_dataset(args, config) + train_loader = data.DataLoader( + dataset, + batch_size=config.training.batch_size, + shuffle=True, + num_workers=config.data.num_workers, + use_shared_memory=False, + ) + model = Model(config) + + model = model + model = paddle.DataParallel(model) + + optimizer = get_optimizer(self.config, model.parameters()) + + if self.config.model.ema: + ema_helper = EMAHelper(mu=self.config.model.ema_rate) + ema_helper.register(model) + else: + ema_helper = None + + start_epoch, step = 0, 0 + if self.args.resume_training: + states = paddle.load(os.path.join(self.args.log_path, "ckpt.pdl")) + model.set_state_dict({k.split("$model_")[-1]: v for k, v in states.items() if "$model_" in k}) + + optimizer.set_state_dict({k.split("$optimizer_")[-1]: v for k, v in states.items() if "$optimizer_" in k}) + optimizer._epsilon = self.config.optim.eps + start_epoch = states["$epoch"] + step = states["$step"] + if self.config.model.ema: + ema_helper.set_state_dict({k.split("$ema_")[-1]: v for k, v in states.items() if "$ema_" in k}) + + for epoch in range(start_epoch, self.config.training.n_epochs): + data_start = time.time() + data_time = 0 + for i, (x, y) in enumerate(train_loader): + n = x.shape[0] + data_time += time.time() - data_start + model.train() + step += 1 + + x = data_transform(self.config, x) + e = paddle.randn(x.shape) + b = self.betas + + # antithetic sampling + t = paddle.randint( + low=0, high=self.num_timesteps, shape=(n // 2 + 1,) + ) + t = paddle.concat([t, self.num_timesteps - t - 1], 0)[:n] + loss = loss_registry[config.model.type](model, x, t, e, b) + + vdl_logger.add_scalar("loss", loss, step=step) + + logging.info( + f"step: {step}, loss: {loss.numpy()}, data time: {data_time / (i+1)}" + ) + + optimizer.clear_grad() + loss.backward() + optimizer.step() + + if self.config.model.ema: + ema_helper.update(model) + + if step % self.config.training.snapshot_freq == 0 or step == 1: + states = dict( + **{"$model_"+k: v for k, v in model.state_dict().items()}, + **{"$optimizer_"+k: v for k, v in optimizer.state_dict().items()}, + **{"$epoch": epoch}, + **{"$step": step}, + ) + if self.config.model.ema: + states.update({"$ema_"+k: v for k, v in ema_helper.state_dict().items()}) + + paddle.save( + states, + os.path.join(self.args.log_path, "ckpt_{}.pdl".format(step)), + ) + paddle.save(states, os.path.join(self.args.log_path, "ckpt.pdl")) + + data_start = time.time() + + def sample(self): + model = Model(self.config) + + if not self.args.use_pretrained: + if getattr(self.config.sampling, "ckpt_id", None) is None: + states = paddle.load( + os.path.join(self.args.log_path, "ckpt.pdl") + ) + else: + states = paddle.load( + os.path.join( + self.args.log_path, f"ckpt_{self.config.sampling.ckpt_id}.pdl" + ) + ) + model = model + model = paddle.DataParallel(model) + model.set_state_dict({k.split("$model_")[-1]: v for k, v in states.items() if "$model_" in k}) + + if self.config.model.ema: + ema_helper = EMAHelper(mu=self.config.model.ema_rate) + ema_helper.register(model) + ema_helper.set_state_dict({k.split("$ema_")[-1]: v for k, v in states.items() if "$ema_" in k}) + ema_helper.ema(model) + else: + ema_helper = None + else: + # This used the pretrained DDPM model, see https://github.com/pesser/pytorch_diffusion + if self.config.data.dataset == "CIFAR10": + name = "cifar10" + elif self.config.data.dataset == "LSUN": + name = f"lsun_{self.config.data.category}" + else: + raise ValueError + ckpt = get_ckpt_path(f"ema_{name}") + print("Loading checkpoint {}".format(ckpt)) + model.set_state_dict(paddle.load(ckpt)) + model = paddle.DataParallel(model) + + model.eval() + + if self.args.fid: + self.sample_fid(model) + elif self.args.interpolation: + self.sample_interpolation(model) + elif self.args.sequence: + self.sample_sequence(model) + else: + raise NotImplementedError("Sample procedeure not defined") + + def sample_fid(self, model): + config = self.config + img_id = len(glob.glob(f"{self.args.image_folder}/*")) + print(f"starting from image {img_id}") + total_n_samples = 50000 + n_rounds = (total_n_samples - img_id) // config.sampling.batch_size + + with paddle.no_grad(): + for _ in tqdm.tqdm( + range(n_rounds), desc="Generating image samples for FID evaluation." + ): + n = config.sampling.batch_size + x = paddle.randn( + n, + config.data.channels, + config.data.image_size, + config.data.image_size, + ) + + x = self.sample_image(x, model) + x = inverse_data_transform(config, x) + + for i in range(n): + Image.fromarray(np.uint8(x[i].numpy().transpose([1,2,0])*255)).save( + os.path.join(self.args.image_folder, f"{img_id}.png") + ) + img_id += 1 + + def sample_sequence(self, model): + config = self.config + + x = paddle.randn([ + 8, + config.data.channels, + config.data.image_size, + config.data.image_size, + ]) + + # NOTE: This means that we are producing each predicted x0, not x_{t-1} at timestep t. + with paddle.no_grad(): + _, x = self.sample_image(x, model, last=False) + + x = [inverse_data_transform(config, y) for y in x] + + for i in range(len(x)): + for j in range(x[i].shape[0]): + Image.fromarray(np.uint8(x[i][j].numpy().transpose([1,2,0])*255)).save( + os.path.join(self.args.image_folder, f"{j}_{i}.png") + ) + + def sample_interpolation(self, model): + config = self.config + + def slerp(z1, z2, alpha): + theta = paddle.acos(paddle.sum(z1 * z2) / (paddle.norm(z1) * paddle.norm(z2))) + return ( + paddle.sin((1 - alpha) * theta) / paddle.sin(theta) * z1 + + paddle.sin(alpha * theta) / paddle.sin(theta) * z2 + ) + + z1 = paddle.randn( + 1, + config.data.channels, + config.data.image_size, + config.data.image_size, + ) + z2 = paddle.randn( + 1, + config.data.channels, + config.data.image_size, + config.data.image_size, + ) + alpha = paddle.arange(0.0, 1.01, 0.1) + z_ = [] + for i in range(alpha.shape[0]): + z_.append(slerp(z1, z2, alpha[i])) + + x = paddle.concat(z_, 0) + xs = [] + + # Hard coded here, modify to your preferences + with paddle.no_grad(): + for i in range(0, x.shape[0], 8): + xs.append(self.sample_image(x[i : i + 8], model)) + x = inverse_data_transform(config, paddle.concat(xs, 0)) + for i in range(x.shape[0]): + Image.fromarray(np.uint8(x[i].numpy().transpose([1,2,0])*255)).save(os.path.join(self.args.image_folder, f"{i}.png")) + + def sample_image(self, x, model, last=True): + try: + skip = self.args.skip + except Exception: + skip = 1 + + if self.args.sample_type == "generalized": + if self.args.skip_type == "uniform": + skip = self.num_timesteps // self.args.timesteps + seq = range(0, self.num_timesteps, skip) + elif self.args.skip_type == "quad": + seq = ( + np.linspace( + 0, np.sqrt(self.num_timesteps * 0.8), self.args.timesteps + ) + ** 2 + ) + seq = [int(s) for s in list(seq)] + else: + raise NotImplementedError + from functions.denoising import generalized_steps + + xs = generalized_steps(x, seq, model, self.betas, eta=self.args.eta) + x = xs + elif self.args.sample_type == "ddpm_noisy": + if self.args.skip_type == "uniform": + skip = self.num_timesteps // self.args.timesteps + seq = range(0, self.num_timesteps, skip) + elif self.args.skip_type == "quad": + seq = ( + np.linspace( + 0, np.sqrt(self.num_timesteps * 0.8), self.args.timesteps + ) + ** 2 + ) + seq = [int(s) for s in list(seq)] + else: + raise NotImplementedError + from functions.denoising import ddpm_steps + + x = ddpm_steps(x, seq, model, self.betas) + else: + raise NotImplementedError + if last: + x = x[0][-1] + return x + + def test(self): + pass