diff --git a/tests/models/quantization/test_bitsandbytes.py b/tests/models/quantization/test_bitsandbytes.py index 24220978534c..808ffd89f47f 100644 --- a/tests/models/quantization/test_bitsandbytes.py +++ b/tests/models/quantization/test_bitsandbytes.py @@ -10,13 +10,14 @@ from tests.quantization.utils import is_quant_method_supported from vllm.platforms import current_platform +from vllm.platforms.rocm import on_gfx9 from ...utils import compare_two_settings, multi_gpu_test from ..utils import check_embeddings_close, check_logprobs_close pytestmark = pytest.mark.skipif( - current_platform.is_rocm(), - reason="bitsandbytes quantization not supported on ROCm (CUDA-only kernels)", + current_platform.is_rocm() and on_gfx9(), + reason="bitsandbytes quantization not supported on Instinct (warp size 64 limitation)", ) models_4bit_to_test = [ diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 1abd6300036d..376e61451d00 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -200,6 +200,9 @@ class RocmPlatform(Platform): "petit_nvfp4", "torchao", ] + # bitsandbytes quantization not supported on Instinct (warp size 64 limitation) + if not on_gfx9(): + supported_quantization += ["bitsandbytes"] @classmethod def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":