From 4223c83fe7ad805a820086bf45986e002f3b20e0 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Thu, 23 Oct 2025 17:13:18 +0800 Subject: [PATCH 1/3] [Feature] Remove stream synchronization during ring_mla Signed-off-by: Jade Zheng --- vllm_ascend/attention/mla_v1.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 86418720af..a66046ed02 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -79,6 +79,7 @@ class ChunkedContextMetadata: max_seq_lens: list[int] workspace: torch.Tensor chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor attn_mask: torch.Tensor query_lens: torch.Tensor @@ -370,6 +371,7 @@ def build( seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.to(device, non_blocking=True), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -774,6 +776,8 @@ def _compute_prefill_context( toks = prefill_metadata.chunked_context.seq_tot[i] seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] + seq_len2_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] seq_len = torch.stack([seq_len1, seq_len2]) kv_c_normed = torch.empty(toks, num_heads, @@ -790,7 +794,7 @@ def _compute_prefill_context( cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len2.to(q_nope.device), + seq_len2_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe, From 4139ebfe673e0c63ab0b20d2f840005dbaf5c972 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Thu, 23 Oct 2025 17:40:02 +0800 Subject: [PATCH 2/3] update test case Signed-off-by: Jade Zheng --- tests/ut/attention/test_mla_v1.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index c55234bc3d..ebfea50e85 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -82,7 +82,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self): seq_tot=seq_tot, max_seq_lens=max_seq_lens, workspace=workspace, - chunk_seq_lens=chunk_seq_lens) + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu()) metadata = AscendMLAPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), From 1d6d64e27379dae76c84f84d0736c3189ee71024 Mon Sep 17 00:00:00 2001 From: Jade Zheng Date: Fri, 24 Oct 2025 11:19:35 +0800 Subject: [PATCH 3/3] fix: remove non_blocking flag from chunk_seq_lens_npu transfer Signed-off-by: Jade Zheng --- vllm_ascend/attention/mla_v1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index a66046ed02..1f2a5c486a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -371,7 +371,7 @@ def build( seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, - chunk_seq_lens_npu=chunk_seq_lens.to(device, non_blocking=True), + chunk_seq_lens_npu=chunk_seq_lens.to(device), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:]