diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index c55234bc3d..ebfea50e85 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -82,7 +82,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self): seq_tot=seq_tot, max_seq_lens=max_seq_lens, workspace=workspace, - chunk_seq_lens=chunk_seq_lens) + chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.npu()) metadata = AscendMLAPrefillMetadata( attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool), diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 86418720af..1f2a5c486a 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -79,6 +79,7 @@ class ChunkedContextMetadata: max_seq_lens: list[int] workspace: torch.Tensor chunk_seq_lens: torch.Tensor + chunk_seq_lens_npu: torch.Tensor attn_mask: torch.Tensor query_lens: torch.Tensor @@ -370,6 +371,7 @@ def build( seq_tot=chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=chunk_seq_lens, + chunk_seq_lens_npu=chunk_seq_lens.to(device), workspace=self.chunked_prefill_workspace, ) prefill_input_positions = input_positions[tokens_start:] @@ -774,6 +776,8 @@ def _compute_prefill_context( toks = prefill_metadata.chunked_context.seq_tot[i] seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i] + seq_len2_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[ + i] seq_len = torch.stack([seq_len1, seq_len2]) kv_c_normed = torch.empty(toks, num_heads, @@ -790,7 +794,7 @@ def _compute_prefill_context( cache_kv_c, cache_k_pe, prefill_metadata.block_table, - seq_len2.to(q_nope.device), + seq_len2_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe,