Skip to content

Commit 2c59af7

Browse files
authored
Raise warning and round down if Wan num_frames is not 4k + 1 (#11167)
* update * raise warning and round to nearest multiple of scale factor
1 parent 75d7e5c commit 2c59af7

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

src/diffusers/pipelines/wan/pipeline_wan.py

+7
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,13 @@ def __call__(
458458
callback_on_step_end_tensor_inputs,
459459
)
460460

461+
if num_frames % self.vae_scale_factor_temporal != 1:
462+
logger.warning(
463+
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
464+
)
465+
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
466+
num_frames = max(num_frames, 1)
467+
461468
self._guidance_scale = guidance_scale
462469
self._attention_kwargs = attention_kwargs
463470
self._current_timestep = None

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

+7
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,13 @@ def __call__(
559559
callback_on_step_end_tensor_inputs,
560560
)
561561

562+
if num_frames % self.vae_scale_factor_temporal != 1:
563+
logger.warning(
564+
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
565+
)
566+
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
567+
num_frames = max(num_frames, 1)
568+
562569
self._guidance_scale = guidance_scale
563570
self._attention_kwargs = attention_kwargs
564571
self._current_timestep = None

0 commit comments

Comments
 (0)