diff --git a/3rd/SpargeAttn b/3rd/SpargeAttn index 0fd094e..2851ae1 160000 --- a/3rd/SpargeAttn +++ b/3rd/SpargeAttn @@ -1 +1 @@ -Subproject commit 0fd094e208d5d0b625a381b051c5b44b6ce8bbc5 +Subproject commit 2851ae1f798a7a561462b09e18832ed3e6616bfd diff --git a/3rd/flash-attention b/3rd/flash-attention index aa04de6..73b37aa 160000 --- a/3rd/flash-attention +++ b/3rd/flash-attention @@ -1 +1 @@ -Subproject commit aa04de66e22fb1810eeede8ba736ccd895f16274 +Subproject commit 73b37aaf6df0024e0ddbfb434badcef8ad0b732f diff --git a/lightx2v/__main__.py b/lightx2v/__main__.py index 9a80d44..a0ac25e 100755 --- a/lightx2v/__main__.py +++ b/lightx2v/__main__.py @@ -341,6 +341,8 @@ def run_vae(latents, generator, args): parser.add_argument("--use_bfloat16", action="store_true", default=True) parser.add_argument("--lora_path", type=str, default=None) parser.add_argument("--strength_model", type=float, default=1.0) + parser.add_argument("--sparge", action="store_true", help="enable sparge attention") + parser.add_argument("--sparge_ckpt", type=str, default=None, help="path of sparge ckpts") args = parser.parse_args() start_time = time.time() @@ -368,6 +370,8 @@ def run_vae(latents, generator, args): "parallel_attn_type": args.parallel_attn_type, "parallel_vae": args.parallel_vae, "use_bfloat16": args.use_bfloat16, + "sparge": args.sparge, + "sparge_ckpt": args.sparge_ckpt, } if args.config_path is not None: diff --git a/lightx2v/common/ops/__init__.py b/lightx2v/common/ops/__init__.py index dbc3905..10cd7eb 100755 --- a/lightx2v/common/ops/__init__.py +++ b/lightx2v/common/ops/__init__.py @@ -1,3 +1,4 @@ from .mm import * from .norm import * from .conv import * +from .attn import * diff --git a/lightx2v/common/ops/attn/__init__.py b/lightx2v/common/ops/attn/__init__.py new file mode 100644 index 0000000..edb5a2b --- /dev/null +++ b/lightx2v/common/ops/attn/__init__.py @@ -0,0 +1 @@ +from .attn_weight import * diff --git a/lightx2v/common/ops/attn/attn_weight.py b/lightx2v/common/ops/attn/attn_weight.py new file mode 100644 index 0000000..9ab75a6 --- /dev/null +++ b/lightx2v/common/ops/attn/attn_weight.py @@ -0,0 +1,77 @@ +import torch +import torch.nn as nn +from abc import ABCMeta, abstractmethod +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER +from lightx2v.attentions import attention +from spas_sage_attn.autotune import SparseAttentionMeansim + + +class AttnWeightTemplate(metaclass=ABCMeta): + def __init__(self, weight_name): + self.weight_name = weight_name + self.config = {} + + def load(self, weight_dict): + pass + + @abstractmethod + def apply(self, input_tensor): + pass + + def set_config(self, config=None): + if config is not None: + self.config = config + + def to_cpu(self, non_blocking=False): + self.weight = self.weight.to("cpu", non_blocking=non_blocking) + + def to_cuda(self, non_blocking=False): + self.weight = self.weight.cuda(non_blocking=non_blocking) + + +@ATTN_WEIGHT_REGISTER("Default") +class DefaultAttnWeightTemplate(AttnWeightTemplate): + def __init__(self, attn_type): + self.attn_type = attn_type + self.config = {} + + def load(self, weight_dict): + pass + + def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None): + return attention(self.attn_type, q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, model_cls=model_cls) + + def set_config(self, config=None): + if config is not None: + self.config = config + + +@ATTN_WEIGHT_REGISTER("Sparge") +class SpargeAttnWeight(AttnWeightTemplate): + def __init__(self, weight_name, verbose=False, l1=0.07, pv_l1=0.08, tune_pv=True, inner_attn_type="flash_attn3"): + self.verbose = (verbose,) + self.l1 = (l1,) + self.pv_l1 = (pv_l1,) + self.tune_pv = (tune_pv,) + self.inner_attn_type = inner_attn_type + self.inner_cls = SparseAttentionMeansim(l1=l1, pv_l1=pv_l1, tune_pv=tune_pv) + super().__init__(weight_name) + + def load(self, weight_dict): + # match all key with prefix weight_name + for key in weight_dict.keys(): + if key.startswith(self.weight_name): + sub_name = key.split(".")[-1] + setattr(self.inner_cls, sub_name, nn.Parameter(weight_dict[key], requires_grad=False)) + + def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None): + if len(q.shape) == 3: + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + + x = self.inner_cls(q, k, v, tensor_layout="NHD") + x = x.flatten(2) + x = x.squeeze(0) + + return x diff --git a/lightx2v/text2v/models/networks/wan/infer/transformer_infer.py b/lightx2v/text2v/models/networks/wan/infer/transformer_infer.py index 23ea441..04cd65b 100755 --- a/lightx2v/text2v/models/networks/wan/infer/transformer_infer.py +++ b/lightx2v/text2v/models/networks/wan/infer/transformer_infer.py @@ -97,9 +97,7 @@ def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, co cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=seq_lens) if not self.parallel_attention: - attn_out = attention( - attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"] - ) + attn_out = weights.self_attn_1.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]) else: attn_out = self.parallel_attention( attention_type=self.attention_type, @@ -128,10 +126,9 @@ def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, co cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device)) - attn_out = attention( - attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"] - ) + attn_out = weights.cross_attn_1.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]) + # todo fix i2v if self.task == "i2v": k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d) v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d) diff --git a/lightx2v/text2v/models/networks/wan/weights/transformer_weights.py b/lightx2v/text2v/models/networks/wan/weights/transformer_weights.py index 1627963..33e7960 100755 --- a/lightx2v/text2v/models/networks/wan/weights/transformer_weights.py +++ b/lightx2v/text2v/models/networks/wan/weights/transformer_weights.py @@ -1,4 +1,5 @@ -from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER +from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER, ATTN_WEIGHT_REGISTER +import torch from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate from lightx2v.common.ops.norm.rms_norm_weight import RMSWeightTemplate @@ -13,11 +14,13 @@ def __init__(self, config): self.mm_type = "Calib" else: self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default" + self.weight_list = [] def load_weights(self, weight_dict): self.blocks_weights = [WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)] for block in self.blocks_weights: block.load_weights(weight_dict) + self.weight_list.append(block.weight_list) def to_cpu(self): for block in self.blocks_weights: @@ -34,6 +37,7 @@ def __init__(self, block_index, task, mm_type, config): self.mm_type = mm_type self.task = task self.config = config + self.weight_list = [] def load_weights(self, weight_dict): self.self_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias") @@ -54,6 +58,18 @@ def load_weights(self, weight_dict): self.ffn_0 = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.0.weight", f"blocks.{self.block_index}.ffn.0.bias") self.ffn_2 = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.2.weight", f"blocks.{self.block_index}.ffn.2.bias") self.modulation = weight_dict[f"blocks.{self.block_index}.modulation"] + # breakpoint() + + # attention weights + if self.config["sparge"]: + # print("sparge is True, just replace self attn weights") + assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True" + self.self_attn_1 = ATTN_WEIGHT_REGISTER["Sparge"](f"blocks.{self.block_index}") + self.cross_attn_1 = ATTN_WEIGHT_REGISTER["Default"](attn_type=self.config["attention_type"]) + else: + # print("sparge is False") + self.self_attn_1 = ATTN_WEIGHT_REGISTER["Default"](attn_type=self.config["attention_type"]) + self.cross_attn_1 = ATTN_WEIGHT_REGISTER["Default"](attn_type=self.config["attention_type"]) self.weight_list = [ self.self_attn_q, @@ -86,6 +102,15 @@ def load_weights(self, weight_dict): mm_weight.set_config(self.config["mm_config"]) mm_weight.load(weight_dict) + # load attn weights + if self.config["sparge"]: + assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True" + sparge_ckpt = torch.load(self.config["sparge_ckpt"]) + self.self_attn_1.load(sparge_ckpt) + else: + # do not load weights + pass + def to_cpu(self): for mm_weight in self.weight_list: if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)): diff --git a/lightx2v/utils/registry_factory.py b/lightx2v/utils/registry_factory.py index c510554..14c6b83 100644 --- a/lightx2v/utils/registry_factory.py +++ b/lightx2v/utils/registry_factory.py @@ -45,6 +45,7 @@ def items(self): MM_WEIGHT_REGISTER = Register() +ATTN_WEIGHT_REGISTER = Register() RMS_WEIGHT_REGISTER = Register() LN_WEIGHT_REGISTER = Register() CONV3D_WEIGHT_REGISTER = Register()