diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index b7ce2a6846..e48ab79f18 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -148,14 +148,27 @@ def get_total_slots(): return cls.total_slots def get_cpu_seqlens(is_decoding, is_unpaged_prefill): + """Get sequence lengths on CPU. + + Returns: + q_seqlens_cpu: query sequence lengths (per sequence). + kv_seqlens_cpu: kv sequence lengths (per sequence), used for + list/max seqlens calculation. + kv_seqlens_expanded: kv sequence lengths expanded per token via + repeat_interleave, used for attention metadata. + """ if is_decoding: - q_seqlens_cpu, kv_seqlens_cpu = None, step_context.kv_seqlens.cpu() + q_seqlens_cpu = None + kv_seqlens_cpu = kv_seqlens_expanded = step_context.kv_seqlens.cpu() elif is_unpaged_prefill: - q_seqlens_cpu = kv_seqlens_cpu = step_context.q_seqlens.cpu() + q_seqlens_cpu = step_context.q_seqlens.cpu() + kv_seqlens_cpu = kv_seqlens_expanded = q_seqlens_cpu else: q_seqlens_cpu = step_context.q_seqlens.cpu() kv_seqlens_cpu = step_context.kv_seqlens.cpu() - return q_seqlens_cpu, kv_seqlens_cpu + # Expand kv_seqlens to per-token for paged prefill attention + kv_seqlens_expanded = kv_seqlens_cpu.repeat_interleave(q_seqlens_cpu, 0) + return q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded def get_list_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_cpu=None, kv_seqlens_cpu=None): if is_decoding: @@ -219,7 +232,8 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s return kv_start_indices, attention_mask - q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_unpaged_prefill) + q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded = get_cpu_seqlens(step_context.is_decoding, + is_unpaged_prefill) q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu, kv_seqlens_cpu) max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list, @@ -248,7 +262,9 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s step_context.block_offsets, q_start_loc=None, q_seqlens=q_seqlens_cpu, - kv_seqlens=kv_seqlens_cpu, + # kv_seqlens_expanded is only expanded in paged prefill, + # otherwise it equals kv_seqlens_cpu + kv_seqlens=kv_seqlens_expanded, kv_start_indices=kv_start_indices, block_size=block_size, attention_mask=attention_mask,