-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrunner_l.py
243 lines (208 loc) · 11.4 KB
/
runner_l.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
import os
import sys
import argparse
import pickle
import numpy as np
from pathlib import Path
from line_profiler import profile
import torch
import torch.nn.functional as F
import lightning as L
from corruption import build_corruption
from dataset import imagenet
from dataset.LMDB2ImageFolder import Dset
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar
from lightning.pytorch.loggers import WandbLogger
from i2sb.network import Image256Net
from i2sb.diffusion import Diffusion
from i2sb.runner import make_beta_schedule, build_optimizer_sched
from evaluation import build_resnet50
from logger import Logger
def update_ema(model_ema, model_net, decay=0.9999):
param_ema = dict(model_ema.named_parameters())
param_net = dict(model_net.named_parameters())
for k in param_ema.keys():
param_ema[k].data.mul_(decay).add_(param_net[k].data, alpha=1 - decay)
class Runner(L.LightningModule):
def __init__(self, log, opt, save_opt=True):
super().__init__()
if save_opt:
opt_pkl_path = opt.ckpt_path / "options.pkl"
with open(opt_pkl_path, "wb") as f:
pickle.dump(opt, f)
log.info("Saved options pickle to {}!".format(opt_pkl_path))
betas = make_beta_schedule(n_timestep=opt.interval, linear_end=opt.beta_max / opt.interval)
betas = np.concatenate([betas[:opt.interval//2],
np.flip(betas[:opt.interval//2])])
self.opt = opt
self.SysLog = log
self.dataset = Dset(opt.dataset_dir, opt.image_size)
self.diffusion = Diffusion(betas) #TODO no 'device' arg here
log.info(f"[Diffusion] Built I2SB diffusion: steps={len(betas)}!")
noise_levels = torch.linspace(opt.t0, opt.T, opt.interval)
self.net = Image256Net(log, noise_levels=noise_levels, use_fp16=self.opt.use_fp16, cond=opt.cond_x1)
#self.ema = ExponentialMovingAverage(self.net.parameters(), decay=opt.ema)
self.ema = Image256Net(log, noise_levels=noise_levels, use_fp16=self.opt.use_fp16, cond=opt.cond_x1)
self.resnet = build_resnet50()
self.corrupt_method = build_corruption(self.opt, log)
self.backprop_frequency = opt.batch_size // opt.microbatch // opt.n_gpu_per_node
self.automatic_optimization = False
if self.opt.load:
checkpoint = torch.load(self.opt.load, map_location='cpu')
self.net.load_state_dict(checkpoint['net'])
log.info(f"[Net] Loaded network ckpt: {opt.load}!")
self.ema.load_state_dict(checkpoint['ema'])
log.info(f"[Net] Loaded ema ckpt: {opt.load}!")
def train_dataloader(self):
train_loader = DataLoader(self.dataset,
batch_size= self.opt.microbatch,
shuffle=True,
num_workers=16,
pin_memory=True,
drop_last=True)
return train_loader
def training_step(self, batch, batch_idx):
optimizer = self.optimizers()
optimizer.zero_grad()
x0, x1, mask, y, cond = self.sample_batch(batch, self.corrupt_method)
step = torch.randint(0, self.opt.interval, (x0.shape[0], ))
xt = self.diffusion.q_sample(step, x0, x1, ot_ode = self.opt.ot_ode)
label = self.compute_label(step, x0, x1)
pred = self.net(xt, step, cond=cond)
if mask is not None:
pred = mask*pred
label = mask*label
loss = F.mse_loss(pred, label)
self.manual_backward(loss)
if (batch_idx % self.backprop_frequency) == 0:
optimizer.step()
update_ema(self.ema, self.net)
# if sched is not None: sched.step()
# --- logging and ckp saving --- #
if batch_idx % 10 == 0:
self.log("loss", loss.detach())
#def validation_step
#def validation_dataloader
def configure_optimizers(self):
optimizer, sched = build_optimizer_sched(self.opt, self.net, self.SysLog) # FIXME how to make sched in L?
return optimizer
# TODO does L has its own setup for sampling?
def sample_batch(self, batch, corrupt_method):
if self.opt.corrupt == 'mixture':
clean_img, corrupt_img, y = batch
mask = None
elif self.opt.corrupt == 'inpaint':
clean_img, y = batch
with torch.no_grad():
corrupt_img, mask = corrupt_method(clean_img)
else:
clean_img, y = batch
with torch.no_grad():
corrupt_img = corrupt_method(clean_img)
mask = None
y = y.detach().to(self.device)
x0 = clean_img.detach().to(self.device)
x1 = corrupt_img.detach().to(self.device)
if mask is not None:
mask = mask.detach().to(self.device)
x1 = (1. - mask) * x1 + mask * torch.randn_like(x1)
cond = x1.detach() if self.opt.cond_x1 else None
if self.opt.add_x1_noise:
x1 = x1 + torch.randn_like(x1)
return x0, x1, mask, y, cond
def compute_label(self, step, x0, xt):
std_fwd = self.diffusion.get_std_fwd(step, xdim=x0.shape[1:])
std_fwd = std_fwd.type_as(x0)
label = (xt - x0) / std_fwd# FIXME GPU and cpu
return label.detach()
def compute_pred_x0(self, step, xt, net_out, clip_noise=False):
std_fwd = self.diffusion.get_std_fwd(step, xdim=xt.shape[1:])
std_fwd = std_fwd.type_as(xt)
pred_x0 = xt - std_fwd * net_out
if clip_noise: pred_x0.clap_(-1., 1.)
return pred_x0
def create_training_options():
# --------------- basic ---------------
parser = argparse.ArgumentParser()
parser.add_argument("--train", action="store_true", default=True)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--name", type=str, default="test", help="experiment ID")
parser.add_argument("--ckpt", type=str, default=None, help="resumed checkpoint name")
parser.add_argument("--gpu", type=int, default=None, help="set only if you wish to run on a particular device")
parser.add_argument("--n-gpu-per-node", type=int, default=1, help="number of gpu on each node")
parser.add_argument("--master-address", type=str, default='localhost', help="address for master")
parser.add_argument("--node-rank", type=int, default=0, help="the index of node")
parser.add_argument("--num-proc-node", type=int, default=1, help="The number of nodes in multi node env")
# parser.add_argument("--amp", action="store_true")
# --------------- SB model ---------------
parser.add_argument("--image-size", type=int, default=256)
parser.add_argument("--corrupt", type=str, default=None, help="restoration task")
parser.add_argument("--t0", type=float, default=1e-4, help="sigma start time in network parametrization")
parser.add_argument("--T", type=float, default=1., help="sigma end time in network parametrization")
parser.add_argument("--interval", type=int, default=1000, help="number of interval")
parser.add_argument("--beta-max", type=float, default=0.3, help="max diffusion for the diffusion model")
# parser.add_argument("--beta-min", type=float, default=0.1)
parser.add_argument("--ot-ode", action="store_true", help="use OT-ODE model")
parser.add_argument("--clip-denoise", action="store_true", help="clamp predicted image to [-1,1] at each")
# optional configs for conditional network
parser.add_argument("--cond-x1", action="store_true", help="conditional the network on degraded images")
parser.add_argument("--add-x1-noise", action="store_true", help="add noise to conditional network")
# --------------- optimizer and loss ---------------
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--microbatch", type=int, default=2, help="accumulate gradient over microbatch until full batch-size")
parser.add_argument("--num-itr", type=int, default=1000000, help="training iteration")
parser.add_argument("--lr", type=float, default=5e-5, help="learning rate")
parser.add_argument("--lr-gamma", type=float, default=0.99, help="learning rate decay ratio")
parser.add_argument("--lr-step", type=int, default=1000, help="learning rate decay step size")
parser.add_argument("--l2-norm", type=float, default=0.0)
parser.add_argument("--ema", type=float, default=0.99)
# --------------- path and logging ---------------
parser.add_argument("--dataset-dir", type=Path, default="/dataset", help="path to LMDB dataset")
parser.add_argument("--log-dir", type=Path, default=".log", help="path to log std outputs and writer data")
parser.add_argument("--log-writer", type=str, default=None, help="log writer: can be tensorbard, wandb, or None")
parser.add_argument("--wandb-api-key", type=str, default=None, help="unique API key of your W&B account; see https://wandb.ai/authorize")
parser.add_argument("--wandb-user", type=str, default=None, help="user name of your W&B account")
opt = parser.parse_args()
RESULT_DIR = Path("results")
os.makedirs(opt.log_dir, exist_ok=True)
opt.ckpt_path = RESULT_DIR / opt.name
os.makedirs(opt.ckpt_path, exist_ok=True)
opt.use_fp16 = False
if opt.ckpt is not None:
ckpt_file = RESULT_DIR / opt.ckpt / "latest.pt"
assert ckpt_file.exists()
opt.load = ckpt_file
else:
opt.load = None
return opt
if __name__ == '__main__':
opt = create_training_options()
log = Logger(log_dir=opt.log_dir)
log.info("=======================================================")
log.info(" Image-to-Image Schrodinger Bridge")
log.info("=======================================================")
log.info("Command used:\n{}".format(" ".join(sys.argv)))
log.info(f"Experiment ID: {opt.name}")
diffusion_model = Runner(log, opt)
wandb_logger = WandbLogger(project = opt.name,
log_model = False,
group = opt.ckpt)
bar = TQDMProgressBar(refresh_rate=diffusion_model.backprop_frequency)
if opt.train:
checkpoint_callback = ModelCheckpoint(dirpath= opt.ckpt_path,
save_last=True,
save_top_k= -1,
)
trainer = L.Trainer(accelerator='auto',
devices = opt.n_gpu_per_node,
num_nodes= opt.num_proc_node,
max_steps= opt.num_itr, # this refers to optimizer steps
callbacks=[checkpoint_callback, bar],
logger= wandb_logger,
strategy='ddp_find_unused_parameters_true',
#fast_dev_run=True,
)
trainer.fit(diffusion_model,
)