Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions diffsynth_engine/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ def _parse_tuple(value: str) -> Tuple[int, int] | int:
raise ValueError(f"Cannot parse tuple: {value}, format should be '256,256' or '256'")


def _parse_attention_type(attn_type_str: str) -> AttentionType:
def _parse_attention_type(attn_type_str: str | None) -> AttentionType | None:
"""Convert string to AttentionType enum"""
if attn_type_str is None:
return None
return AttentionType[attn_type_str.upper()]


Expand Down Expand Up @@ -106,9 +108,9 @@ def parse_cli_args() -> Dict[str, Any]:
attn_group.add_argument(
"--attn-type",
type=str,
default="sdpa",
default=None,
choices=attn_type_choices,
help="Attention type (default: sdpa)",
help="Attention type (default: auto, SDPA on GPU, MINDIE on NPU)",
)
attn_group.add_argument(
"--sparge-topk",
Expand Down
2 changes: 1 addition & 1 deletion diffsynth_engine/configs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class PipelineConfig:
vae_tile_stride: int | Tuple[int, int] = (192, 192)

# attention
attn_type: AttentionType = AttentionType.SDPA
attn_type: AttentionType | None = None # None = auto-detect
attn_params: Optional[AttentionParams] = None

# parallelism
Expand Down
1 change: 1 addition & 0 deletions diffsynth_engine/layers/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class AttentionType(enum.Enum):
SAGE2 = enum.auto()
SAGE3 = enum.auto()
SPARGE = enum.auto()
MINDIE = enum.auto()

def __str__(self) -> str:
return self.name.lower()
Expand Down
79 changes: 79 additions & 0 deletions diffsynth_engine/layers/attention/backends/mindie_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
from diffsynth_engine.layers.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionType,
)
from diffsynth_engine.utils.import_utils import is_npu_available


class MindieAttentionBackend(AttentionBackend):
@staticmethod
def check_availability() -> None:
if not is_npu_available():
raise RuntimeError("NPU is not available, cannot use MINDIE attention backend")

@staticmethod
def get_type() -> AttentionType:
return AttentionType.MINDIE

@staticmethod
def get_impl_cls() -> type["AttentionImpl"]:
return MindieAttentionImpl

@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return AttentionMetadata

@staticmethod
def get_builder_cls() -> type:
return None

@staticmethod
def get_supported_head_sizes() -> list[int]:
return []


class MindieAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
softmax_scale: float | None = None,
causal: bool = False,
num_kv_heads: int | None = None,
**extra_impl_args,
) -> None:
if num_kv_heads is None:
num_kv_heads = num_heads
self.num_kv_groups = num_heads // num_kv_heads
self.causal = causal
self.softmax_scale = softmax_scale
self.num_heads = num_heads
self.head_size = head_size

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
attn_metadata=None,
) -> torch.Tensor:
from mindiesd.layers.flash_attn.attention_forward import attention_forward

scale = self.softmax_scale
if scale is None:
scale = self.head_size ** -0.5

out = attention_forward(
query=query,
key=key,
value=value,
attn_mask=attn_mask,
scale=scale,
fused=True,
head_first=False,
)
return out
9 changes: 6 additions & 3 deletions diffsynth_engine/layers/attention/selector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import cache

from diffsynth_engine.layers.attention.backends.abstract import AttentionBackend, AttentionType
from diffsynth_engine.utils.import_utils import LazyImport
from diffsynth_engine.utils.import_utils import LazyImport, is_npu_available

AiterBackend = LazyImport("diffsynth_engine.layers.attention.backends.aiter", "AiterBackend")
AiterFP8Backend = LazyImport("diffsynth_engine.layers.attention.backends.aiter", "AiterFP8Backend")
Expand All @@ -15,6 +15,7 @@
SageAttention3Backend = LazyImport("diffsynth_engine.layers.attention.backends.sage_attn_3", "SageAttention3Backend")
SDPABackend = LazyImport("diffsynth_engine.layers.attention.backends.sdpa", "SDPABackend")
SpargeAttentionBackend = LazyImport("diffsynth_engine.layers.attention.backends.sparge_attn", "SpargeAttentionBackend")
MindieAttentionBackend = LazyImport("diffsynth_engine.layers.attention.backends.mindie_attn", "MindieAttentionBackend")

_attention_backends = {
AttentionType.AITER: AiterBackend,
Expand All @@ -27,14 +28,16 @@
AttentionType.SAGE3: SageAttention3Backend,
AttentionType.SDPA: SDPABackend,
AttentionType.SPARGE: SpargeAttentionBackend,
AttentionType.MINDIE: MindieAttentionBackend,
}


@cache
def get_attn_backend(head_size: int, attn_type: AttentionType | None = None) -> type["AttentionBackend"]:
# use SDPA as default
if attn_type is None:
attn_type = AttentionType.SDPA
# Auto-detect: NPU → MINDIE, otherwise → SDPA
attn_type = AttentionType.MINDIE if is_npu_available() else AttentionType.SDPA

