Skip to content

Commit e7b9a07

Browse files
asomozayiyixuxusayakpaul
authored
[SD3 LoRA] Fix list index out of range (#8584)
* fix * add check * key present is checked before * test case draft * aply suggestions * changed testing repo, back to old class * forgot docstring --------- Co-authored-by: YiYi Xu <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent 8eb1731 commit e7b9a07

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

src/diffusers/loaders/lora.py

+6
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
_get_model_file,
3131
convert_state_dict_to_diffusers,
3232
convert_state_dict_to_peft,
33+
convert_unet_state_dict_to_peft,
3334
delete_adapter_layers,
3435
get_adapter_name,
3536
get_peft_kwargs,
@@ -1543,6 +1544,11 @@ def load_lora_into_transformer(cls, state_dict, transformer, adapter_name=None,
15431544
}
15441545

15451546
if len(state_dict.keys()) > 0:
1547+
# check with first key if is not in peft format
1548+
first_key = next(iter(state_dict.keys()))
1549+
if "lora_A" not in first_key:
1550+
state_dict = convert_unet_state_dict_to_peft(state_dict)
1551+
15461552
if adapter_name in getattr(transformer, "peft_config", {}):
15471553
raise ValueError(
15481554
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."

tests/lora/test_lora_layers_sd3.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
SD3Transformer2DModel,
2828
StableDiffusion3Pipeline,
2929
)
30-
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, torch_device
30+
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
3131

3232

3333
if is_peft_available():
@@ -287,3 +287,24 @@ def test_simple_inference_with_transformer_fuse_unfuse(self):
287287
self.assertTrue(
288288
np.allclose(ouput_fused, output_unfused_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
289289
)
290+
291+
@require_torch_gpu
292+
def test_sd3_lora(self):
293+
"""
294+
Test loading the loras that are saved with the diffusers and peft formats.
295+
Related PR: https://github.com/huggingface/diffusers/pull/8584
296+
"""
297+
components = self.get_dummy_components()
298+
299+
pipe = self.pipeline_class(**components)
300+
pipe = pipe.to(torch_device)
301+
pipe.set_progress_bar_config(disable=None)
302+
303+
lora_model_id = "hf-internal-testing/tiny-sd3-loras"
304+
305+
lora_filename = "lora_diffusers_format.safetensors"
306+
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
307+
pipe.unload_lora_weights()
308+
309+
lora_filename = "lora_peft_format.safetensors"
310+
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)

0 commit comments

Comments
 (0)