Skip to content

Commit 846fec5

Browse files
committed
[ADD] better names in loss weight
1 parent d3c6532 commit 846fec5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

sgm/modules/diffusionmodules/loss_weighting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def __call__(self, sigma: torch.Tensor) -> torch.Tensor:
3434
class SevaWeighting(DiffusionLossWeighting):
3535
def __call__(self, sigma: torch.Tensor, mask, max_weight=5.0) -> torch.Tensor:
3636
bools = mask.to(torch.bool)
37-
batch_size, num_frames = bools.shape
38-
indices = torch.arange(num_frames, device=bools.device).unsqueeze(0).expand(batch_size, num_frames)
39-
weights = torch.full((batch_size, num_frames), max_weight, dtype=torch.float, device=bools.device)
37+
batch_size, N = bools.shape
38+
indices = torch.arange(N, device=bools.device).unsqueeze(0).expand(batch_size, N)
39+
weights = torch.full((batch_size, N), max_weight, dtype=torch.float, device=bools.device)
4040

4141
for b in range(batch_size):
4242
true_idx = indices[b][bools[b]]

0 commit comments

Comments
 (0)