From 4088b47b0f47b4cecc8c1a221d26e25bd4feda6d Mon Sep 17 00:00:00 2001 From: Eric Shi Date: Tue, 18 Feb 2025 11:35:04 -0800 Subject: [PATCH] Improve error reporting in GEMM benchmark --- warp/examples/benchmarks/benchmark_gemm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/warp/examples/benchmarks/benchmark_gemm.py b/warp/examples/benchmarks/benchmark_gemm.py index 43855bf32..ef92992cc 100644 --- a/warp/examples/benchmarks/benchmark_gemm.py +++ b/warp/examples/benchmarks/benchmark_gemm.py @@ -95,7 +95,11 @@ def benchmark_warp(A: wp.array, B: wp.array, config: List[int], warm_up: int, it # check output if warm_up > 0: - assert np.allclose(output.numpy(), A.numpy() @ B.numpy(), atol=1e-3, rtol=1e-3) + try: + np.testing.assert_allclose(output.numpy(), A.numpy() @ B.numpy(), atol=1e-3, rtol=1e-3) + except AssertionError as e: + print(f"Failed with {TILE_M=}, {TILE_N=}, {TILE_K=}, {BLOCK_DIM=}") + raise e # benchmark with wp.ScopedTimer("warp", print=False, synchronize=True, cuda_filter=wp.TIMING_KERNEL) as timer: