diff --git a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py index 484cbd1b72..f528bab2ca 100644 --- a/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/ascend/op_backend.py @@ -382,11 +382,20 @@ def get_moe_group_name(group): AscendKVQuantMeta.set_value(step_context.block_offsets.device, step_context.model_config.dtype, record_file, total_layers) + cu_seqlens = None + has_initial_state = None + + if step_context.state_offsets is not None: + q_start_loc = step_context.q_start_loc + cu_seqlens = torch.cat((q_start_loc, step_context.q_seqlens.sum().unsqueeze(0))).int() + if not step_context.is_decoding: + has_initial_state = ~(step_context.q_seqlens == step_context.kv_seqlens) + attn_meta_cls = cls.get_attention_metadata_cls() attn_metadata = attn_meta_cls( step_context.is_decoding, step_context.block_offsets, - q_start_loc=None, + q_start_loc=cu_seqlens, q_seqlens=q_seqlens_cpu, # kv_seqlens_expanded is only expanded in paged prefill, # otherwise it equals kv_seqlens_cpu @@ -399,6 +408,7 @@ def get_moe_group_name(group): max_kv_seq_len=max_kv_seq_len, quant_policy=step_context.kv_quant_policy, quant_meta=AscendKVQuantMeta.quant_meta, + has_initial_state=has_initial_state, ) step_context.attn_metadata = attn_metadata @@ -442,6 +452,8 @@ def init(): """Initialize Ascend backend.""" try: from torch_npu.contrib import transfer_to_npu # noqa: F401 + from dlinfer.vendor.ascend.triton_ops.fla.triton_utils import init_device_properties_triton + init_device_properties_triton() except ImportError: logger.warning('Failed to import torch_npu. Please make sure torch_npu is installed correctly. ' 'Ascend initialization skipped.') diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 8566187021..8b481730e5 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -18,6 +18,7 @@ class DlinferAttentionMetadata(AttentionMetadata): max_kv_seq_len: int = 1 quant_meta: Dict = None cu_seq_lens_kv: Optional[Tensor] = None + has_initial_state: Optional[Tensor] = None class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):