Skip to content

Commit d6f4774

Browse files
authored
Add latents_mean and latents_std to SDXLLongPromptWeightingPipeline (#11034)
1 parent eb50def commit d6f4774

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

examples/community/lpw_stable_diffusion_xl.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -1773,7 +1773,7 @@ def denoising_value_valid(dnv):
17731773
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
17741774
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
17751775
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1776-
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1776+
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
17771777
" `pipeline.unet` or your `mask_image` or `image` input."
17781778
)
17791779
elif num_channels_unet != 4:
@@ -1924,7 +1924,22 @@ def denoising_value_valid(dnv):
19241924
self.upcast_vae()
19251925
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
19261926

1927-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1927+
# unscale/denormalize the latents
1928+
# denormalize with the mean and std if available and not None
1929+
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1930+
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1931+
if has_latents_mean and has_latents_std:
1932+
latents_mean = (
1933+
torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1934+
)
1935+
latents_std = (
1936+
torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1937+
)
1938+
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1939+
else:
1940+
latents = latents / self.vae.config.scaling_factor
1941+
1942+
image = self.vae.decode(latents, return_dict=False)[0]
19281943

19291944
# cast back to fp16 if needed
19301945
if needs_upcasting:

0 commit comments

Comments
 (0)