From ee69f31c66dfb94ecca041e731fa0fb040d3c1e2 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 16 Jan 2026 09:58:30 +0000 Subject: [PATCH 1/5] ring_attn_fp8_kv_fusion --- lightx2v/common/ops/attn/ring_attn.py | 105 +++++++++++++++++++++++--- 1 file changed, 93 insertions(+), 12 deletions(-) diff --git a/lightx2v/common/ops/attn/ring_attn.py b/lightx2v/common/ops/attn/ring_attn.py index e9dcaa44..24fc394e 100644 --- a/lightx2v/common/ops/attn/ring_attn.py +++ b/lightx2v/common/ops/attn/ring_attn.py @@ -4,6 +4,7 @@ from loguru import logger from lightx2v.utils.envs import * +from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from .template import AttnWeightTemplate @@ -41,7 +42,21 @@ class RingAttnWeight(AttnWeightTemplate): def __init__(self): self.config = {} - def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, attention_type="flash_attn2", seq_p_group=None, use_fp8_comm=False, enable_head_parallel=False, **kwargs): + def apply( + self, + q, + k, + v, + slice_qkv_len, + cu_seqlens_qkv, + attention_module=None, + attention_type="flash_attn2", + seq_p_group=None, + use_fp8_comm=False, + use_kv_fusion=False, + enable_head_parallel=False, + **kwargs, + ): """ 执行 Ring 注意力机制,结合图像和文本的查询、键和值。 @@ -56,13 +71,13 @@ def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, att 返回: torch.Tensor: 计算得到的注意力结果 """ - assert not use_fp8_comm, "RingAttn can't support fp8 comm now." assert not enable_head_parallel, "RingAttn can't support head parallel mode." # 获取当前进程的排名和全局进程数 cur_rank = dist.get_rank(seq_p_group) world_size = dist.get_world_size(seq_p_group) + img_qkv_len = slice_qkv_len if len(cu_seqlens_qkv) == 3: txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度 @@ -85,6 +100,7 @@ def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, att k = k.unsqueeze(0) v = v.unsqueeze(0) + heads, hidden_dims = k.shape[-2], k.shape[-1] img_q, img_k, img_v = q[:, :img_qkv_len, :, :].contiguous(), k[:, :img_qkv_len, :, :].contiguous(), v[:, :img_qkv_len, :, :].contiguous() txt_q, txt_k, txt_v = ( q[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(), @@ -92,31 +108,78 @@ def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, att v[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(), ) - out, lse, next_k, next_v = None, None, None, None + out, lse, next_k, next_v, next_kv = None, None, None, None, None if len(cu_seqlens_qkv) == 3: q = torch.cat((img_q, txt_q), dim=1) k = img_k v = img_v + if use_kv_fusion: + kv = torch.stack([img_k, img_v], dim=0).reshape(2, img_qkv_len, heads, hidden_dims).contiguous() + txt_kv = torch.stack([txt_k, txt_v], dim=0).reshape(2, txt_qkv_len, heads, hidden_dims).contiguous() + original_dtype = kv.dtype + original_shape = kv.shape + else: + original_dtype = k.dtype + original_shape = k.shape + for step in range(world_size): if step + 1 != world_size: - next_k = RING_COMM.send_recv(k) - next_v = RING_COMM.send_recv(v) - RING_COMM.commit() + if use_fp8_comm: + if use_kv_fusion: + kv_fp8, kv_scale = quant_fp8_vllm(kv.reshape(-1, hidden_dims)) + kv_fp8 = kv_fp8.reshape(original_shape) + kv_scale = kv_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1) + next_kv_fp8 = RING_COMM.send_recv(kv_fp8) + next_kv_scale = RING_COMM.send_recv(kv_scale) + else: + k_fp8, k_scale = quant_fp8_vllm(k.reshape(-1, hidden_dims)) + v_fp8, v_scale = quant_fp8_vllm(v.reshape(-1, hidden_dims)) + k_fp8 = k_fp8.reshape(original_shape) + v_fp8 = v_fp8.reshape(original_shape) + k_scale = k_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1) + v_scale = v_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1) + next_k_fp8 = RING_COMM.send_recv(k_fp8) + next_k_scale = RING_COMM.send_recv(k_scale) + next_v_fp8 = RING_COMM.send_recv(v_fp8) + next_v_scale = RING_COMM.send_recv(v_scale) + RING_COMM.commit() + else: + if use_kv_fusion: + next_kv = RING_COMM.send_recv(kv) + else: + next_k = RING_COMM.send_recv(k) + next_v = RING_COMM.send_recv(v) + RING_COMM.commit() if step + 1 == world_size: - k = torch.cat((k, txt_k), dim=1) - v = torch.cat((v, txt_v), dim=1) - - block_out, block_lse = self.ring_attn_sub(q, k, v) + if use_kv_fusion: + next_kv = torch.cat((kv, txt_kv), dim=1) + else: + k = torch.cat((k, txt_k), dim=1) + v = torch.cat((v, txt_v), dim=1) + if use_kv_fusion: + block_out, block_lse = self.ring_attn_sub_kv_fusion(q, kv) + else: + block_out, block_lse = self.ring_attn_sub(q, k, v) out, lse = self.update_out_and_lse(out, lse, block_out, block_lse) if step + 1 != world_size: RING_COMM.wait() - k = next_k - v = next_v + if use_fp8_comm: + if use_kv_fusion: + kv = dequant_fp8_vllm(next_kv_fp8, next_kv_scale, original_dtype) + else: + k = dequant_fp8_vllm(next_k_fp8, next_k_scale, original_dtype) + v = dequant_fp8_vllm(next_v_fp8, next_v_scale, original_dtype) + else: + if use_kv_fusion: + kv = next_kv + else: + k = next_k + v = next_v attn1 = out.to(GET_DTYPE()).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1) @@ -140,6 +203,24 @@ def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, att return attn1 + def ring_attn_sub_kv_fusion(self, q, kv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + block_out, block_lse, _, _ = flash_attn.flash_attn_interface._flash_attn_forward( + q, + kv[:1, :, :, :], + kv[1:, :, :, :], + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax, + ) + return block_out, block_lse + def ring_attn_sub(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) From 37febdb82f6f930e4a7ff825fb78e041020acee5 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 16 Jan 2026 10:38:31 +0000 Subject: [PATCH 2/5] main-sync --- lightx2v/common/ops/attn/ring_attn.py | 105 +++----------------------- 1 file changed, 12 insertions(+), 93 deletions(-) diff --git a/lightx2v/common/ops/attn/ring_attn.py b/lightx2v/common/ops/attn/ring_attn.py index 24fc394e..e9dcaa44 100644 --- a/lightx2v/common/ops/attn/ring_attn.py +++ b/lightx2v/common/ops/attn/ring_attn.py @@ -4,7 +4,6 @@ from loguru import logger from lightx2v.utils.envs import * -from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from .template import AttnWeightTemplate @@ -42,21 +41,7 @@ class RingAttnWeight(AttnWeightTemplate): def __init__(self): self.config = {} - def apply( - self, - q, - k, - v, - slice_qkv_len, - cu_seqlens_qkv, - attention_module=None, - attention_type="flash_attn2", - seq_p_group=None, - use_fp8_comm=False, - use_kv_fusion=False, - enable_head_parallel=False, - **kwargs, - ): + def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, attention_type="flash_attn2", seq_p_group=None, use_fp8_comm=False, enable_head_parallel=False, **kwargs): """ 执行 Ring 注意力机制,结合图像和文本的查询、键和值。 @@ -71,13 +56,13 @@ def apply( 返回: torch.Tensor: 计算得到的注意力结果 """ + assert not use_fp8_comm, "RingAttn can't support fp8 comm now." assert not enable_head_parallel, "RingAttn can't support head parallel mode." # 获取当前进程的排名和全局进程数 cur_rank = dist.get_rank(seq_p_group) world_size = dist.get_world_size(seq_p_group) - img_qkv_len = slice_qkv_len if len(cu_seqlens_qkv) == 3: txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度 @@ -100,7 +85,6 @@ def apply( k = k.unsqueeze(0) v = v.unsqueeze(0) - heads, hidden_dims = k.shape[-2], k.shape[-1] img_q, img_k, img_v = q[:, :img_qkv_len, :, :].contiguous(), k[:, :img_qkv_len, :, :].contiguous(), v[:, :img_qkv_len, :, :].contiguous() txt_q, txt_k, txt_v = ( q[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(), @@ -108,78 +92,31 @@ def apply( v[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(), ) - out, lse, next_k, next_v, next_kv = None, None, None, None, None + out, lse, next_k, next_v = None, None, None, None if len(cu_seqlens_qkv) == 3: q = torch.cat((img_q, txt_q), dim=1) k = img_k v = img_v - if use_kv_fusion: - kv = torch.stack([img_k, img_v], dim=0).reshape(2, img_qkv_len, heads, hidden_dims).contiguous() - txt_kv = torch.stack([txt_k, txt_v], dim=0).reshape(2, txt_qkv_len, heads, hidden_dims).contiguous() - original_dtype = kv.dtype - original_shape = kv.shape - else: - original_dtype = k.dtype - original_shape = k.shape - for step in range(world_size): if step + 1 != world_size: - if use_fp8_comm: - if use_kv_fusion: - kv_fp8, kv_scale = quant_fp8_vllm(kv.reshape(-1, hidden_dims)) - kv_fp8 = kv_fp8.reshape(original_shape) - kv_scale = kv_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1) - next_kv_fp8 = RING_COMM.send_recv(kv_fp8) - next_kv_scale = RING_COMM.send_recv(kv_scale) - else: - k_fp8, k_scale = quant_fp8_vllm(k.reshape(-1, hidden_dims)) - v_fp8, v_scale = quant_fp8_vllm(v.reshape(-1, hidden_dims)) - k_fp8 = k_fp8.reshape(original_shape) - v_fp8 = v_fp8.reshape(original_shape) - k_scale = k_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1) - v_scale = v_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1) - next_k_fp8 = RING_COMM.send_recv(k_fp8) - next_k_scale = RING_COMM.send_recv(k_scale) - next_v_fp8 = RING_COMM.send_recv(v_fp8) - next_v_scale = RING_COMM.send_recv(v_scale) - RING_COMM.commit() - else: - if use_kv_fusion: - next_kv = RING_COMM.send_recv(kv) - else: - next_k = RING_COMM.send_recv(k) - next_v = RING_COMM.send_recv(v) - RING_COMM.commit() + next_k = RING_COMM.send_recv(k) + next_v = RING_COMM.send_recv(v) + RING_COMM.commit() if step + 1 == world_size: - if use_kv_fusion: - next_kv = torch.cat((kv, txt_kv), dim=1) - else: - k = torch.cat((k, txt_k), dim=1) - v = torch.cat((v, txt_v), dim=1) + k = torch.cat((k, txt_k), dim=1) + v = torch.cat((v, txt_v), dim=1) + + block_out, block_lse = self.ring_attn_sub(q, k, v) - if use_kv_fusion: - block_out, block_lse = self.ring_attn_sub_kv_fusion(q, kv) - else: - block_out, block_lse = self.ring_attn_sub(q, k, v) out, lse = self.update_out_and_lse(out, lse, block_out, block_lse) if step + 1 != world_size: RING_COMM.wait() - if use_fp8_comm: - if use_kv_fusion: - kv = dequant_fp8_vllm(next_kv_fp8, next_kv_scale, original_dtype) - else: - k = dequant_fp8_vllm(next_k_fp8, next_k_scale, original_dtype) - v = dequant_fp8_vllm(next_v_fp8, next_v_scale, original_dtype) - else: - if use_kv_fusion: - kv = next_kv - else: - k = next_k - v = next_v + k = next_k + v = next_v attn1 = out.to(GET_DTYPE()).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1) @@ -203,24 +140,6 @@ def apply( return attn1 - def ring_attn_sub_kv_fusion(self, q, kv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False): - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) - block_out, block_lse, _, _ = flash_attn.flash_attn_interface._flash_attn_forward( - q, - kv[:1, :, :, :], - kv[1:, :, :, :], - dropout_p=dropout_p, - softmax_scale=softmax_scale, - causal=causal, - window_size_left=window_size[0], - window_size_right=window_size[1], - softcap=softcap, - alibi_slopes=alibi_slopes, - return_softmax=return_softmax, - ) - return block_out, block_lse - def ring_attn_sub(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) From 31fcd262766622877fcdc3c9a4b015c8f92fbc43 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 16 Jan 2026 10:46:53 +0000 Subject: [PATCH 3/5] ring_attn_fp8_kv_fusion --- lightx2v/common/ops/attn/ring_attn.py | 105 +++++++++++++++++++++++--- 1 file changed, 93 insertions(+), 12 deletions(-) diff --git a/lightx2v/common/ops/attn/ring_attn.py b/lightx2v/common/ops/attn/ring_attn.py index e9dcaa44..24fc394e 100644 --- a/lightx2v/common/ops/attn/ring_attn.py +++ b/lightx2v/common/ops/attn/ring_attn.py @@ -4,6 +4,7 @@ from loguru import logger from lightx2v.utils.envs import * +from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from .template import AttnWeightTemplate @@ -41,7 +42,21 @@ class RingAttnWeight(AttnWeightTemplate): def __init__(self): self.config = {} - def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, attention_type="flash_attn2", seq_p_group=None, use_fp8_comm=False, enable_head_parallel=False, **kwargs): + def apply( + self, + q, + k, + v, + slice_qkv_len, + cu_seqlens_qkv, + attention_module=None, + attention_type="flash_attn2", + seq_p_group=None, + use_fp8_comm=False, + use_kv_fusion=False, + enable_head_parallel=False, + **kwargs, + ): """ 执行 Ring 注意力机制,结合图像和文本的查询、键和值。 @@ -56,13 +71,13 @@ def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, att 返回: torch.Tensor: 计算得到的注意力结果 """ - assert not use_fp8_comm, "RingAttn can't support fp8 comm now." assert not enable_head_parallel, "RingAttn can't support head parallel mode." # 获取当前进程的排名和全局进程数 cur_rank = dist.get_rank(seq_p_group) world_size = dist.get_world_size(seq_p_group) + img_qkv_len = slice_qkv_len if len(cu_seqlens_qkv) == 3: txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度 @@ -85,6 +100,7 @@ def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, att k = k.unsqueeze(0) v = v.unsqueeze(0) + heads, hidden_dims = k.shape[-2], k.shape[-1] img_q, img_k, img_v = q[:, :img_qkv_len, :, :].contiguous(), k[:, :img_qkv_len, :, :].contiguous(), v[:, :img_qkv_len, :, :].contiguous() txt_q, txt_k, txt_v = ( q[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(), @@ -92,31 +108,78 @@ def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, att v[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(), ) - out, lse, next_k, next_v = None, None, None, None + out, lse, next_k, next_v, next_kv = None, None, None, None, None if len(cu_seqlens_qkv) == 3: q = torch.cat((img_q, txt_q), dim=1) k = img_k v = img_v + if use_kv_fusion: + kv = torch.stack([img_k, img_v], dim=0).reshape(2, img_qkv_len, heads, hidden_dims).contiguous() + txt_kv = torch.stack([txt_k, txt_v], dim=0).reshape(2, txt_qkv_len, heads, hidden_dims).contiguous() + original_dtype = kv.dtype + original_shape = kv.shape + else: + original_dtype = k.dtype + original_shape = k.shape + for step in range(world_size): if step + 1 != world_size: - next_k = RING_COMM.send_recv(k) - next_v = RING_COMM.send_recv(v) - RING_COMM.commit() + if use_fp8_comm: + if use_kv_fusion: + kv_fp8, kv_scale = quant_fp8_vllm(kv.reshape(-1, hidden_dims)) + kv_fp8 = kv_fp8.reshape(original_shape) + kv_scale = kv_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1) + next_kv_fp8 = RING_COMM.send_recv(kv_fp8) + next_kv_scale = RING_COMM.send_recv(kv_scale) + else: + k_fp8, k_scale = quant_fp8_vllm(k.reshape(-1, hidden_dims)) + v_fp8, v_scale = quant_fp8_vllm(v.reshape(-1, hidden_dims)) + k_fp8 = k_fp8.reshape(original_shape) + v_fp8 = v_fp8.reshape(original_shape) + k_scale = k_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1) + v_scale = v_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1) + next_k_fp8 = RING_COMM.send_recv(k_fp8) + next_k_scale = RING_COMM.send_recv(k_scale) + next_v_fp8 = RING_COMM.send_recv(v_fp8) + next_v_scale = RING_COMM.send_recv(v_scale) + RING_COMM.commit() + else: + if use_kv_fusion: + next_kv = RING_COMM.send_recv(kv) + else: + next_k = RING_COMM.send_recv(k) + next_v = RING_COMM.send_recv(v) + RING_COMM.commit() if step + 1 == world_size: - k = torch.cat((k, txt_k), dim=1) - v = torch.cat((v, txt_v), dim=1) - - block_out, block_lse = self.ring_attn_sub(q, k, v) + if use_kv_fusion: + next_kv = torch.cat((kv, txt_kv), dim=1) + else: + k = torch.cat((k, txt_k), dim=1) + v = torch.cat((v, txt_v), dim=1) + if use_kv_fusion: + block_out, block_lse = self.ring_attn_sub_kv_fusion(q, kv) + else: + block_out, block_lse = self.ring_attn_sub(q, k, v) out, lse = self.update_out_and_lse(out, lse, block_out, block_lse) if step + 1 != world_size: RING_COMM.wait() - k = next_k - v = next_v + if use_fp8_comm: + if use_kv_fusion: + kv = dequant_fp8_vllm(next_kv_fp8, next_kv_scale, original_dtype) + else: + k = dequant_fp8_vllm(next_k_fp8, next_k_scale, original_dtype) + v = dequant_fp8_vllm(next_v_fp8, next_v_scale, original_dtype) + else: + if use_kv_fusion: + kv = next_kv + else: + k = next_k + v = next_v attn1 = out.to(GET_DTYPE()).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1) @@ -140,6 +203,24 @@ def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, att return attn1 + def ring_attn_sub_kv_fusion(self, q, kv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + block_out, block_lse, _, _ = flash_attn.flash_attn_interface._flash_attn_forward( + q, + kv[:1, :, :, :], + kv[1:, :, :, :], + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size_left=window_size[0], + window_size_right=window_size[1], + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax, + ) + return block_out, block_lse + def ring_attn_sub(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) From 5bf8f6dd6f4e18038daf7ec746d1177886784506 Mon Sep 17 00:00:00 2001 From: root Date: Fri, 16 Jan 2026 12:50:12 +0000 Subject: [PATCH 4/5] ring_attn_fp8_kv_fusion_gemini_code_assist --- lightx2v/common/ops/attn/ring_attn.py | 187 +++++++++++++++++++++----- 1 file changed, 150 insertions(+), 37 deletions(-) diff --git a/lightx2v/common/ops/attn/ring_attn.py b/lightx2v/common/ops/attn/ring_attn.py index 24fc394e..5f5256ce 100644 --- a/lightx2v/common/ops/attn/ring_attn.py +++ b/lightx2v/common/ops/attn/ring_attn.py @@ -41,6 +41,7 @@ def _update_out_and_lse( class RingAttnWeight(AttnWeightTemplate): def __init__(self): self.config = {} + self.helper = RingAttnHelper() def apply( self, @@ -64,7 +65,7 @@ def apply( q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims] k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims] v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims] - img_qkv_len (int): 图像查询、键和值的长度 + slice_qkv_len (int): 图像查询、键和值的长度 cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息 attention_type (str): 注意力类型,默认为 "flash_attn2" @@ -78,12 +79,7 @@ def apply( world_size = dist.get_world_size(seq_p_group) img_qkv_len = slice_qkv_len - if len(cu_seqlens_qkv) == 3: - txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 - txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度 - elif len(cu_seqlens_qkv) == 2: - txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度 - txt_mask_len = 0 + txt_qkv_len, txt_mask_len = self.helper._get_text_lengths(cu_seqlens_qkv, img_qkv_len) # if RING_COMM is None: # init_ring_comm() @@ -108,7 +104,7 @@ def apply( v[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(), ) - out, lse, next_k, next_v, next_kv = None, None, None, None, None + out, lse, next_k, next_v = None, None, None, None if len(cu_seqlens_qkv) == 3: q = torch.cat((img_q, txt_q), dim=1) @@ -116,10 +112,8 @@ def apply( v = img_v if use_kv_fusion: - kv = torch.stack([img_k, img_v], dim=0).reshape(2, img_qkv_len, heads, hidden_dims).contiguous() txt_kv = torch.stack([txt_k, txt_v], dim=0).reshape(2, txt_qkv_len, heads, hidden_dims).contiguous() - original_dtype = kv.dtype - original_shape = kv.shape + kv, original_dtype, original_shape = self.helper._prepare_kv_tensors(k, v, use_kv_fusion) else: original_dtype = k.dtype original_shape = k.shape @@ -128,34 +122,21 @@ def apply( if step + 1 != world_size: if use_fp8_comm: if use_kv_fusion: - kv_fp8, kv_scale = quant_fp8_vllm(kv.reshape(-1, hidden_dims)) - kv_fp8 = kv_fp8.reshape(original_shape) - kv_scale = kv_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1) - next_kv_fp8 = RING_COMM.send_recv(kv_fp8) - next_kv_scale = RING_COMM.send_recv(kv_scale) + next_kv_fp8, next_kv_scale = self.helper._send_recv_tensor(kv, hidden_dims, RING_COMM, use_fp8_comm, original_shape) else: - k_fp8, k_scale = quant_fp8_vllm(k.reshape(-1, hidden_dims)) - v_fp8, v_scale = quant_fp8_vllm(v.reshape(-1, hidden_dims)) - k_fp8 = k_fp8.reshape(original_shape) - v_fp8 = v_fp8.reshape(original_shape) - k_scale = k_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1) - v_scale = v_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1) - next_k_fp8 = RING_COMM.send_recv(k_fp8) - next_k_scale = RING_COMM.send_recv(k_scale) - next_v_fp8 = RING_COMM.send_recv(v_fp8) - next_v_scale = RING_COMM.send_recv(v_scale) - RING_COMM.commit() + next_k_fp8, next_k_scale = self.helper._send_recv_tensor(k, hidden_dims, RING_COMM, use_fp8_comm, original_shape) + next_v_fp8, next_v_scale = self.helper._send_recv_tensor(v, hidden_dims, RING_COMM, use_fp8_comm, original_shape) else: if use_kv_fusion: - next_kv = RING_COMM.send_recv(kv) + next_kv = self.helper._send_recv_tensor(kv, hidden_dims, RING_COMM, use_fp8_comm, original_shape)[0] else: - next_k = RING_COMM.send_recv(k) - next_v = RING_COMM.send_recv(v) - RING_COMM.commit() + next_k = self.helper._send_recv_tensor(k, hidden_dims, RING_COMM, use_fp8_comm, original_shape)[0] + next_v = self.helper._send_recv_tensor(v, hidden_dims, RING_COMM, use_fp8_comm, original_shape)[0] + RING_COMM.commit() if step + 1 == world_size: if use_kv_fusion: - next_kv = torch.cat((kv, txt_kv), dim=1) + kv = torch.cat((kv, txt_kv), dim=1) else: k = torch.cat((k, txt_k), dim=1) v = torch.cat((v, txt_v), dim=1) @@ -164,22 +145,24 @@ def apply( block_out, block_lse = self.ring_attn_sub_kv_fusion(q, kv) else: block_out, block_lse = self.ring_attn_sub(q, k, v) + out, lse = self.update_out_and_lse(out, lse, block_out, block_lse) if step + 1 != world_size: RING_COMM.wait() + if use_fp8_comm: if use_kv_fusion: - kv = dequant_fp8_vllm(next_kv_fp8, next_kv_scale, original_dtype) + kv = self.helper._dequantize_received(next_kv_fp8, next_kv_scale, original_dtype, original_shape, use_kv_fusion=True, is_kv_fusion=True) else: - k = dequant_fp8_vllm(next_k_fp8, next_k_scale, original_dtype) - v = dequant_fp8_vllm(next_v_fp8, next_v_scale, original_dtype) + k, v = self.helper._dequantize_received( + next_k_fp8, next_k_scale, original_dtype, original_shape, use_kv_fusion=False, is_kv_fusion=False, v_fp8=next_v_fp8, v_scale=next_v_scale + ) else: if use_kv_fusion: kv = next_kv else: - k = next_k - v = next_v + k, v = next_k, next_v attn1 = out.to(GET_DTYPE()).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1) @@ -206,6 +189,7 @@ def apply( def ring_attn_sub_kv_fusion(self, q, kv, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) + block_out, block_lse, _, _ = flash_attn.flash_attn_interface._flash_attn_forward( q, kv[:1, :, :, :], @@ -259,3 +243,132 @@ def update_out_and_lse( else: out, lse = _update_out_and_lse(out, lse, block_out, block_lse) return out, lse + +class RingAttnHelper: + """辅助函数类,处理 Ring Attention 中的量化、通信和反量化逻辑""" + + @staticmethod + def _quant_and_send(tensor, hidden_dims, comm, original_shape=None): + """ + 对张量进行 FP8 量化并通过通信器发送/接收 + + 参数: + tensor: 要量化和发送的张量 + hidden_dims: 隐藏维度大小 + comm: 通信器对象 + original_shape: 原始形状(用于 reshape 回原始形状) + + 返回: + tuple: (量化后的张量, scale 张量) + """ + if original_shape is None: + original_shape = tensor.shape + + # 量化为 FP8 + tensor_fp8, tensor_scale = quant_fp8_vllm(tensor.reshape(-1, hidden_dims)) + + # reshape 回原始形状 + tensor_fp8 = tensor_fp8.reshape(original_shape) + tensor_scale = tensor_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1) + + # 发送/接收量化后的张量 + next_tensor_fp8 = comm.send_recv(tensor_fp8) + next_tensor_scale = comm.send_recv(tensor_scale) + + return next_tensor_fp8, next_tensor_scale + + @staticmethod + def _prepare_kv_tensors(k, v, use_kv_fusion): + """ + 准备 K 和 V 张量,根据是否使用 KV 融合返回适当的张量 + + 参数: + k: 键张量 + v: 值张量 + use_kv_fusion: 是否使用 KV 融合 + + 返回: + tuple: (主张量, 原始数据类型, 原始形状) + """ + original_dtype = k.dtype + original_shape = k.shape + + if use_kv_fusion: + # 融合 K 和 V + kv = torch.stack([k, v], dim=0).reshape(2, k.shape[1], k.shape[2], k.shape[3]).contiguous() + return kv, original_dtype, kv.shape + else: + return k, original_dtype, original_shape + + @staticmethod + def _dequantize_received(next_tensor_fp8, next_tensor_scale, original_dtype, original_shape, use_kv_fusion=False, is_kv_fusion=False, v_fp8=None, v_scale=None): + """ + 反量化接收到的 FP8 张量 + + 参数: + next_tensor_fp8: 接收到的量化张量 + next_tensor_scale: 接收到的 scale 张量 + original_dtype: 原始数据类型 + original_shape: 原始形状 + use_kv_fusion: 是否使用 KV 融合模式 + is_kv_fusion: 当前张量是否为 KV 融合张量 + v_fp8, v_scale: 分离模式下的 V 张量和 scale + + 返回: + tuple: 反量化后的张量 (k, v) 或 kv + """ + if use_kv_fusion and is_kv_fusion: + # KV 融合模式 + return dequant_fp8_vllm(next_tensor_fp8, next_tensor_scale, original_dtype) + elif not use_kv_fusion: + # 分离模式 + k = dequant_fp8_vllm(next_tensor_fp8, next_tensor_scale, original_dtype) + v = dequant_fp8_vllm(v_fp8, v_scale, original_dtype) + return k, v + else: + # 默认返回单个张量 + return dequant_fp8_vllm(next_tensor_fp8, next_tensor_scale, original_dtype) + + @staticmethod + def _send_recv_tensor(tensor, hidden_dims, comm, use_fp8_comm, original_shape=None): + """ + 发送/接收张量,根据是否使用 FP8 选择通信方式 + + 参数: + tensor: 要发送的张量 + hidden_dims: 隐藏维度大小 + comm: 通信器对象 + use_fp8_comm: 是否使用 FP8 通信 + original_shape: 原始形状 + + 返回: + tuple: 接收到的张量(和可能的 scale) + """ + if use_fp8_comm: + return RingAttnHelper._quant_and_send(tensor, hidden_dims, comm, original_shape) + else: + next_tensor = comm.send_recv(tensor) + return next_tensor, None + + @staticmethod + def _get_text_lengths(cu_seqlens_qkv, img_qkv_len): + """ + 从累积序列长度中获取文本长度 + + 参数: + cu_seqlens_qkv: 累积序列长度 + img_qkv_len: 图像序列长度 + + 返回: + tuple: (文本QKV长度, 文本掩码长度) + """ + if len(cu_seqlens_qkv) == 3: + txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len + txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len + elif len(cu_seqlens_qkv) == 2: + txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len + txt_mask_len = 0 + else: + raise ValueError(f"Invalid cu_seqlens_qkv length: {len(cu_seqlens_qkv)}") + + return txt_qkv_len, txt_mask_len From 19e3561bd4899912158c88773977fd7c2749e880 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 19 Jan 2026 14:05:51 +0000 Subject: [PATCH 5/5] ring_attn_fp8_comm_kv_fusion --- .../seko_talk_26_fp8_dist_fp8_comm.json | 42 +++++++++++++++++++ lightx2v/common/ops/attn/ring_attn.py | 1 + .../networks/wan/infer/transformer_infer.py | 3 ++ 3 files changed, 46 insertions(+) create mode 100644 configs/seko_talk/seko_talk_26_fp8_dist_fp8_comm.json diff --git a/configs/seko_talk/seko_talk_26_fp8_dist_fp8_comm.json b/configs/seko_talk/seko_talk_26_fp8_dist_fp8_comm.json new file mode 100644 index 00000000..74bcb376 --- /dev/null +++ b/configs/seko_talk/seko_talk_26_fp8_dist_fp8_comm.json @@ -0,0 +1,42 @@ +{ + "infer_steps": 4, + "target_fps": 16, + "video_duration": 20, + "audio_sr": 16000, + "target_video_length": 81, + "resize_mode": "fixed_shape", + "fixed_shape": [ + 832, + 480 + ], + "self_attn_1_type": "sage_attn3", + "cross_attn_1_type": "sage_attn3", + "cross_attn_2_type": "sage_attn3", + "sample_guide_scale": 1, + "sample_shift": 5, + "enable_cfg": false, + "use_31_block": false, + "cpu_offload": false, + "offload_granularity": "block", + "offload_ratio": 1, + "t5_cpu_offload": true, + "t5_quantized": true, + "t5_quant_scheme": "fp8-sgl", + "clip_cpu_offload": true, + "clip_quantized": false, + "audio_encoder_cpu_offload": true, + "audio_adapter_cpu_offload": true, + "adapter_quantized": true, + "adapter_quant_scheme": "fp8-sgl", + "vae_cpu_offload": true, + "use_tiling_vae": false, + "dit_quantized": true, + "dit_quant_scheme": "fp8-sgl", + "parallel": { + "seq_p_size": 4, + "seq_p_attn_type": "ring", + "seq_p_fp8_comm": true, + "seq_p_head_parallel": true, + "seq_p_tensor_fusion": false + } +} diff --git a/lightx2v/common/ops/attn/ring_attn.py b/lightx2v/common/ops/attn/ring_attn.py index 5f5256ce..fd5c5ada 100644 --- a/lightx2v/common/ops/attn/ring_attn.py +++ b/lightx2v/common/ops/attn/ring_attn.py @@ -244,6 +244,7 @@ def update_out_and_lse( out, lse = _update_out_and_lse(out, lse, block_out, block_lse) return out, lse + class RingAttnHelper: """辅助函数类,处理 Ring Attention 中的量化、通信和反量化逻辑""" diff --git a/lightx2v/models/networks/wan/infer/transformer_infer.py b/lightx2v/models/networks/wan/infer/transformer_infer.py index 4f710b4e..705e1bb2 100755 --- a/lightx2v/models/networks/wan/infer/transformer_infer.py +++ b/lightx2v/models/networks/wan/infer/transformer_infer.py @@ -64,10 +64,12 @@ def rope_wrapper(xq, xk, cos_sin_cache): self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False) self.enable_head_parallel = self.config["parallel"].get("seq_p_head_parallel", False) + self.seq_p_tensor_fusion = self.config["parallel"].get("seq_p_tensor_fusion", False) else: self.seq_p_group = None self.seq_p_fp8_comm = False self.enable_head_parallel = False + self.seq_p_tensor_fusion = False self.infer_func = self.infer_without_offload self.cos_sin = None @@ -220,6 +222,7 @@ def infer_self_attn(self, phase, x, shift_msa, scale_msa): attention_type=self.self_attn_1_type, seq_p_group=self.seq_p_group, use_fp8_comm=self.seq_p_fp8_comm, + use_tensor_fusion=self.seq_p_tensor_fusion, enable_head_parallel=self.enable_head_parallel, **attn_running_args, )