Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 23 additions & 105 deletions tests/ut/attention/test_attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,27 +242,6 @@ def setUp(self):
attn_type=self.attention_type.DECODER,
kv_sharing_target_layer_name=None)

@patch('torch.ops.vllm.unified_ascend_attention_with_output')
def test_forward_trace_flag_true(self, mock_unified_attention):
"""Test forward pass when trace_flag is True"""
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 0, 0, 8, 64)
metadata = self.attn_metadata
layer = self.layer

output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=True)

mock_unified_attention.assert_called_once()
assert output.shape == (10, 8 * 64)

@patch('torch_npu._npu_paged_attention_splitfuse')
def test_forward_with_quant_method(self, mock_paged_attention):
"""Test forward pass when layer has quant_method"""
Expand All @@ -284,13 +263,8 @@ def test_forward_with_quant_method(self, mock_paged_attention):
layer.quant_method = MagicMock()
layer.quant_method.apply.return_value = ret_value

output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata)

layer.quant_method.apply.assert_called_once()
assert output.shape == (10, 8 * 64)
Expand All @@ -303,13 +277,7 @@ def test_forward_no_attn_metadata(self):
kv_cache = torch.empty(2, 0, 0, 8, 64)
layer = self.layer_no_quant

output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
None,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache, None)

assert output.shape == (10, 8 * 64)

Expand All @@ -331,13 +299,8 @@ def test_forward_prefill_no_cache(self, mock_flash_attention,
layer = self.layer_no_quant
# layer.quant_method.apply.return_value = metadata
print(self.layer_no_quant._v_scale_float)
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata)

mock_reshape_cache.assert_called_once()
mock_flash_attention.assert_called_once()
Expand All @@ -362,13 +325,8 @@ def test_forward_prefill_cache_hit(self, mock_flash_attention_qlens,
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant

output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata)

mock_flash_attention_qlens.assert_called_once()
assert output.shape == (10, 8 * 64)
Expand All @@ -394,13 +352,8 @@ def test_forward_decode_only(self, mock_paged_attention,

mock_get_forward_context.return_value = MagicMock(capturing=False)

output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata)

mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
Expand Down Expand Up @@ -501,13 +454,8 @@ def test_paged_attention_with_existing_workspace(
mock_get_forward_context.return_value = MagicMock(capturing=True)
mock_get_graph_params.return_value = graph_params

output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata)

mock_paged_attention.assert_called_once()
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
Expand All @@ -530,13 +478,8 @@ def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
layer = self.layer_no_quant
mock_fused_infer_attention_score.return_value = (torch.ones(10, 8,
64), 1)
output = self.impl_swa.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl_swa.forward(layer, query, key, value, kv_cache,
metadata)
print(output.shape)
mock_fused_infer_attention_score.assert_called_once()
assert output.shape == (10, 8 * 64)
Expand Down Expand Up @@ -566,13 +509,8 @@ def test_forward_decode_only_swa_seq_len_mismatch(

mock_get_forward_context.return_value = MagicMock(capturing=False)

output = self.impl_swa.forward(self.layer_no_quant,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl_swa.forward(self.layer_no_quant, query, key, value,
kv_cache, metadata)

mock_paged_attention.assert_called_once()
mock_fused_infer_attention_score.assert_not_called()
Expand Down Expand Up @@ -601,13 +539,8 @@ def test_forward_head_size_192(self, mock_vanilla_prefill,
layer = self.layer_no_quant
mock_vanilla_prefill.return_value = MagicMock()

output = self.impl_192.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl_192.forward(layer, query, key, value, kv_cache,
metadata)

mock_vanilla_prefill.assert_called_once()
assert output.shape == (10, 8 * 192)
Expand All @@ -630,13 +563,8 @@ def test_forward_normal_v1_situation(self, mock_paged_attention,
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant

output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata)

mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
Expand All @@ -663,13 +591,8 @@ def test_forward_310p_device(self, mock_is_310p, mock_paged_attention,
layer = self.layer_no_quant

mock_npu_format_cast.return_value = metadata.attn_mask
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
output = self.impl.forward(layer, query, key, value, kv_cache,
metadata)

mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
Expand All @@ -690,10 +613,5 @@ def test_forward_raise_error(self, mock_paged_attention):
layer = self.layer_no_quant

with self.assertRaises(NotImplementedError):
self.impl_error.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
self.impl_error.forward(layer, query, key, value, kv_cache,
metadata)
Loading
Loading