Skip to content

Commit 5c00f43

Browse files
[TEST] Only check_threads_supported on XPU (#3736)
7d24ef4 accidentally reduced the pass rate on interpreter mode, as `get_current_device` returns xpu. Pass rate: 89.76% -> 89.99% Signed-off-by: Whitney Tsang <[email protected]>
1 parent ac6492f commit 5c00f43

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

python/test/unit/language/test_core.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,10 @@ def check_type_supported(dtype, device):
137137
pytest.xfail("float64 not supported on current xpu hardware")
138138

139139

140-
def check_threads_supported(num_warps, threads_per_warp):
141-
device = triton.runtime.driver.active.get_current_device()
142-
props = triton.runtime.driver.active.utils.get_device_properties(device)
140+
def check_threads_supported(num_warps, threads_per_warp, device):
141+
if device != "xpu":
142+
return
143+
props = triton.runtime.driver.active.utils.get_device_properties(triton.runtime.driver.active.get_current_device())
143144
if threads_per_warp not in props['sub_group_sizes']:
144145
pytest.xfail('unsupported warp size')
145146
if threads_per_warp * num_warps > props['max_work_group_size']:
@@ -2366,7 +2367,7 @@ def get_reduced_dtype(dtype_str, op):
23662367
[(64, 16), (4, THREADS_PER_WARP)] if is_xpu() else [(4, THREADS_PER_WARP)])
23672368
def test_reduce1d(op, dtype_str, shape, num_ctas, num_warps, threads_per_warp, device):
23682369
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested
2369-
check_threads_supported(num_warps, threads_per_warp)
2370+
check_threads_supported(num_warps, threads_per_warp, device)
23702371

23712372
# triton kernel
23722373
@triton.jit
@@ -2475,7 +2476,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
24752476
[(64, 16), (4, THREADS_PER_WARP)] if is_xpu() else [(4, THREADS_PER_WARP)])
24762477
def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, num_warps, threads_per_warp, device):
24772478
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested
2478-
check_threads_supported(num_warps, threads_per_warp)
2479+
check_threads_supported(num_warps, threads_per_warp, device)
24792480

24802481
@triton.jit
24812482
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr,

0 commit comments

Comments
 (0)