@@ -137,9 +137,10 @@ def check_type_supported(dtype, device):
137
137
pytest .xfail ("float64 not supported on current xpu hardware" )
138
138
139
139
140
- def check_threads_supported (num_warps , threads_per_warp ):
141
- device = triton .runtime .driver .active .get_current_device ()
142
- props = triton .runtime .driver .active .utils .get_device_properties (device )
140
+ def check_threads_supported (num_warps , threads_per_warp , device ):
141
+ if device != "xpu" :
142
+ return
143
+ props = triton .runtime .driver .active .utils .get_device_properties (triton .runtime .driver .active .get_current_device ())
143
144
if threads_per_warp not in props ['sub_group_sizes' ]:
144
145
pytest .xfail ('unsupported warp size' )
145
146
if threads_per_warp * num_warps > props ['max_work_group_size' ]:
@@ -2366,7 +2367,7 @@ def get_reduced_dtype(dtype_str, op):
2366
2367
[(64 , 16 ), (4 , THREADS_PER_WARP )] if is_xpu () else [(4 , THREADS_PER_WARP )])
2367
2368
def test_reduce1d (op , dtype_str , shape , num_ctas , num_warps , threads_per_warp , device ):
2368
2369
check_type_supported (dtype_str , device ) # bfloat16 on cc < 80 will not be tested
2369
- check_threads_supported (num_warps , threads_per_warp )
2370
+ check_threads_supported (num_warps , threads_per_warp , device )
2370
2371
2371
2372
# triton kernel
2372
2373
@triton .jit
@@ -2475,7 +2476,7 @@ def kernel(X, Z, BLOCK: tl.constexpr):
2475
2476
[(64 , 16 ), (4 , THREADS_PER_WARP )] if is_xpu () else [(4 , THREADS_PER_WARP )])
2476
2477
def test_reduce (op , dtype_str , shape , axis , keep_dims , num_ctas , num_warps , threads_per_warp , device ):
2477
2478
check_type_supported (dtype_str , device ) # bfloat16 on cc < 80 will not be tested
2478
- check_threads_supported (num_warps , threads_per_warp )
2479
+ check_threads_supported (num_warps , threads_per_warp , device )
2479
2480
2480
2481
@triton .jit
2481
2482
def kernel (X , Z , BLOCK_M : tl .constexpr , BLOCK_N : tl .constexpr , BLOCK_K : tl .constexpr , IS_3D : tl .constexpr ,
0 commit comments