Skip to content

Commit 31f8c33

Browse files
committed
[ADD] loss weighting
1 parent fe6bcf2 commit 31f8c33

File tree

3 files changed

+37
-2
lines changed

3 files changed

+37
-2
lines changed

configs/example_training/seva-clipl_dl3dv.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ model:
111111
target: sgm.modules.diffusionmodules.loss.StandardDiffusionLoss
112112
params:
113113
loss_weighting_config:
114-
target: sgm.modules.diffusionmodules.loss_weighting.EpsWeighting
114+
target: sgm.modules.diffusionmodules.loss_weighting.SevaWeighting
115115
sigma_sampler_config:
116116
target: sgm.modules.diffusionmodules.sigma_sampling.DiscreteSampling
117117
params:

sgm/modules/diffusionmodules/loss.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def _forward(
8686
model_output = denoiser(
8787
network, noised_input, sigmas, cond, **additional_model_inputs
8888
)
89-
w = append_dims(self.loss_weighting(sigmas), input.ndim)
89+
w = append_dims(self.loss_weighting(sigmas, cond["mask"]), input.ndim)
9090
return self.get_loss(model_output, input, w)
9191

9292
def get_loss(self, model_output, target, w):
@@ -103,3 +103,19 @@ def get_loss(self, model_output, target, w):
103103
return loss
104104
else:
105105
raise NotImplementedError(f"Unknown loss type {self.loss_type}")
106+
107+
def interpolate_weights_batch(bools: torch.Tensor, max_weight=5.0) -> torch.Tensor:
108+
B, N = bools.shape
109+
indices = torch.arange(N, device=bools.device).unsqueeze(0).expand(B, N)
110+
weights = torch.full((B, N), max_weight, dtype=torch.float, device=bools.device)
111+
112+
for b in range(B):
113+
true_idx = indices[b][bools[b]]
114+
if len(true_idx) > 0:
115+
dists = torch.stack([torch.abs(indices[b] - t) for t in true_idx]).min(dim=0).values
116+
dists[bools[b]] = 0
117+
weights[b] = dists / dists.max() * max_weight
118+
else:
119+
weights[b] = max_weight
120+
121+
return weights

sgm/modules/diffusionmodules/loss_weighting.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,22 @@ def __init__(self):
3030
class EpsWeighting(DiffusionLossWeighting):
3131
def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
3232
return sigma**-2.0
33+
34+
class SevaWeighting(DiffusionLossWeighting):
35+
def __call__(self, sigma: torch.Tensor, mask, max_weight=5.0) -> torch.Tensor:
36+
bools = mask.to(torch.bool)
37+
batch_size, num_frames = bools.shape
38+
indices = torch.arange(num_frames, device=bools.device).unsqueeze(0).expand(batch_size, num_frames)
39+
weights = torch.full((batch_size, num_frames), max_weight, dtype=torch.float, device=bools.device)
40+
41+
for b in range(batch_size):
42+
true_idx = indices[b][bools[b]]
43+
if len(true_idx) > 0:
44+
dists = torch.stack([torch.abs(indices[b] - t) for t in true_idx]).min(dim=0).values
45+
dists[bools[b]] = 0
46+
weights[b] = dists / dists.max() * max_weight
47+
else:
48+
weights[b] = max_weight
49+
50+
return weights
51+

0 commit comments

Comments
 (0)