Skip to content

Commit fe6bcf2

Browse files
author
Nicolas Violante
committed
[FIX] correct shapes in wrapper, and dataset (needs to be checked) pass logging of images, add mask and plucer to VanillaCFG
1 parent 6273420 commit fe6bcf2

File tree

5 files changed

+33
-18
lines changed

5 files changed

+33
-18
lines changed

main.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -456,20 +456,23 @@ def check_frequency(self, check_idx):
456456
@rank_zero_only
457457
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
458458
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
459-
self.log_img(pl_module, batch, batch_idx, split="train")
459+
# self.log_img(pl_module, batch, batch_idx, split="train")
460+
pass
460461

461462
@rank_zero_only
462463
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
463464
if self.log_before_first_step and pl_module.global_step == 0:
464465
print(f"{self.__class__.__name__}: logging before training")
465-
self.log_img(pl_module, batch, batch_idx, split="train")
466+
# self.log_img(pl_module, batch, batch_idx, split="train")
467+
pass
466468

467469
@rank_zero_only
468470
def on_validation_batch_end(
469471
self, trainer, pl_module, outputs, batch, batch_idx, *args, **kwargs
470472
):
471473
if not self.disabled and pl_module.global_step > 0:
472-
self.log_img(pl_module, batch, batch_idx, split="val")
474+
# self.log_img(pl_module, batch, batch_idx, split="val")
475+
pass
473476
if hasattr(pl_module, "calibrate_grad_norm"):
474477
if (
475478
pl_module.calibrate_grad_norm and batch_idx % 25 == 0
@@ -831,8 +834,6 @@ def init_wandb(save_dir, opt, config, group_name, name_str):
831834
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
832835
# calling these ourselves should not be necessary but it is.
833836
# lightning still takes care of proper multiprocessing though
834-
print("DATAAAAAAA", data)
835-
print("-"*100)
836837
data.prepare_data()
837838
# data.setup()
838839
print("#### Data #####")

sgm/data/dataset.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -182,19 +182,19 @@ def __getitem__(self, idx):
182182

183183
images_files = [images_files[i] for i in images_idxs]
184184

185-
frames = np.zeros((self.num_images, self.image_shape[0], self.image_shape[1], 3))
185+
frames = np.zeros((self.num_images, self.target_shape[0], self.target_shape[1], 3))
186186
for i, img_file in enumerate(images_files):
187187
img_path = os.path.join(images_dir, img_file)
188188
image = cv.imread(img_path)
189189
image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
190-
# image = cv.resize(image, self.target_shape, interpolation=cv.INTER_LINEAR) # TODO: Crops?
190+
image = cv.resize(image, self.target_shape, interpolation=cv.INTER_LINEAR) # TODO: Crops?
191191
if self.transform:
192192
frames[i] = image
193193

194194
frames = frames.astype(np.float32) / 255.0
195195
frames = torch.from_numpy(frames).permute(0, 3, 1, 2) # Convert to (N, C, H, W)
196196
frames = frames * 2.0 - 1.0 # Normalize to [-1, 1]
197-
# TODO: reisze to target shape
197+
198198

199199
# Load colmap data
200200
colmap_scene_path = os.path.join(
@@ -230,8 +230,8 @@ def __getitem__(self, idx):
230230
extrinsics_src=w2cs[0],
231231
extrinsics=w2cs,
232232
intrinsics=Ks.clone(),
233-
target_size=(self.image_shape[0] // self.donwsample_factor,
234-
self.image_shape[1] // self.donwsample_factor),
233+
target_size=(self.target_shape[0] // self.donwsample_factor,
234+
self.target_shape[1] // self.donwsample_factor),
235235
)
236236

237237
concat = torch.cat(
@@ -265,6 +265,7 @@ def __init__(
265265
colmap_dir,
266266
batch_size,
267267
num_workers=0,
268+
num_images=21,
268269
shuffle=True):
269270
super().__init__()
270271

@@ -276,7 +277,7 @@ def __init__(
276277
self.train_dataset = DL3DVDataset(
277278
dataset_dir,
278279
colmap_dir,
279-
num_images=21,
280+
num_images=num_images,
280281
)
281282

282283
def prepare_data(self):

sgm/modules/diffusionmodules/denoiser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def forward(
3535
c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape))
3636

3737
if "mask" in cond:
38-
mask = cond.pop("mask")[...,None,None,None]
38+
mask = cond.pop("mask")[...,None,None,None].to(dtype=input.dtype)
3939
input = input * (1 - mask) + input * mask
4040

4141
return (

sgm/modules/diffusionmodules/guiders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def prepare_inputs(self, x, s, c, uc):
3434
c_out = dict()
3535

3636
for k in c:
37-
if k in ["vector", "crossattn", "concat"]:
37+
if k in ["vector", "crossattn", "concat", "mask", "plucker"]:
3838
c_out[k] = torch.cat((uc[k], c[k]), 0)
3939
else:
4040
assert c[k] == uc[k]

sgm/modules/diffusionmodules/wrappers.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import torch.nn as nn
33
from packaging import version
4+
from einops import rearrange, repeat
45

56
OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
67

@@ -50,12 +51,24 @@ def forward(
5051
) -> torch.Tensor:
5152
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=2)
5253

54+
55+
b = x.shape[0]
56+
f = x.shape[1]
57+
x = rearrange(x, "b f c h w -> (b f) c h w")
58+
dense_y=rearrange(c["plucker"], "b f c h w -> (b f) c h w")
59+
5360
#TODO: remove
54-
c["crossattn"] = torch.zeros((x.shape[0], 1, 1024)).type_as(x).to(x.device)
55-
return self.diffusion_model(
61+
c = torch.zeros((b, 1, 1024)).type_as(x).to(x.device)
62+
c = repeat(c, "b 1 c -> (b f) 1 c", f=f)
63+
t = repeat(t, "b -> (b f)", f=f)
64+
65+
out = self.diffusion_model(
5666
x,
5767
t=t,
58-
y=c["crossattn"],
59-
dense_y=c["plucker"],
68+
y=c, # c["crossattn"]
69+
dense_y=dense_y,
70+
num_frames=f,
6071
**kwargs,
61-
)
72+
)
73+
out = rearrange(out, "(b f) c h w -> b f c h w", f=f)
74+
return out

0 commit comments

Comments
 (0)