From 92ca5d89ffd8098d5d27f22b5f902ad069306799 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Sat, 30 May 2026 05:25:31 +0200 Subject: [PATCH] feat(lora): support PEFT named-adapter LoRAs (e.g. Klein 9B) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PEFT-format LoRAs with a named adapter encode the adapter name in the weight key (e.g. `foo.lora_A.default.weight` instead of `foo.lora_A.weight`). InvokeAI's format detection and conversion only matched the standard PEFT suffix, so these files were classified as Unknown and silently ignored. Normalize named-adapter keys to the standard PEFT form in both `ModelOnDisk.load_state_dict` (detection pipeline) and `LoRALoader._load_model` (conversion pipeline). The pattern is LoRA-specific, so this is a no-op for non-LoRA state dicts. State dicts with multiple distinct adapter names are left untouched to avoid key collisions. Also widen Flux2-Klein diffusers detection to accept keys without a `transformer.` or `base_model.model.` prefix — some trainers (Modelscope / MuseAI Klein 9B finetunes) save at the top level. --- .../model_manager/load/model_loaders/lora.py | 5 ++ .../backend/model_manager/model_on_disk.py | 7 ++ .../flux_diffusers_lora_conversion_utils.py | 5 +- .../lora_conversions/peft_adapter_utils.py | 72 +++++++++++++++ .../test_peft_adapter_utils.py | 90 +++++++++++++++++++ 5 files changed, 178 insertions(+), 1 deletion(-) create mode 100644 invokeai/backend/patches/lora_conversions/peft_adapter_utils.py create mode 100644 tests/backend/patches/lora_conversions/test_peft_adapter_utils.py diff --git a/invokeai/backend/model_manager/load/model_loaders/lora.py b/invokeai/backend/model_manager/load/model_loaders/lora.py index 6cf06d48074..15dfa376179 100644 --- a/invokeai/backend/model_manager/load/model_loaders/lora.py +++ b/invokeai/backend/model_manager/load/model_loaders/lora.py @@ -57,6 +57,7 @@ is_state_dict_likely_in_flux_xlabs_format, lora_model_from_flux_xlabs_state_dict, ) +from invokeai.backend.patches.lora_conversions.peft_adapter_utils import normalize_peft_adapter_names from invokeai.backend.patches.lora_conversions.qwen_image_lora_conversion_utils import ( lora_model_from_qwen_image_state_dict, ) @@ -105,6 +106,10 @@ def _load_model( # To revisit later to determine if they're needed/useful. state_dict = {k: v for k, v in state_dict.items() if not k.startswith("bundle_emb")} + # Normalize PEFT named-adapter keys (e.g. `lora_A.default.weight` → `lora_A.weight`) + # so the downstream format detectors and converters see canonical PEFT keys. + state_dict = normalize_peft_adapter_names(state_dict) + # At the time of writing, we support the OMI standard for base models Flux and SDXL if config.format == ModelFormat.OMI and self._model_base in [ BaseModelType.StableDiffusionXL, diff --git a/invokeai/backend/model_manager/model_on_disk.py b/invokeai/backend/model_manager/model_on_disk.py index 284c4998589..acc413b54c0 100644 --- a/invokeai/backend/model_manager/model_on_disk.py +++ b/invokeai/backend/model_manager/model_on_disk.py @@ -118,6 +118,13 @@ def load_state_dict(self, path: Optional[Path] = None) -> StateDict: raise ValueError(f"Unrecognized model extension: {path.suffix}") state_dict = checkpoint.get("state_dict", checkpoint) + + # Normalize PEFT named-adapter keys (e.g. `lora_A.default.weight` → `lora_A.weight`). + # Pattern is LoRA-specific, so this is a no-op for non-LoRA state dicts. + from invokeai.backend.patches.lora_conversions.peft_adapter_utils import normalize_peft_adapter_names + + state_dict = normalize_peft_adapter_names(state_dict) + self._state_dict_cache[path] = state_dict return state_dict diff --git a/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py b/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py index e691071a397..05fe4cab297 100644 --- a/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py +++ b/invokeai/backend/patches/lora_conversions/flux_diffusers_lora_conversion_utils.py @@ -49,7 +49,10 @@ def is_state_dict_likely_in_flux_diffusers_format(state_dict: dict[str | int, to # --- Flux2 Klein diffusers key patterns (fused QKV+MLP, ff.linear_in) --- # These use Flux2Transformer2DModel naming which differs from Flux.1. - for prefix in ["transformer.", "base_model.model."]: + # An empty prefix is supported because some trainers (e.g. PEFT-style LoRAs from + # Modelscope/MuseAI Klein 9B finetunes) save keys at the top level without any + # `transformer.` or `base_model.model.` wrapper. + for prefix in ["transformer.", "base_model.model.", ""]: has_single = any( k.startswith(f"{prefix}single_transformer_blocks.") and "to_qkv_mlp_proj" in k for k in state_dict ) diff --git a/invokeai/backend/patches/lora_conversions/peft_adapter_utils.py b/invokeai/backend/patches/lora_conversions/peft_adapter_utils.py new file mode 100644 index 00000000000..d680cd0fe2c --- /dev/null +++ b/invokeai/backend/patches/lora_conversions/peft_adapter_utils.py @@ -0,0 +1,72 @@ +"""Utilities for handling PEFT named-adapter LoRA state dicts. + +PEFT (HuggingFace Parameter-Efficient Fine-Tuning) supports multiple named adapters per model. +When saved, the adapter name is encoded in the weight key: + + Standard PEFT: foo.bar.lora_A.weight + Named-adapter PEFT: foo.bar.lora_A..weight + +The most common adapter name is "default", produced automatically by `model.add_adapter()` +without an explicit name. Some training tools (e.g. Diffusers' PEFT integration with +multi-adapter support, certain LoRA fine-tuning scripts) save in this format even with a +single adapter. + +InvokeAI's downstream LoRA detection and conversion code expects the standard PEFT suffix +(`lora_A.weight` / `lora_B.weight`). This module normalizes named-adapter state dicts to +that form so the rest of the pipeline can handle them transparently. +""" + +import re +from typing import Any + +# Match a named-adapter PEFT key ending: .lora_A..weight or .lora_B..weight. +# The adapter name is a single dot-free component (PEFT identifiers do not contain dots). +_NAMED_ADAPTER_RE = re.compile(r"\.lora_([AB])\.([^.]+)\.weight$") + + +def _extract_adapter_names(state_dict: dict[str | int, Any]) -> set[str]: + """Return the set of distinct PEFT adapter names found in the state dict. + + A "named adapter" key is one matching `.lora_A..weight` or `.lora_B..weight`. + Keys in the standard PEFT form (`.lora_A.weight` / `.lora_B.weight`) do not contribute. + """ + names: set[str] = set() + for key in state_dict: + if not isinstance(key, str): + continue + m = _NAMED_ADAPTER_RE.search(key) + if m: + names.add(m.group(2)) + return names + + +def has_peft_named_adapter_keys(state_dict: dict[str | int, Any]) -> bool: + """Check whether the state dict contains any PEFT named-adapter keys.""" + return bool(_extract_adapter_names(state_dict)) + + +def normalize_peft_adapter_names(state_dict: dict[str | int, Any]) -> dict[str | int, Any]: + """Return a state dict with PEFT named-adapter suffixes stripped to the standard form. + + Transforms: + foo.bar.lora_A..weight → foo.bar.lora_A.weight + foo.bar.lora_B..weight → foo.bar.lora_B.weight + + Only applied when the state dict contains exactly one distinct adapter name. If the + file holds multiple adapters, the keys are left untouched (renaming would collide and + multi-adapter LoRAs are not currently supported by InvokeAI). + + If no named-adapter keys are present, the input dict is returned unchanged. + """ + adapter_names = _extract_adapter_names(state_dict) + if len(adapter_names) != 1: + return state_dict + + normalized: dict[str | int, Any] = {} + for key, value in state_dict.items(): + if isinstance(key, str): + new_key = _NAMED_ADAPTER_RE.sub(r".lora_\1.weight", key) + normalized[new_key] = value + else: + normalized[key] = value + return normalized diff --git a/tests/backend/patches/lora_conversions/test_peft_adapter_utils.py b/tests/backend/patches/lora_conversions/test_peft_adapter_utils.py new file mode 100644 index 00000000000..fbc6f4e72e2 --- /dev/null +++ b/tests/backend/patches/lora_conversions/test_peft_adapter_utils.py @@ -0,0 +1,90 @@ +import torch + +from invokeai.backend.patches.lora_conversions.peft_adapter_utils import ( + has_peft_named_adapter_keys, + normalize_peft_adapter_names, +) + + +def _t() -> torch.Tensor: + return torch.zeros(1) + + +def test_no_op_when_no_named_adapter_keys(): + """State dicts without named-adapter keys are returned unchanged.""" + sd = { + "transformer_blocks.0.attn.to_q.lora_A.weight": _t(), + "transformer_blocks.0.attn.to_q.lora_B.weight": _t(), + "transformer_blocks.0.attn.to_k.lora_down.weight": _t(), + "single_blocks.0.lokr_w1": _t(), + } + assert not has_peft_named_adapter_keys(sd) + assert normalize_peft_adapter_names(sd) is sd + + +def test_strips_default_adapter_name(): + """The common `default` adapter name gets stripped from lora_A/lora_B keys.""" + sd = { + "transformer_blocks.0.attn.to_q.lora_A.default.weight": _t(), + "transformer_blocks.0.attn.to_q.lora_B.default.weight": _t(), + "transformer_blocks.0.attn.to_k.lora_A.default.weight": _t(), + "transformer_blocks.0.attn.to_k.lora_B.default.weight": _t(), + } + assert has_peft_named_adapter_keys(sd) + + result = normalize_peft_adapter_names(sd) + assert set(result.keys()) == { + "transformer_blocks.0.attn.to_q.lora_A.weight", + "transformer_blocks.0.attn.to_q.lora_B.weight", + "transformer_blocks.0.attn.to_k.lora_A.weight", + "transformer_blocks.0.attn.to_k.lora_B.weight", + } + + +def test_strips_custom_adapter_name(): + """Non-default adapter names are also stripped, as long as only one is present.""" + sd = { + "single_transformer_blocks.0.attn.to_qkv_mlp_proj.lora_A.my_adapter.weight": _t(), + "single_transformer_blocks.0.attn.to_qkv_mlp_proj.lora_B.my_adapter.weight": _t(), + } + result = normalize_peft_adapter_names(sd) + assert set(result.keys()) == { + "single_transformer_blocks.0.attn.to_qkv_mlp_proj.lora_A.weight", + "single_transformer_blocks.0.attn.to_qkv_mlp_proj.lora_B.weight", + } + + +def test_leaves_multi_adapter_state_dict_untouched(): + """If multiple distinct adapter names are present, renaming would collide, so don't.""" + sd = { + "transformer_blocks.0.attn.to_q.lora_A.default.weight": _t(), + "transformer_blocks.0.attn.to_q.lora_A.other.weight": _t(), + } + assert has_peft_named_adapter_keys(sd) + assert normalize_peft_adapter_names(sd) is sd + + +def test_preserves_non_lora_keys_alongside_named_adapter_keys(): + """Keys that aren't lora_A/lora_B PEFT keys pass through unchanged.""" + sd = { + "transformer_blocks.0.attn.to_q.lora_A.default.weight": _t(), + "transformer_blocks.0.attn.to_q.lora_B.default.weight": _t(), + "transformer_blocks.0.attn.to_q.alpha": _t(), + "metadata_like.dora_scale": _t(), + } + result = normalize_peft_adapter_names(sd) + assert "transformer_blocks.0.attn.to_q.lora_A.weight" in result + assert "transformer_blocks.0.attn.to_q.lora_B.weight" in result + assert "transformer_blocks.0.attn.to_q.alpha" in result + assert "metadata_like.dora_scale" in result + + +def test_preserves_integer_keys(): + """Non-string keys (some PyTorch state dicts use ints) are passed through.""" + sd: dict = { + 0: _t(), + "x.lora_A.default.weight": _t(), + } + result = normalize_peft_adapter_names(sd) + assert 0 in result + assert "x.lora_A.weight" in result