Skip to content

Commit 5ec96fd

Browse files
LookAround0301Delphine-Nic
andauthored
[long_seq_Feat] support chunk prefill (#4158)
### What this PR does / why we need it? 1、qwen GQA attention_v1 optim 2、DeepSeek MLA refactor, all gather q -> all gather kv 3、modelrunner refactor for chunk prefill, we remove some code not use - vLLM version: v0.11.0 - vLLM main: vllm-project/vllm@2918c1b --------- Signed-off-by: LookAround <[email protected]> Signed-off-by: Delphine-Nic <[email protected]> Co-authored-by: Delphine-Nic <[email protected]>
1 parent 7294f89 commit 5ec96fd

File tree

6 files changed

+421
-943
lines changed

6 files changed

+421
-943
lines changed

tests/ut/attention/test_mla_v1.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -484,9 +484,6 @@ def test_compute_prefill_context(self, mock_ring, mock_load):
484484
chunk_ctx.chunk_seq_lens = [torch.tensor([8])]
485485
chunk_ctx.chunk_seq_lens_npu = [torch.tensor([8])]
486486
chunk_ctx.starts = [torch.tensor([0])]
487-
chunk_ctx.max_chunk_num = 1
488-
chunk_ctx.mask_for_non_zero_chunk = [True]
489-
chunk_ctx.local_chunked_kv_lens = [[[[8]]]]
490487

491488
prefill_meta = MagicMock()
492489
prefill_meta.chunked_context = chunk_ctx

vllm_ascend/attention/attention_v1.py

Lines changed: 152 additions & 217 deletions
Large diffs are not rendered by default.

vllm_ascend/attention/mla_v1.py

Lines changed: 248 additions & 460 deletions
Large diffs are not rendered by default.

vllm_ascend/attention/utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@ class AscendPrefillContextParallelMetadata:
2020

2121
num_computed_tokens_of_pcp_dcp: Optional[list[list[list[int]]]] = None
2222

23-
local_chunked_kv_lens: Optional[list[Optional[list[Optional[list[Optional[
24-
list[int]]]]]]]] = None
25-
26-
mask_for_non_zero_chunk: Optional[List[bool]] = None
27-
28-
max_chunk_num: int = 0
29-
3023
q_head_idx_tensor: torch.Tensor = None
3124

3225
q_tail_idx_tensor: torch.Tensor = None
@@ -115,23 +108,6 @@ class AscendCommonAttentionMetadata:
115108
AscendPrefillContextParallelMetadata] = None
116109

117110

118-
def extract_req_dcp_by_chunk_pcp(lst,
119-
chunk_idx,
120-
dcp_size,
121-
pcp_rank,
122-
fill_value=0):
123-
num_reqs = len(lst)
124-
results: List[List[int]] = []
125-
for i in range(num_reqs):
126-
if len(lst[i]) == 0 or chunk_idx >= len(lst[i]):
127-
# empty req or this req has no corresponding chunk, fill 0
128-
results.append([fill_value] * dcp_size)
129-
continue
130-
dcp_values = lst[i][chunk_idx][pcp_rank]
131-
results.append(dcp_values)
132-
return results
133-
134-
135111
def filter_chunked_req_indices(
136112
seq_len: torch.Tensor,
137113
mask_for_non_zero_chunk: Optional[List[bool]],

0 commit comments

Comments
 (0)