Skip to content

Commit 33d10af

Browse files
chengzeyiyiyixuxuhlky
authored
Fix Wan I2V Quality (#11087)
* fix_wan_i2v_quality * Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py Co-authored-by: YiYi Xu <[email protected]> * Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py Co-authored-by: YiYi Xu <[email protected]> * Update src/diffusers/pipelines/wan/pipeline_wan_i2v.py Co-authored-by: YiYi Xu <[email protected]> * Update pipeline_wan_i2v.py --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: hlky <[email protected]>
1 parent 1001425 commit 33d10af

File tree

1 file changed

+7
-20
lines changed

1 file changed

+7
-20
lines changed

src/diffusers/pipelines/wan/pipeline_wan_i2v.py

+7-20
Original file line numberDiff line numberDiff line change
@@ -108,31 +108,16 @@ def prompt_clean(text):
108108
return text
109109

110110

111+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
111112
def retrieve_latents(
112-
encoder_output: torch.Tensor,
113-
latents_mean: torch.Tensor,
114-
latents_std: torch.Tensor,
115-
generator: Optional[torch.Generator] = None,
116-
sample_mode: str = "sample",
113+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
117114
):
118115
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
119-
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
120-
encoder_output.latent_dist.logvar = torch.clamp(
121-
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
122-
)
123-
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
124-
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
125116
return encoder_output.latent_dist.sample(generator)
126117
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
127-
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
128-
encoder_output.latent_dist.logvar = torch.clamp(
129-
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
130-
)
131-
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
132-
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
133118
return encoder_output.latent_dist.mode()
134119
elif hasattr(encoder_output, "latents"):
135-
return (encoder_output.latents - latents_mean) * latents_std
120+
return encoder_output.latents
136121
else:
137122
raise AttributeError("Could not access latents of provided encoder_output")
138123

@@ -412,13 +397,15 @@ def prepare_latents(
412397

413398
if isinstance(generator, list):
414399
latent_condition = [
415-
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
400+
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
416401
]
417402
latent_condition = torch.cat(latent_condition)
418403
else:
419-
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
404+
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
420405
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
421406

407+
latent_condition = (latent_condition - latents_mean) * latents_std
408+
422409
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
423410
mask_lat_size[:, :, list(range(1, num_frames))] = 0
424411
first_frame_mask = mask_lat_size[:, :, 0:1]

0 commit comments

Comments
 (0)