@@ -108,31 +108,16 @@ def prompt_clean(text):
108
108
return text
109
109
110
110
111
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
111
112
def retrieve_latents (
112
- encoder_output : torch .Tensor ,
113
- latents_mean : torch .Tensor ,
114
- latents_std : torch .Tensor ,
115
- generator : Optional [torch .Generator ] = None ,
116
- sample_mode : str = "sample" ,
113
+ encoder_output : torch .Tensor , generator : Optional [torch .Generator ] = None , sample_mode : str = "sample"
117
114
):
118
115
if hasattr (encoder_output , "latent_dist" ) and sample_mode == "sample" :
119
- encoder_output .latent_dist .mean = (encoder_output .latent_dist .mean - latents_mean ) * latents_std
120
- encoder_output .latent_dist .logvar = torch .clamp (
121
- (encoder_output .latent_dist .logvar - latents_mean ) * latents_std , - 30.0 , 20.0
122
- )
123
- encoder_output .latent_dist .std = torch .exp (0.5 * encoder_output .latent_dist .logvar )
124
- encoder_output .latent_dist .var = torch .exp (encoder_output .latent_dist .logvar )
125
116
return encoder_output .latent_dist .sample (generator )
126
117
elif hasattr (encoder_output , "latent_dist" ) and sample_mode == "argmax" :
127
- encoder_output .latent_dist .mean = (encoder_output .latent_dist .mean - latents_mean ) * latents_std
128
- encoder_output .latent_dist .logvar = torch .clamp (
129
- (encoder_output .latent_dist .logvar - latents_mean ) * latents_std , - 30.0 , 20.0
130
- )
131
- encoder_output .latent_dist .std = torch .exp (0.5 * encoder_output .latent_dist .logvar )
132
- encoder_output .latent_dist .var = torch .exp (encoder_output .latent_dist .logvar )
133
118
return encoder_output .latent_dist .mode ()
134
119
elif hasattr (encoder_output , "latents" ):
135
- return ( encoder_output .latents - latents_mean ) * latents_std
120
+ return encoder_output .latents
136
121
else :
137
122
raise AttributeError ("Could not access latents of provided encoder_output" )
138
123
@@ -412,13 +397,15 @@ def prepare_latents(
412
397
413
398
if isinstance (generator , list ):
414
399
latent_condition = [
415
- retrieve_latents (self .vae .encode (video_condition ), latents_mean , latents_std , g ) for g in generator
400
+ retrieve_latents (self .vae .encode (video_condition ), sample_mode = "argmax" ) for _ in generator
416
401
]
417
402
latent_condition = torch .cat (latent_condition )
418
403
else :
419
- latent_condition = retrieve_latents (self .vae .encode (video_condition ), latents_mean , latents_std , generator )
404
+ latent_condition = retrieve_latents (self .vae .encode (video_condition ), sample_mode = "argmax" )
420
405
latent_condition = latent_condition .repeat (batch_size , 1 , 1 , 1 , 1 )
421
406
407
+ latent_condition = (latent_condition - latents_mean ) * latents_std
408
+
422
409
mask_lat_size = torch .ones (batch_size , 1 , num_frames , latent_height , latent_width )
423
410
mask_lat_size [:, :, list (range (1 , num_frames ))] = 0
424
411
first_frame_mask = mask_lat_size [:, :, 0 :1 ]
0 commit comments