Skip to content

Commit 3b9a6cf

Browse files
Fix issue with 3d masks.
1 parent 3748e7e commit 3b9a6cf

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

comfy/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -853,19 +853,18 @@ def reshape_mask(input_mask, output_shape):
853853
dims = len(output_shape) - 2
854854

855855
if dims == 1:
856-
mask = input_mask
857856
scale_mode = "linear"
858857

859858
if dims == 2:
860-
mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
859+
input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
861860
scale_mode = "bilinear"
862861

863862
if dims == 3:
864863
if len(input_mask.shape) < 5:
865-
mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
864+
input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
866865
scale_mode = "trilinear"
867866

868-
mask = torch.nn.functional.interpolate(mask, size=output_shape[2:], mode=scale_mode)
867+
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
869868
if mask.shape[1] < output_shape[1]:
870869
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
871870
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])

0 commit comments

Comments
 (0)