Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions invokeai/backend/model_manager/load/model_loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
is_state_dict_likely_in_flux_xlabs_format,
lora_model_from_flux_xlabs_state_dict,
)
from invokeai.backend.patches.lora_conversions.peft_adapter_utils import normalize_peft_adapter_names
from invokeai.backend.patches.lora_conversions.qwen_image_lora_conversion_utils import (
lora_model_from_qwen_image_state_dict,
)
Expand Down Expand Up @@ -105,6 +106,10 @@ def _load_model(
# To revisit later to determine if they're needed/useful.
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")}

# Normalize PEFT named-adapter keys (e.g. `lora_A.default.weight` → `lora_A.weight`)
# so the downstream format detectors and converters see canonical PEFT keys.
state_dict = normalize_peft_adapter_names(state_dict)

# At the time of writing, we support the OMI standard for base models Flux and SDXL
if config.format == ModelFormat.OMI and self._model_base in [
BaseModelType.StableDiffusionXL,
Expand Down
7 changes: 7 additions & 0 deletions invokeai/backend/model_manager/model_on_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,13 @@ def load_state_dict(self, path: Optional[Path] = None) -> StateDict:
raise ValueError(f"Unrecognized model extension: {path.suffix}")

state_dict = checkpoint.get("state_dict", checkpoint)

# Normalize PEFT named-adapter keys (e.g. `lora_A.default.weight` → `lora_A.weight`).
# Pattern is LoRA-specific, so this is a no-op for non-LoRA state dicts.
from invokeai.backend.patches.lora_conversions.peft_adapter_utils import normalize_peft_adapter_names

state_dict = normalize_peft_adapter_names(state_dict)

self._state_dict_cache[path] = state_dict
return state_dict

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: dict[str | int, to

# --- Flux2 Klein diffusers key patterns (fused QKV+MLP, ff.linear_in) ---
# These use Flux2Transformer2DModel naming which differs from Flux.1.
for prefix in ["transformer.", "base_model.model."]:
# An empty prefix is supported because some trainers (e.g. PEFT-style LoRAs from
# Modelscope/MuseAI Klein 9B finetunes) save keys at the top level without any
# `transformer.` or `base_model.model.` wrapper.
for prefix in ["transformer.", "base_model.model.", ""]:
has_single = any(
k.startswith(f"{prefix}single_transformer_blocks.") and "to_qkv_mlp_proj" in k for k in state_dict
)
Expand Down
72 changes: 72 additions & 0 deletions invokeai/backend/patches/lora_conversions/peft_adapter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Utilities for handling PEFT named-adapter LoRA state dicts.

PEFT (HuggingFace Parameter-Efficient Fine-Tuning) supports multiple named adapters per model.
When saved, the adapter name is encoded in the weight key:

Standard PEFT: foo.bar.lora_A.weight
Named-adapter PEFT: foo.bar.lora_A.<adapter_name>.weight

The most common adapter name is "default", produced automatically by `model.add_adapter()`
without an explicit name. Some training tools (e.g. Diffusers' PEFT integration with
multi-adapter support, certain LoRA fine-tuning scripts) save in this format even with a
single adapter.

InvokeAI's downstream LoRA detection and conversion code expects the standard PEFT suffix
(`lora_A.weight` / `lora_B.weight`). This module normalizes named-adapter state dicts to
that form so the rest of the pipeline can handle them transparently.
"""

import re
from typing import Any

# Match a named-adapter PEFT key ending: .lora_A.<adapter_name>.weight or .lora_B.<adapter_name>.weight.
# The adapter name is a single dot-free component (PEFT identifiers do not contain dots).
_NAMED_ADAPTER_RE = re.compile(r"\.lora_([AB])\.([^.]+)\.weight$")


def _extract_adapter_names(state_dict: dict[str | int, Any]) -> set[str]:
"""Return the set of distinct PEFT adapter names found in the state dict.

A "named adapter" key is one matching `.lora_A.<name>.weight` or `.lora_B.<name>.weight`.
Keys in the standard PEFT form (`.lora_A.weight` / `.lora_B.weight`) do not contribute.
"""
names: set[str] = set()
for key in state_dict:
if not isinstance(key, str):
continue
m = _NAMED_ADAPTER_RE.search(key)
if m:
names.add(m.group(2))
return names


def has_peft_named_adapter_keys(state_dict: dict[str | int, Any]) -> bool:
"""Check whether the state dict contains any PEFT named-adapter keys."""
return bool(_extract_adapter_names(state_dict))


def normalize_peft_adapter_names(state_dict: dict[str | int, Any]) -> dict[str | int, Any]:
"""Return a state dict with PEFT named-adapter suffixes stripped to the standard form.

Transforms:
foo.bar.lora_A.<adapter_name>.weight → foo.bar.lora_A.weight
foo.bar.lora_B.<adapter_name>.weight → foo.bar.lora_B.weight

Only applied when the state dict contains exactly one distinct adapter name. If the
file holds multiple adapters, the keys are left untouched (renaming would collide and
multi-adapter LoRAs are not currently supported by InvokeAI).

If no named-adapter keys are present, the input dict is returned unchanged.
"""
adapter_names = _extract_adapter_names(state_dict)
if len(adapter_names) != 1:
return state_dict

normalized: dict[str | int, Any] = {}
for key, value in state_dict.items():
if isinstance(key, str):
new_key = _NAMED_ADAPTER_RE.sub(r".lora_\1.weight", key)
normalized[new_key] = value
else:
normalized[key] = value
return normalized
90 changes: 90 additions & 0 deletions tests/backend/patches/lora_conversions/test_peft_adapter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch

from invokeai.backend.patches.lora_conversions.peft_adapter_utils import (
has_peft_named_adapter_keys,
normalize_peft_adapter_names,
)


def _t() -> torch.Tensor:
return torch.zeros(1)


def test_no_op_when_no_named_adapter_keys():
"""State dicts without named-adapter keys are returned unchanged."""
sd = {
"transformer_blocks.0.attn.to_q.lora_A.weight": _t(),
"transformer_blocks.0.attn.to_q.lora_B.weight": _t(),
"transformer_blocks.0.attn.to_k.lora_down.weight": _t(),
"single_blocks.0.lokr_w1": _t(),
}
assert not has_peft_named_adapter_keys(sd)
assert normalize_peft_adapter_names(sd) is sd


def test_strips_default_adapter_name():
"""The common `default` adapter name gets stripped from lora_A/lora_B keys."""
sd = {
"transformer_blocks.0.attn.to_q.lora_A.default.weight": _t(),
"transformer_blocks.0.attn.to_q.lora_B.default.weight": _t(),
"transformer_blocks.0.attn.to_k.lora_A.default.weight": _t(),
"transformer_blocks.0.attn.to_k.lora_B.default.weight": _t(),
}
assert has_peft_named_adapter_keys(sd)

result = normalize_peft_adapter_names(sd)
assert set(result.keys()) == {
"transformer_blocks.0.attn.to_q.lora_A.weight",
"transformer_blocks.0.attn.to_q.lora_B.weight",
"transformer_blocks.0.attn.to_k.lora_A.weight",
"transformer_blocks.0.attn.to_k.lora_B.weight",
}


def test_strips_custom_adapter_name():
"""Non-default adapter names are also stripped, as long as only one is present."""
sd = {
"single_transformer_blocks.0.attn.to_qkv_mlp_proj.lora_A.my_adapter.weight": _t(),
"single_transformer_blocks.0.attn.to_qkv_mlp_proj.lora_B.my_adapter.weight": _t(),
}
result = normalize_peft_adapter_names(sd)
assert set(result.keys()) == {
"single_transformer_blocks.0.attn.to_qkv_mlp_proj.lora_A.weight",
"single_transformer_blocks.0.attn.to_qkv_mlp_proj.lora_B.weight",
}


def test_leaves_multi_adapter_state_dict_untouched():
"""If multiple distinct adapter names are present, renaming would collide, so don't."""
sd = {
"transformer_blocks.0.attn.to_q.lora_A.default.weight": _t(),
"transformer_blocks.0.attn.to_q.lora_A.other.weight": _t(),
}
assert has_peft_named_adapter_keys(sd)
assert normalize_peft_adapter_names(sd) is sd


def test_preserves_non_lora_keys_alongside_named_adapter_keys():
"""Keys that aren't lora_A/lora_B PEFT keys pass through unchanged."""
sd = {
"transformer_blocks.0.attn.to_q.lora_A.default.weight": _t(),
"transformer_blocks.0.attn.to_q.lora_B.default.weight": _t(),
"transformer_blocks.0.attn.to_q.alpha": _t(),
"metadata_like.dora_scale": _t(),
}
result = normalize_peft_adapter_names(sd)
assert "transformer_blocks.0.attn.to_q.lora_A.weight" in result
assert "transformer_blocks.0.attn.to_q.lora_B.weight" in result
assert "transformer_blocks.0.attn.to_q.alpha" in result
assert "metadata_like.dora_scale" in result


def test_preserves_integer_keys():
"""Non-string keys (some PyTorch state dicts use ints) are passed through."""
sd: dict = {
0: _t(),
"x.lora_A.default.weight": _t(),
}
result = normalize_peft_adapter_names(sd)
assert 0 in result
assert "x.lora_A.weight" in result
Loading