Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/ut/attention/test_mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def test_ascend_mla_prefill_metadata_with_chunked_context(self):
seq_tot=seq_tot,
max_seq_lens=max_seq_lens,
workspace=workspace,
chunk_seq_lens=chunk_seq_lens)
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.npu())

metadata = AscendMLAPrefillMetadata(
attn_mask=torch.tensor([[1, 0], [1, 1]], dtype=torch.bool),
Expand Down
6 changes: 5 additions & 1 deletion vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class ChunkedContextMetadata:
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
chunk_seq_lens_npu: torch.Tensor

attn_mask: torch.Tensor
query_lens: torch.Tensor
Expand Down Expand Up @@ -370,6 +371,7 @@ def build(
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
chunk_seq_lens_npu=chunk_seq_lens.to(device),
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
Expand Down Expand Up @@ -774,6 +776,8 @@ def _compute_prefill_context(
toks = prefill_metadata.chunked_context.seq_tot[i]

seq_len2 = prefill_metadata.chunked_context.chunk_seq_lens[i]
seq_len2_npu = prefill_metadata.chunked_context.chunk_seq_lens_npu[
i]
seq_len = torch.stack([seq_len1, seq_len2])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The seq_len variable is created here but never used. Its creation involves torch.stack with seq_len2, which is a CPU tensor, and seq_len1, a device tensor. This causes an unnecessary host-to-device transfer and stream synchronization, which this PR aims to eliminate. Removing this line will improve performance by avoiding this synchronization.

kv_c_normed = torch.empty(toks,
num_heads,
Expand All @@ -790,7 +794,7 @@ def _compute_prefill_context(
cache_kv_c,
cache_k_pe,
prefill_metadata.block_table,
seq_len2.to(q_nope.device),
seq_len2_npu,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ops, removing this optimization might be a bug introduced by MLA refactoring. Thanks for the fix. LGTM.

seq_starts=prefill_metadata.chunked_context.starts[i],
key=kv_c_normed,
value=k_pe,
Expand Down
Loading