Skip to content

Commit 75d7e5c

Browse files
authored
Fix LatteTransformer3DModel dtype mismatch with enable_temporal_attentions (#11139)
1 parent 617c208 commit 75d7e5c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/diffusers/models/transformers/latte_transformer_3d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def forward(
273273
hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
274274

275275
if i == 0 and num_frame > 1:
276-
hidden_states = hidden_states + self.temp_pos_embed
276+
hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype)
277277

278278
if torch.is_grad_enabled() and self.gradient_checkpointing:
279279
hidden_states = self._gradient_checkpointing_func(

0 commit comments

Comments
 (0)