Skip to content

Commit 80af8f5

Browse files
committed
feat(ml):verify multi batch prompt match x
1 parent b0fb4df commit 80af8f5

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

models/modules/img2img_turbo/img2img_turbo.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def __init__(self, in_channels, out_channels, lora_rank_unet, lora_rank_vae):
192192
unet.enable_gradient_checkpointing()
193193

194194
def forward(self, x, prompt):
195-
196195
caption_tokens = self.tokenizer(
197196
prompt,
198197
max_length=self.tokenizer.model_max_length,
@@ -205,7 +204,7 @@ def forward(self, x, prompt):
205204
batch_size = caption_enc.shape[0]
206205
repeated_encs = [
207206
caption_enc[i].repeat(int(x.shape[0] / batch_size), 1, 1)
208-
for i in range(caption_enc.shape[0])
207+
for i in range(batch_size)
209208
]
210209

211210
# Concatenate the repeated encodings along the batch dimension

0 commit comments

Comments
 (0)