selected_backend = _attention_backends[attn_type]
selected_backend.check_availability()
if not selected_backend.supports_head_size(head_size):
Expand Down
72 changes: 72 additions & 0 deletions diffsynth_engine/layers/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch.nn as nn
import torch.nn.functional as F
from diffsynth_engine.utils.import_utils import is_npu_available

try:
import torch_npu
except ImportError:
torch_npu = None


class _GELUProj(nn.Module):
"""Wrapper to match diffusers FeedForward GELU structure with internal proj.

This wrapper holds the first Linear layer as .proj to match checkpoint keys.
"""

def __init__(self, dim, inner_dim):
super().__init__()
self.proj = nn.Linear(dim, inner_dim)

def forward(self, x):
return F.gelu(x, approximate="tanh")


class FastGELUMLP(nn.Module):
"""MLP with npu_fast_gelu on NPU, fallback to F.gelu on other devices.

Functionally equivalent to diffusers.models.attention.FeedForward(
dim=dim, dim_out=dim, activation_fn="gelu-approximate"
)
"""

def __init__(self, dim, dim_out=None, mult=4):
"""Initialize MLP.

Args:
dim: Input and output dimension
dim_out: Output dimension, defaults to dim
mult: inner_dim = dim * mult, defaults to 4
"""
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out or dim

# Match diffusers FeedForward structure: net[0]=GELU(proj), net[2]=output
# net[1] is Dropout which is skipped in inference
self.net = nn.ModuleList([
_GELUProj(dim, inner_dim),
nn.Dropout(0.0),
nn.Linear(inner_dim, dim_out),
])

def forward(self, hidden_states):
"""Forward pass.

Args:
hidden_states: Input tensor, shape [B, S, dim]

Returns:
Output tensor, shape [B, S, dim_out]
"""
# net[0] = _GELUProj with internal proj (dim → inner_dim)
hidden_states = self.net[0].proj(hidden_states)

if is_npu_available() and torch_npu is not None:
hidden_states = torch_npu.npu_fast_gelu(hidden_states)
else:
hidden_states = F.gelu(hidden_states, approximate="tanh")

# net[2] = output Linear (inner_dim → dim_out)
hidden_states = self.net[2](hidden_states)
return hidden_states
82 changes: 82 additions & 0 deletions diffsynth_engine/layers/norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import torch.nn as nn
from diffusers.models.normalization import RMSNorm as DiffusersRMSNorm
from diffsynth_engine.utils.import_utils import is_npu_available

try:
import torch_npu
except ImportError:
torch_npu = None

try:
from mindiesd.layers import layernorm_scale_shift
except ImportError:
layernorm_scale_shift = None


class RMSNorm(nn.Module):
"""NPU-optimized RMSNorm wrapper with fallback to diffusers implementation."""

def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
# Cache the fallback instance so forward() reuses the same weight
# tensor. register_parameter is reference assignment (no copy), so
# self.weight and self._fallback.weight share the same storage.
# When a checkpoint writes to "weight", both paths see the update.
fallback = DiffusersRMSNorm(hidden_size, eps)
self.register_parameter("weight", fallback.weight)
# Use object.__setattr__ to avoid registering _fallback as an
# nn.Module submodule, which would add spurious keys to state_dict()
# and break strict checkpoint loading.
object.__setattr__(self, "_fallback", fallback)

def forward(self, hidden_states):
if is_npu_available() and torch_npu is not None:
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0]
else:
return self._fallback(hidden_states)


class AdaLayerNorm(nn.Module):
"""NPU-optimized AdaLayerNorm with fallback to original implementation.

Performs: output = layernorm(x) * (1 + scale) + shift

Args:
layernorm: The underlying nn.LayerNorm module (elementwise_affine=False)
"""

def __init__(self, layernorm: nn.LayerNorm):
super().__init__()
self.layernorm = layernorm

def forward(self, hidden_states: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden_states: Input tensor, shape [B, S, H]
scale: Scale parameter, shape [B, H] or [B, 1, H]
shift: Shift parameter, shape [B, H] or [B, 1, H]

Returns:
layernorm(x) * (1 + scale) + shift
"""
if is_npu_available() and layernorm_scale_shift is not None:
# NPU path: use MindIE-SD fused operator
return layernorm_scale_shift(
layernorm=self.layernorm,
x=hidden_states,
scale=scale,
shift=shift,
fused=True
)
else:
# Fallback: original Python implementation
normed = self.layernorm(hidden_states)
# Handle [B, 1, H] -> [B, H] dimension
if scale.dim() == 2:
scale = scale.unsqueeze(1)
if shift.dim() == 2:
shift = shift.unsqueeze(1)
return normed * (1 + scale) + shift
Loading