Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions tests/others/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
pytest tests/others/test_attention_backends.py
```

Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
Tests were conducted on an H100 with PyTorch 2.9.1 (CUDA 12.9).

Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
Expand All @@ -24,6 +23,8 @@
import pytest
import torch

from ..testing_utils import numpy_cosine_similarity_distance


pytestmark = pytest.mark.skipif(
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
Expand All @@ -36,51 +37,61 @@
FORWARD_CASES = [
(
"flash_hub",
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16),
1e-4
),
(
"_flash_3_hub",
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
1e-4
),
(
"native",
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
),
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16),
1e-4
),
(
"_native_cudnn",
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
5e-4
),
(
"aiter",
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
1e-4
)
]

COMPILE_CASES = [
(
"flash_hub",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
True
True,
1e-4
),
(
"_flash_3_hub",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
True,
1e-4
),
(
"native",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
True,
1e-4
),
(
"_native_cudnn",
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
True,
5e-4,
),
(
"aiter",
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
True,
1e-4
)
]
# fmt: on
Expand All @@ -104,11 +115,11 @@ def _backend_is_probably_supported(pipe, name: str):
return False


def _check_if_slices_match(output, expected_slice):
def _check_if_slices_match(output, expected_slice, expected_diff=1e-4):
img = output.images.detach().cpu()
generated_slice = img.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
assert numpy_cosine_similarity_distance(generated_slice, expected_slice) < expected_diff


@pytest.fixture(scope="session")
Expand All @@ -126,23 +137,23 @@ def pipe(device):
return pipe


@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
def test_forward(pipe, backend_name, expected_slice):
@pytest.mark.parametrize("backend_name,expected_slice,expected_diff", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
def test_forward(pipe, backend_name, expected_slice, expected_diff):
out = _backend_is_probably_supported(pipe, backend_name)
if isinstance(out, bool):
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")

modified_pipe = out[0]
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
_check_if_slices_match(out, expected_slice)
_check_if_slices_match(out, expected_slice, expected_diff)


@pytest.mark.parametrize(
"backend_name,expected_slice,error_on_recompile",
"backend_name,expected_slice,error_on_recompile,expected_diff",
COMPILE_CASES,
ids=[c[0] for c in COMPILE_CASES],
)
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile, expected_diff):
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")

Expand All @@ -160,4 +171,4 @@ def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recom
):
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))

_check_if_slices_match(out, expected_slice)
_check_if_slices_match(out, expected_slice, expected_diff)
4 changes: 4 additions & 0 deletions tests/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def torch_all_close(a, b, *args, **kwargs):


def numpy_cosine_similarity_distance(a, b):
if isinstance(a, torch.Tensor):
a = a.detach().cpu().float().numpy()
if isinstance(b, torch.Tensor):
b = b.detach().cpu().float().numpy()
similarity = np.dot(a, b) / (norm(a) * norm(b))
distance = 1.0 - similarity.mean()

Expand Down
Loading