Skip to content

Commit b316104

Browse files
authored
Fix Hunyuan I2V for transformers>4.47.1 (huggingface#11293)
* update * update
1 parent d3b2699 commit b316104

File tree

1 file changed

+64
-8
lines changed

1 file changed

+64
-8
lines changed

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py

+64-8
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,50 @@
100100
}
101101

102102

103+
def _expand_input_ids_with_image_tokens(
104+
text_input_ids,
105+
prompt_attention_mask,
106+
max_sequence_length,
107+
image_token_index,
108+
image_emb_len,
109+
image_emb_start,
110+
image_emb_end,
111+
pad_token_id,
112+
):
113+
special_image_token_mask = text_input_ids == image_token_index
114+
num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
115+
batch_indices, non_image_indices = torch.where(text_input_ids != image_token_index)
116+
117+
max_expanded_length = max_sequence_length + (num_special_image_tokens.max() * (image_emb_len - 1))
118+
new_token_positions = torch.cumsum((special_image_token_mask * (image_emb_len - 1) + 1), -1) - 1
119+
text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
120+
121+
expanded_input_ids = torch.full(
122+
(text_input_ids.shape[0], max_expanded_length),
123+
pad_token_id,
124+
dtype=text_input_ids.dtype,
125+
device=text_input_ids.device,
126+
)
127+
expanded_input_ids[batch_indices, text_to_overwrite] = text_input_ids[batch_indices, non_image_indices]
128+
expanded_input_ids[batch_indices, image_emb_start:image_emb_end] = image_token_index
129+
130+
expanded_attention_mask = torch.zeros(
131+
(text_input_ids.shape[0], max_expanded_length),
132+
dtype=prompt_attention_mask.dtype,
133+
device=prompt_attention_mask.device,
134+
)
135+
attn_batch_indices, attention_indices = torch.where(expanded_input_ids != pad_token_id)
136+
expanded_attention_mask[attn_batch_indices, attention_indices] = 1.0
137+
expanded_attention_mask = expanded_attention_mask.to(prompt_attention_mask.dtype)
138+
position_ids = (expanded_attention_mask.cumsum(-1) - 1).masked_fill_((expanded_attention_mask == 0), 1)
139+
140+
return {
141+
"input_ids": expanded_input_ids,
142+
"attention_mask": expanded_attention_mask,
143+
"position_ids": position_ids,
144+
}
145+
146+
103147
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104148
def retrieve_timesteps(
105149
scheduler,
@@ -251,6 +295,12 @@ def _get_llama_prompt_embeds(
251295
prompt = [prompt_template["template"].format(p) for p in prompt]
252296

253297
crop_start = prompt_template.get("crop_start", None)
298+
299+
image_emb_len = prompt_template.get("image_emb_len", 576)
300+
image_emb_start = prompt_template.get("image_emb_start", 5)
301+
image_emb_end = prompt_template.get("image_emb_end", 581)
302+
double_return_token_id = prompt_template.get("double_return_token_id", 271)
303+
254304
if crop_start is None:
255305
prompt_template_input = self.tokenizer(
256306
prompt_template["template"],
@@ -280,19 +330,25 @@ def _get_llama_prompt_embeds(
280330

281331
image_embeds = self.image_processor(image, return_tensors="pt").pixel_values.to(device)
282332

333+
image_token_index = self.text_encoder.config.image_token_index
334+
pad_token_id = self.text_encoder.config.pad_token_id
335+
expanded_inputs = _expand_input_ids_with_image_tokens(
336+
text_input_ids,
337+
prompt_attention_mask,
338+
max_sequence_length,
339+
image_token_index,
340+
image_emb_len,
341+
image_emb_start,
342+
image_emb_end,
343+
pad_token_id,
344+
)
283345
prompt_embeds = self.text_encoder(
284-
input_ids=text_input_ids,
285-
attention_mask=prompt_attention_mask,
286-
pixel_values=image_embeds,
346+
**expanded_inputs,
347+
pixel_value=image_embeds,
287348
output_hidden_states=True,
288349
).hidden_states[-(num_hidden_layers_to_skip + 1)]
289350
prompt_embeds = prompt_embeds.to(dtype=dtype)
290351

291-
image_emb_len = prompt_template.get("image_emb_len", 576)
292-
image_emb_start = prompt_template.get("image_emb_start", 5)
293-
image_emb_end = prompt_template.get("image_emb_end", 581)
294-
double_return_token_id = prompt_template.get("double_return_token_id", 271)
295-
296352
if crop_start is not None and crop_start > 0:
297353
text_crop_start = crop_start - 1 + image_emb_len
298354
batch_indices, last_double_return_token_indices = torch.where(text_input_ids == double_return_token_id)

0 commit comments

Comments
 (0)