diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index e95db1a93f..492f2036b8 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -242,27 +242,6 @@ def setUp(self): attn_type=self.attention_type.DECODER, kv_sharing_target_layer_name=None) - @patch('torch.ops.vllm.unified_ascend_attention_with_output') - def test_forward_trace_flag_true(self, mock_unified_attention): - """Test forward pass when trace_flag is True""" - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 0, 0, 8, 64) - metadata = self.attn_metadata - layer = self.layer - - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=True) - - mock_unified_attention.assert_called_once() - assert output.shape == (10, 8 * 64) - @patch('torch_npu._npu_paged_attention_splitfuse') def test_forward_with_quant_method(self, mock_paged_attention): """Test forward pass when layer has quant_method""" @@ -284,13 +263,8 @@ def test_forward_with_quant_method(self, mock_paged_attention): layer.quant_method = MagicMock() layer.quant_method.apply.return_value = ret_value - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl.forward(layer, query, key, value, kv_cache, + metadata) layer.quant_method.apply.assert_called_once() assert output.shape == (10, 8 * 64) @@ -303,13 +277,7 @@ def test_forward_no_attn_metadata(self): kv_cache = torch.empty(2, 0, 0, 8, 64) layer = self.layer_no_quant - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - None, - trace_flag=False) + output = self.impl.forward(layer, query, key, value, kv_cache, None) assert output.shape == (10, 8 * 64) @@ -331,13 +299,8 @@ def test_forward_prefill_no_cache(self, mock_flash_attention, layer = self.layer_no_quant # layer.quant_method.apply.return_value = metadata print(self.layer_no_quant._v_scale_float) - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl.forward(layer, query, key, value, kv_cache, + metadata) mock_reshape_cache.assert_called_once() mock_flash_attention.assert_called_once() @@ -362,13 +325,8 @@ def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens, metadata.slot_mapping = torch.zeros(10, dtype=torch.long) layer = self.layer_no_quant - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl.forward(layer, query, key, value, kv_cache, + metadata) mock_flash_attention_qlens.assert_called_once() assert output.shape == (10, 8 * 64) @@ -394,13 +352,8 @@ def test_forward_decode_only(self, mock_paged_attention, mock_get_forward_context.return_value = MagicMock(capturing=False) - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl.forward(layer, query, key, value, kv_cache, + metadata) mock_paged_attention.assert_called_once() assert output.shape == (10, 8 * 64) @@ -501,13 +454,8 @@ def test_paged_attention_with_existing_workspace( mock_get_forward_context.return_value = MagicMock(capturing=True) mock_get_graph_params.return_value = graph_params - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl.forward(layer, query, key, value, kv_cache, + metadata) mock_paged_attention.assert_called_once() self.assertEqual(len(graph_params.handles[num_tokens]), 0) @@ -530,13 +478,8 @@ def test_forward_decode_only_swa(self, mock_fused_infer_attention_score, layer = self.layer_no_quant mock_fused_infer_attention_score.return_value = (torch.ones(10, 8, 64), 1) - output = self.impl_swa.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl_swa.forward(layer, query, key, value, kv_cache, + metadata) print(output.shape) mock_fused_infer_attention_score.assert_called_once() assert output.shape == (10, 8 * 64) @@ -566,13 +509,8 @@ def test_forward_decode_only_swa_seq_len_mismatch( mock_get_forward_context.return_value = MagicMock(capturing=False) - output = self.impl_swa.forward(self.layer_no_quant, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl_swa.forward(self.layer_no_quant, query, key, value, + kv_cache, metadata) mock_paged_attention.assert_called_once() mock_fused_infer_attention_score.assert_not_called() @@ -601,13 +539,8 @@ def test_forward_head_size_192(self, mock_vanilla_prefill, layer = self.layer_no_quant mock_vanilla_prefill.return_value = MagicMock() - output = self.impl_192.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl_192.forward(layer, query, key, value, kv_cache, + metadata) mock_vanilla_prefill.assert_called_once() assert output.shape == (10, 8 * 192) @@ -630,13 +563,8 @@ def test_forward_normal_v1_situation(self, mock_paged_attention, metadata.slot_mapping = torch.zeros(10, dtype=torch.long) layer = self.layer_no_quant - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl.forward(layer, query, key, value, kv_cache, + metadata) mock_paged_attention.assert_called_once() assert output.shape == (10, 8 * 64) @@ -663,13 +591,8 @@ def test_forward_310p_device(self, mock_is_310p, mock_paged_attention, layer = self.layer_no_quant mock_npu_format_cast.return_value = metadata.attn_mask - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + output = self.impl.forward(layer, query, key, value, kv_cache, + metadata) mock_paged_attention.assert_called_once() assert output.shape == (10, 8 * 64) @@ -690,10 +613,5 @@ def test_forward_raise_error(self, mock_paged_attention): layer = self.layer_no_quant with self.assertRaises(NotImplementedError): - self.impl_error.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) + self.impl_error.forward(layer, query, key, value, kv_cache, + metadata) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index fce9e8c3c2..a3d4ec5a5d 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -26,14 +26,12 @@ AttentionLayer, AttentionType) from vllm.config import VllmConfig from vllm.forward_context import ForwardContext, get_forward_context -from vllm.utils import cdiv, direct_register_custom_op +from vllm.utils import cdiv from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - maybe_save_kv_layer_to_connector, - wait_for_kv_layer_from_connector) +from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (get_graph_params, update_graph_params_workspaces) from vllm_ascend.ops.attention import vanilla_chunked_prefill @@ -590,158 +588,94 @@ def forward( kv_cache: Tuple[torch.Tensor], attn_metadata: AscendMetadata, output: Optional[torch.Tensor] = None, - trace_flag: bool = True, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Ascend attention. Args: - query: shape = [batch_size, seq_len, num_heads * head_size] - key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, seq_len, num_kv_heads * head_size] - kv_cache: shape = [key_cache, value_cache] - key_cache = [num_blocks, block_size, - num_kv_heads, head_size] - value_cache = [num_blocks, block_size, - num_kv_heads, head_size] + query: shape = [num_tokens, num_heads, head_size] + key: shape = [num_tokens, num_kv_heads, head_size] + value: shape = [num_tokens, num_kv_heads, head_size] + kv_cache: shape = + [2, num_blocks, block_size, num_kv_heads, head_size] attn_metadata: Metadata for attention. Returns: - shape = [batch_size * seq_len, num_heads, head_size] + shape = [num_tokens, num_heads * head_size] """ - num_tokens = query.shape[0] - use_kv_cache_int8 = len( - kv_cache) > 0 and kv_cache[0].dtype == torch.int8 - if output is None: - output = torch.empty(num_tokens, - self.num_heads, - self.head_size, - dtype=query.dtype, - device=query.device) - ori_output = output - if trace_flag: - torch.ops.vllm.unified_ascend_attention_with_output( - query=query, - key=key, - value=value, - output=output, - layer_name=layer.layer_name) + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for AscendAttentionBackendImpl") - elif hasattr(layer, 'quant_method') and use_kv_cache_int8: - output = layer.quant_method.apply(layer, query, key, value, - kv_cache, attn_metadata, - self.attn_type, self.scale, - output) + num_tokens = query.shape[0] + if attn_metadata is None: + return output + # NOTE: Currently, we have various attention paths for different + # scenarios, and not all of them are in-place operations. Therefore, + # we need to create a separate tensor to hold the attention result. + # In the future, we may consolidate them into fewer paths, which will + # hopefully allow us to use in-place operation by default. + intermediate_output: torch.Tensor + + num_actual_tokens = attn_metadata.num_actual_tokens + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 + attn_type = self.attn_type + if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PallasAttentionBackendImpl") + + if len(kv_cache) > 1: + if self.key_cache is None: + self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] + slots = attn_metadata.slot_mapping + torch_npu._npu_reshape_and_cache(key=key[:num_actual_tokens], + value=value[:num_actual_tokens], + key_cache=self.key_cache, + value_cache=self.value_cache, + slot_indices=slots) + if attn_type == AttentionType.ENCODER_ONLY: + # TODO(zzzwwjj): Deal with this `cum_seq_len` more elegantly. + cum_seq_len = attn_metadata.query_start_loc[1:].tolist() + intermediate_output = torch_npu.npu_fusion_attention( + query, + key, + value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=4, + atten_mask=attn_metadata.attn_mask, + pre_tockens=attn_metadata.max_query_len, + next_tockens=attn_metadata.max_query_len, + actual_seq_qlen=cum_seq_len, + actual_seq_kvlen=cum_seq_len, + )[0] + # V0-Style scheduler situation. + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + intermediate_output = self._forward_prefill_no_cache( + query, key, value, attn_metadata, output, num_tokens) + elif attn_metadata.attn_state == \ + AscendAttentionState.PrefillCacheHit: + intermediate_output = self._forward_prefill_cache_hit( + query, attn_metadata, output) + elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: + intermediate_output = self._forward_decode_only( + query, attn_metadata, output) + # Normal V1 situation. else: - if attn_metadata is None: - return output.view(num_tokens, self.hidden_size) - num_actual_tokens = attn_metadata.num_actual_tokens - assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 - attn_type = self.attn_type - if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY: - raise NotImplementedError("Encoder/decoder cross-attention " - "are not implemented for " - "PallasAttentionBackendImpl") - # View q k v to BSH. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - # TODO: Remove this contiguous in the future. - value = value.contiguous() - - if len(kv_cache) > 1: - if self.key_cache is None: - self.key_cache, self.value_cache = kv_cache[0], kv_cache[1] - slots = attn_metadata.slot_mapping - torch_npu._npu_reshape_and_cache( - key=key[:num_actual_tokens], - value=value[:num_actual_tokens], - key_cache=self.key_cache, - value_cache=self.value_cache, - slot_indices=slots) - if attn_type == AttentionType.ENCODER_ONLY: - cum_seq_len = attn_metadata.query_start_loc[1:].tolist() - attn_out = torch_npu.npu_fusion_attention( - query, - key, - value, - head_num=self.num_heads, - input_layout="TND", - scale=self.scale, - sparse_mode=4, - atten_mask=attn_metadata.attn_mask, - pre_tockens=attn_metadata.max_query_len, - next_tockens=attn_metadata.max_query_len, - actual_seq_qlen=cum_seq_len, - actual_seq_kvlen=cum_seq_len, - ) - output = attn_out[0] - # V0-Style scheduler situation. - elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: - output = self._forward_prefill_no_cache( - query, key, value, attn_metadata, output, num_tokens) - elif attn_metadata.attn_state == \ - AscendAttentionState.PrefillCacheHit: - output = self._forward_prefill_cache_hit( - query, attn_metadata, output) - elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: - output = self._forward_decode_only(query, attn_metadata, - output) - # Normal V1 situation. - else: - if torch.version.cann.startswith("8.3"): - # npu_fused_infer_attention_score does not support cases - # where query.shape[0] != attn_metadata.query_start_loc[-1]. - # Thus we need unpad it here. - num_tokens = attn_metadata.query_start_loc[-1] - query = query[:num_tokens] - output = self._forward_v1_style(query, attn_metadata, output) - - # to make in-place change to the output tensor - if hasattr(layer, 'quant_method') and use_kv_cache_int8: - output = output.view(num_tokens, self.num_heads, self.head_size) - ori_output[:num_tokens, :, :] = output[:num_tokens, :, :] - return output.view(num_tokens, self.hidden_size) - - -def unified_ascend_attention_with_output( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - layer_name: str, -) -> None: - wait_for_kv_layer_from_connector(layer_name) - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): - attn_metadata = attn_metadata[layer_name] - self = forward_context.no_compile_layers[layer_name] - kv_cache = self.kv_cache[forward_context.virtual_engine] - self.impl.forward(self, - query, - key, - value, - kv_cache, - attn_metadata, - output, - trace_flag=False) - maybe_save_kv_layer_to_connector(layer_name, kv_cache) - return - - -def unified_attention_with_output_fake( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - output: torch.Tensor, - layer_name: str, -) -> None: - return - - -direct_register_custom_op( - op_name="unified_ascend_attention_with_output", - op_func=unified_ascend_attention_with_output, - mutates_args=["output"], - fake_impl=unified_attention_with_output_fake, - dispatch_key="PrivateUse1", -) + if torch.version.cann.startswith("8.3"): + # npu_fused_infer_attention_score does not support cases + # where query.shape[0] != attn_metadata.query_start_loc[-1]. + # Thus we need unpad it here. + num_tokens = attn_metadata.query_start_loc[-1] + query = query[:num_tokens] + intermediate_output = self._forward_v1_style( + query, attn_metadata, output) + + output[:num_tokens, :, :] = intermediate_output[:num_tokens, :, :] + + return output diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index d8cf5251ea..e23c9130b6 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -237,9 +237,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "When enabling piecewise aclgraph, please make sure compilation_config.level == CompilationLevel.PIECEWISE and compilation_config.cudagraph_mode == CUDAGraphMode.PIECEWISE" compilation_config.set_splitting_ops_for_v1() compilation_config.use_inductor = False - compilation_config.splitting_ops.extend([ - "vllm.unified_ascend_attention_with_output", "vllm.mla_forward" - ]) + compilation_config.splitting_ops.extend(["vllm.mla_forward"]) update_aclgraph_sizes(vllm_config) elif compilation_config.cudagraph_mode == CUDAGraphMode.FULL_DECODE_ONLY: logger.info( @@ -383,6 +381,10 @@ def get_device_communicator_cls(cls) -> str: def is_pin_memory_available(cls): return True + @classmethod + def opaque_attention_op(cls) -> bool: + return True + @classmethod def get_static_graph_wrapper_cls(cls) -> str: """ diff --git a/vllm_ascend/torchair/models/qwen2.py b/vllm_ascend/torchair/models/qwen2.py index 6e4990d7b8..a5a198e040 100644 --- a/vllm_ascend/torchair/models/qwen2.py +++ b/vllm_ascend/torchair/models/qwen2.py @@ -125,7 +125,6 @@ def forward( v, kv_cache=kv_cache, attn_metadata=attn_metadata, - trace_flag=False, **forward_kwargs) output, _ = self.o_proj(attn_output) return output diff --git a/vllm_ascend/torchair/models/qwen3_moe.py b/vllm_ascend/torchair/models/qwen3_moe.py index 5302f4e7cf..87d927838a 100644 --- a/vllm_ascend/torchair/models/qwen3_moe.py +++ b/vllm_ascend/torchair/models/qwen3_moe.py @@ -251,7 +251,6 @@ def forward( v, kv_cache=kv_cache, attn_metadata=attn_metadata, - trace_flag=False, **forward_kwargs) output, _ = self.o_proj(attn_output) return output diff --git a/vllm_ascend/torchair/models/torchair_pangu_moe.py b/vllm_ascend/torchair/models/torchair_pangu_moe.py index 195ffdedd7..7a0c9c0696 100644 --- a/vllm_ascend/torchair/models/torchair_pangu_moe.py +++ b/vllm_ascend/torchair/models/torchair_pangu_moe.py @@ -625,7 +625,7 @@ def forward( q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) if self.torchair_graph_enabled: - forward_kwargs = {'trace_flag': False} + forward_kwargs = {} output_shape = q.shape attn_output = torch.empty(output_shape, dtype=q.dtype, diff --git a/vllm_ascend/torchair/torchair_attention.py b/vllm_ascend/torchair/torchair_attention.py index 9f1b40e5e1..b229f45eb5 100644 --- a/vllm_ascend/torchair/torchair_attention.py +++ b/vllm_ascend/torchair/torchair_attention.py @@ -314,7 +314,6 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AscendTorchairMetadata, output: Optional[torch.Tensor] = None, - trace_flag: bool = False, ) -> torch.Tensor: """Forward pass with Ascend attention. Args: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 0bfe0f847c..211cb4de9a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -3406,7 +3406,6 @@ def initialize_aclgraph_capture(self) -> None: splitting_ops_contain_attention = ( self.compilation_config.splitting_ops is not None and all(op in self.compilation_config.splitting_ops for op in [ - "vllm.unified_ascend_attention_with_output", "vllm.mla_forward", ]))