Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,27 @@ def get_total_slots():
return cls.total_slots

def get_cpu_seqlens(is_decoding, is_unpaged_prefill):
"""Get sequence lengths on CPU.

Returns:
q_seqlens_cpu: query sequence lengths (per sequence).
kv_seqlens_cpu: kv sequence lengths (per sequence), used for
list/max seqlens calculation.
kv_seqlens_expanded: kv sequence lengths expanded per token via
repeat_interleave, used for attention metadata.
"""
if is_decoding:
q_seqlens_cpu, kv_seqlens_cpu = None, step_context.kv_seqlens.cpu()
q_seqlens_cpu = None
kv_seqlens_cpu = kv_seqlens_expanded = step_context.kv_seqlens.cpu()
elif is_unpaged_prefill:
q_seqlens_cpu = kv_seqlens_cpu = step_context.q_seqlens.cpu()
q_seqlens_cpu = step_context.q_seqlens.cpu()
kv_seqlens_cpu = kv_seqlens_expanded = q_seqlens_cpu
else:
q_seqlens_cpu = step_context.q_seqlens.cpu()
kv_seqlens_cpu = step_context.kv_seqlens.cpu()
return q_seqlens_cpu, kv_seqlens_cpu
# Expand kv_seqlens to per-token for paged prefill attention
kv_seqlens_expanded = kv_seqlens_cpu.repeat_interleave(q_seqlens_cpu, 0)
return q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded

def get_list_seqlens(is_decoding, is_unpaged_prefill, q_seqlens_cpu=None, kv_seqlens_cpu=None):
if is_decoding:
Expand Down Expand Up @@ -219,7 +232,8 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s

return kv_start_indices, attention_mask

q_seqlens_cpu, kv_seqlens_cpu = get_cpu_seqlens(step_context.is_decoding, is_unpaged_prefill)
q_seqlens_cpu, kv_seqlens_cpu, kv_seqlens_expanded = get_cpu_seqlens(step_context.is_decoding,
is_unpaged_prefill)
q_seqlens_list, kv_seqlens_list = get_list_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_cpu,
kv_seqlens_cpu)
max_q_seq_len, max_kv_seq_len = get_max_seqlens(step_context.is_decoding, is_unpaged_prefill, q_seqlens_list,
Expand Down Expand Up @@ -248,7 +262,9 @@ def get_kv_start_indices_and_attention_mask(is_decoding, is_unpaged_prefill, q_s
step_context.block_offsets,
q_start_loc=None,
q_seqlens=q_seqlens_cpu,
kv_seqlens=kv_seqlens_cpu,
# kv_seqlens_expanded is only expanded in paged prefill,
# otherwise it equals kv_seqlens_cpu
kv_seqlens=kv_seqlens_expanded,
kv_start_indices=kv_start_indices,
block_size=block_size,
attention_mask=attention_mask,
Expand Down