Skip to content

Feature spargeattn #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rd/flash-attention
4 changes: 4 additions & 0 deletions lightx2v/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,8 @@ def run_vae(latents, generator, args):
parser.add_argument("--use_bfloat16", action="store_true", default=True)
parser.add_argument("--lora_path", type=str, default=None)
parser.add_argument("--strength_model", type=float, default=1.0)
parser.add_argument("--sparge", action="store_true", help="enable sparge attention")
parser.add_argument("--sparge_ckpt", type=str, default=None, help="path of sparge ckpts")
args = parser.parse_args()

start_time = time.time()
Expand Down Expand Up @@ -368,6 +370,8 @@ def run_vae(latents, generator, args):
"parallel_attn_type": args.parallel_attn_type,
"parallel_vae": args.parallel_vae,
"use_bfloat16": args.use_bfloat16,
"sparge": args.sparge,
"sparge_ckpt": args.sparge_ckpt,
}

if args.config_path is not None:
Expand Down
1 change: 1 addition & 0 deletions lightx2v/common/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .mm import *
from .norm import *
from .conv import *
from .attn import *
1 change: 1 addition & 0 deletions lightx2v/common/ops/attn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .attn_weight import *
77 changes: 77 additions & 0 deletions lightx2v/common/ops/attn/attn_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch
import torch.nn as nn
from abc import ABCMeta, abstractmethod
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from lightx2v.attentions import attention
from spas_sage_attn.autotune import SparseAttentionMeansim


class AttnWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name):
self.weight_name = weight_name
self.config = {}

def load(self, weight_dict):
pass

@abstractmethod
def apply(self, input_tensor):
pass

def set_config(self, config=None):
if config is not None:
self.config = config

def to_cpu(self, non_blocking=False):
self.weight = self.weight.to("cpu", non_blocking=non_blocking)

def to_cuda(self, non_blocking=False):
self.weight = self.weight.cuda(non_blocking=non_blocking)


@ATTN_WEIGHT_REGISTER("Default")
class DefaultAttnWeightTemplate(AttnWeightTemplate):
def __init__(self, attn_type):
self.attn_type = attn_type
self.config = {}

def load(self, weight_dict):
pass

def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
return attention(self.attn_type, q, k, v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, model_cls=model_cls)

def set_config(self, config=None):
if config is not None:
self.config = config


@ATTN_WEIGHT_REGISTER("Sparge")
class SpargeAttnWeight(AttnWeightTemplate):
def __init__(self, weight_name, verbose=False, l1=0.07, pv_l1=0.08, tune_pv=True, inner_attn_type="flash_attn3"):
self.verbose = (verbose,)
self.l1 = (l1,)
self.pv_l1 = (pv_l1,)
self.tune_pv = (tune_pv,)
self.inner_attn_type = inner_attn_type
self.inner_cls = SparseAttentionMeansim(l1=l1, pv_l1=pv_l1, tune_pv=tune_pv)
super().__init__(weight_name)

def load(self, weight_dict):
# match all key with prefix weight_name
for key in weight_dict.keys():
if key.startswith(self.weight_name):
sub_name = key.split(".")[-1]
setattr(self.inner_cls, sub_name, nn.Parameter(weight_dict[key], requires_grad=False))

def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None):
if len(q.shape) == 3:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)

x = self.inner_cls(q, k, v, tensor_layout="NHD")
x = x.flatten(2)
x = x.squeeze(0)

return x
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, co
cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=seq_lens)

if not self.parallel_attention:
attn_out = attention(
attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
)
attn_out = weights.self_attn_1.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"])
else:
attn_out = self.parallel_attention(
attention_type=self.attention_type,
Expand Down Expand Up @@ -128,10 +126,9 @@ def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, co

cu_seqlens_q, cu_seqlens_k, lq, lk = self._calculate_q_k_len(q, k, k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device))

attn_out = attention(
attention_type=self.attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"]
)
attn_out = weights.cross_attn_1.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_k, max_seqlen_q=lq, max_seqlen_kv=lk, model_cls=self.config["model_cls"])

# todo fix i2v
if self.task == "i2v":
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER, ATTN_WEIGHT_REGISTER
import torch
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
from lightx2v.common.ops.norm.rms_norm_weight import RMSWeightTemplate
Expand All @@ -13,11 +14,13 @@ def __init__(self, config):
self.mm_type = "Calib"
else:
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
self.weight_list = []

def load_weights(self, weight_dict):
self.blocks_weights = [WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)]
for block in self.blocks_weights:
block.load_weights(weight_dict)
self.weight_list.append(block.weight_list)

def to_cpu(self):
for block in self.blocks_weights:
Expand All @@ -34,6 +37,7 @@ def __init__(self, block_index, task, mm_type, config):
self.mm_type = mm_type
self.task = task
self.config = config
self.weight_list = []

def load_weights(self, weight_dict):
self.self_attn_q = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.self_attn.q.weight", f"blocks.{self.block_index}.self_attn.q.bias")
Expand All @@ -54,6 +58,18 @@ def load_weights(self, weight_dict):
self.ffn_0 = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.0.weight", f"blocks.{self.block_index}.ffn.0.bias")
self.ffn_2 = MM_WEIGHT_REGISTER[self.mm_type](f"blocks.{self.block_index}.ffn.2.weight", f"blocks.{self.block_index}.ffn.2.bias")
self.modulation = weight_dict[f"blocks.{self.block_index}.modulation"]
# breakpoint()

# attention weights
if self.config["sparge"]:
# print("sparge is True, just replace self attn weights")
assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
self.self_attn_1 = ATTN_WEIGHT_REGISTER["Sparge"](f"blocks.{self.block_index}")
self.cross_attn_1 = ATTN_WEIGHT_REGISTER["Default"](attn_type=self.config["attention_type"])
else:
# print("sparge is False")
self.self_attn_1 = ATTN_WEIGHT_REGISTER["Default"](attn_type=self.config["attention_type"])
self.cross_attn_1 = ATTN_WEIGHT_REGISTER["Default"](attn_type=self.config["attention_type"])

self.weight_list = [
self.self_attn_q,
Expand Down Expand Up @@ -86,6 +102,15 @@ def load_weights(self, weight_dict):
mm_weight.set_config(self.config["mm_config"])
mm_weight.load(weight_dict)

# load attn weights
if self.config["sparge"]:
assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
sparge_ckpt = torch.load(self.config["sparge_ckpt"])
self.self_attn_1.load(sparge_ckpt)
else:
# do not load weights
pass

def to_cpu(self):
for mm_weight in self.weight_list:
if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate, RMSWeightTemplate)):
Expand Down
1 change: 1 addition & 0 deletions lightx2v/utils/registry_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def items(self):


MM_WEIGHT_REGISTER = Register()
ATTN_WEIGHT_REGISTER = Register()
RMS_WEIGHT_REGISTER = Register()
LN_WEIGHT_REGISTER = Register()
CONV3D_WEIGHT_REGISTER = Register()
Expand Down