Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,11 +382,20 @@ def get_moe_group_name(group):
AscendKVQuantMeta.set_value(step_context.block_offsets.device, step_context.model_config.dtype,
record_file, total_layers)

cu_seqlens = None
has_initial_state = None

if step_context.state_offsets is not None:
q_start_loc = step_context.q_start_loc
cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int()
if not step_context.is_decoding:
has_initial_state = ~(step_context.q_seqlens == step_context.kv_seqlens)

attn_meta_cls = cls.get_attention_metadata_cls()
attn_metadata = attn_meta_cls(
step_context.is_decoding,
step_context.block_offsets,
q_start_loc=None,
q_start_loc=cu_seqlens,
q_seqlens=q_seqlens_cpu,
# kv_seqlens_expanded is only expanded in paged prefill,
# otherwise it equals kv_seqlens_cpu
Expand All @@ -399,6 +408,7 @@ def get_moe_group_name(group):
max_kv_seq_len=max_kv_seq_len,
quant_policy=step_context.kv_quant_policy,
quant_meta=AscendKVQuantMeta.quant_meta,
has_initial_state=has_initial_state,
)
step_context.attn_metadata = attn_metadata

Expand Down Expand Up @@ -442,6 +452,8 @@ def init():
"""Initialize Ascend backend."""
try:
from torch_npu.contrib import transfer_to_npu # noqa: F401
from dlinfer.vendor.ascend.triton_ops.fla.triton_utils import init_device_properties_triton
init_device_properties_triton()
except ImportError:
logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. '
'Ascend initialization skipped.')
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class DlinferAttentionMetadata(AttentionMetadata):
max_kv_seq_len: int = 1
quant_meta: Dict = None
cu_seq_lens_kv: Optional[Tensor] = None
has_initial_state: Optional[Tensor] = None


class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
Expand Down
Loading