-
Notifications
You must be signed in to change notification settings - Fork 721
Open
Labels
enhancementNew feature or requestNew feature or request
Description
Transformers 0.51 delivers significantly better performance for MoE models. We should also update the unsloth and unsloth-zoo dependencies.
Important: With Transformers 0.51, Unsloth produces LoRAs in a different format that is incompatible with vLLM, so we need to confirm compatibility. If it does not work, consider using a conversion script like the one below to perform the update (you may need to adjust the parameters).
"""
Convert Unsloth/PEFT fused MoE LoRA adapter to per-expert format compatible with vLLM.
Unsloth (transformers v5.x) saves MoE expert LoRA as fused 2D tensors:
mlp.experts.base_layer.lora_A [num_experts*rank, intermediate*2] (gate_up_proj)
mlp.experts.base_layer.lora_B [hidden, num_experts*rank] (gate_up_proj)
mlp.experts.lora_A [num_experts*rank, hidden] (down_proj)
mlp.experts.lora_B [intermediate, num_experts*rank] (down_proj)
vLLM expects per-expert keys like:
mlp.experts.0.gate_proj.lora_A [rank, hidden]
mlp.experts.0.gate_proj.lora_B [intermediate, rank]
... for each expert and projection
"""
import json
import os
import re
import shutil
import torch
import safetensors.torch
ADAPTER_PATH = "checkpoints/0002"
OUTPUT_PATH = "checkpoints/0002_vllm"
NUM_EXPERTS = 128
RANK = 8
INTERMEDIATE_SIZE = 768
HIDDEN_SIZE = 2048
def convert():
print("=" * 60)
print("Converting fused MoE LoRA adapter to per-expert format")
print("=" * 60)
# Load original adapter
src = os.path.join(ADAPTER_PATH, "adapter_model.safetensors")
print(f"\nLoading adapter from {src}...")
tensors = safetensors.torch.load_file(src)
new_tensors = {}
converted_expert_layers = 0
for key, tensor in tensors.items():
# Attention layers: keep as-is
if "self_attn" in key:
new_tensors[key] = tensor
continue
# Expert layers: convert fused → per-expert
# Match: base_model.model.model.layers.{N}.mlp.experts.base_layer.lora_{A|B}.weight
m = re.match(
r"(base_model\.model\.model\.layers\.(\d+)\.mlp\.experts)\.(base_layer\.lora_(A|B)|lora_(A|B))\.weight",
key,
)
if not m:
new_tensors[key] = tensor
continue
prefix = m.group(1) # base_model.model.model.layers.N.mlp.experts
is_base_layer = "base_layer" in key # True = gate_up_proj, False = down_proj
is_A = "lora_A" in key
if is_base_layer:
# gate_up_proj (fused gate + up)
if is_A:
# [num_experts*rank, intermediate*2] → per expert [rank, intermediate*2]
# Then split into gate [rank, intermediate] and up [rank, intermediate]
per_expert = tensor.reshape(NUM_EXPERTS, RANK, INTERMEDIATE_SIZE * 2)
for e in range(NUM_EXPERTS):
expert_a = per_expert[e] # [rank, intermediate*2]
gate_a = expert_a[:, :INTERMEDIATE_SIZE] # [rank, intermediate]
up_a = expert_a[:, INTERMEDIATE_SIZE:] # [rank, intermediate]
# In the 3D format, A maps from "output" dim. For nn.Linear, swap A↔B and transpose.
# nn.Linear gate_proj: weight [intermediate, hidden], lora_B [intermediate, rank]
new_tensors[f"{prefix}.{e}.gate_proj.lora_B.weight"] = gate_a.T.contiguous() # [intermediate, rank]
new_tensors[f"{prefix}.{e}.up_proj.lora_B.weight"] = up_a.T.contiguous() # [intermediate, rank]
else:
# [hidden, num_experts*rank] → per expert [hidden, rank]
per_expert = tensor.reshape(HIDDEN_SIZE, NUM_EXPERTS, RANK)
for e in range(NUM_EXPERTS):
expert_b = per_expert[:, e, :] # [hidden, rank]
# This B becomes lora_A in nn.Linear convention (transposed)
# nn.Linear gate_proj: lora_A [rank, hidden]
new_tensors[f"{prefix}.{e}.gate_proj.lora_A.weight"] = expert_b.T.contiguous() # [rank, hidden]
new_tensors[f"{prefix}.{e}.up_proj.lora_A.weight"] = expert_b.T.contiguous() # [rank, hidden]
else:
# down_proj
if is_A:
# [num_experts*rank, hidden] → per expert [rank, hidden]
per_expert = tensor.reshape(NUM_EXPERTS, RANK, HIDDEN_SIZE)
for e in range(NUM_EXPERTS):
expert_a = per_expert[e] # [rank, hidden]
# For nn.Linear down_proj [hidden, intermediate]: lora_B [hidden, rank]
new_tensors[f"{prefix}.{e}.down_proj.lora_B.weight"] = expert_a.T.contiguous() # [hidden, rank]
else:
# [intermediate, num_experts*rank] → per expert [intermediate, rank]
per_expert = tensor.reshape(INTERMEDIATE_SIZE, NUM_EXPERTS, RANK)
for e in range(NUM_EXPERTS):
expert_b = per_expert[:, e, :] # [intermediate, rank]
# For nn.Linear down_proj: lora_A [rank, intermediate]
new_tensors[f"{prefix}.{e}.down_proj.lora_A.weight"] = expert_b.T.contiguous() # [rank, intermediate]
converted_expert_layers += 1
print(f"Converted {converted_expert_layers} fused expert tensors")
print(f"Total output tensors: {len(new_tensors)} (was {len(tensors)})")
# Verify a sample: compute delta_W for original and converted, compare
print("\nVerifying conversion correctness...")
# Original gate_up_proj delta for layer 0, expert 0
orig_a = tensors["base_model.model.model.layers.0.mlp.experts.base_layer.lora_A.weight"]
orig_b = tensors["base_model.model.model.layers.0.mlp.experts.base_layer.lora_B.weight"]
# Per-expert delta: B[hidden, experts*rank] → B_e[hidden, rank], A[experts*rank, inter*2] → A_e[rank, inter*2]
orig_a_e0 = orig_a[:RANK, :] # [rank, inter*2]
orig_b_e0 = orig_b[:, :RANK] # [hidden, rank]
orig_delta = orig_b_e0 @ orig_a_e0 # [hidden, inter*2]
# Converted: gate_proj for expert 0
conv_gate_a = new_tensors["base_model.model.model.layers.0.mlp.experts.0.gate_proj.lora_A.weight"] # [rank, hidden]
conv_gate_b = new_tensors["base_model.model.model.layers.0.mlp.experts.0.gate_proj.lora_B.weight"] # [inter, rank]
conv_up_a = new_tensors["base_model.model.model.layers.0.mlp.experts.0.up_proj.lora_A.weight"] # [rank, hidden]
conv_up_b = new_tensors["base_model.model.model.layers.0.mlp.experts.0.up_proj.lora_B.weight"] # [inter, rank]
# Standard LoRA delta: B @ A = [out, rank] @ [rank, in] = [out, in]
conv_gate_delta = conv_gate_b @ conv_gate_a # [inter, hidden]
conv_up_delta = conv_up_b @ conv_up_a # [inter, hidden]
conv_delta = torch.cat([conv_gate_delta, conv_up_delta], dim=0) # [inter*2, hidden]
# Original delta is [hidden, inter*2] (transposed convention)
# Converted delta is [inter*2, hidden] (standard convention)
diff = (orig_delta.T.float() - conv_delta.float()).abs().max().item()
print(f" gate_up_proj layer 0, expert 0: max diff = {diff:.10f}")
assert diff < 1e-5, f"Verification failed! Max diff: {diff}"
print(" Verification PASSED!")
# Save converted adapter
os.makedirs(OUTPUT_PATH, exist_ok=True)
out_path = os.path.join(OUTPUT_PATH, "adapter_model.safetensors")
print(f"\nSaving converted adapter to {out_path}...")
safetensors.torch.save_file(new_tensors, out_path)
# Update adapter_config.json
with open(os.path.join(ADAPTER_PATH, "adapter_config.json")) as f:
config = json.load(f)
config["target_modules"] = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
config.pop("target_parameters", None)
config["base_model_name_or_path"] = "Qwen/Qwen3-30B-A3B-Instruct-2507"
with open(os.path.join(OUTPUT_PATH, "adapter_config.json"), "w") as f:
json.dump(config, f, indent=2)
print(f"Updated adapter_config.json (removed target_parameters, fixed base_model_name_or_path)")
# Print sample keys
print("\nSample converted keys:")
sample_keys = sorted(new_tensors.keys())[:12]
for k in sample_keys:
print(f" {k}: {new_tensors[k].shape}")
print("\n" + "=" * 60)
print(f"Converted adapter saved to: {OUTPUT_PATH}")
print("=" * 60)
if __name__ == "__main__":
convert()
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
enhancementNew feature or requestNew feature or request