-
Notifications
You must be signed in to change notification settings - Fork 25
[wave2water] E2E execution of matmul kernel via water middle-end #672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| from wave_lang.kernel.lang.global_symbols import * | ||
| from wave_lang.kernel.lang.wave_types import * | ||
| from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile | ||
| from wave_lang.kernel.wave.constraints import MMAType | ||
| from wave_lang.kernel.wave.mlir_converter.mlir_converter import emit_wave_dialect | ||
| from wave_lang.kernel.wave.utils.run_utils import set_default_run_config | ||
| from wave_lang.kernel.wave.utils.general_utils import run_test | ||
|
|
@@ -135,3 +136,156 @@ def matrix_add( | |
| # CHECK: arith.addf | ||
| # CHECK-NOT: wave.write | ||
| # CHECK: vector.maskedstore | ||
|
|
||
|
|
||
| @run_test | ||
| def test_matmul_water_e2e(): | ||
| """Test Water PassManager with matmul kernel and e2e execution.""" | ||
| torch.manual_seed(0) | ||
|
|
||
| # Input sizes | ||
| M = tkl.sym.M | ||
| N = tkl.sym.N | ||
| K = tkl.sym.K | ||
| # Workgroup tile sizes | ||
| BLOCK_M = tkl.sym.BLOCK_M | ||
| BLOCK_N = tkl.sym.BLOCK_N | ||
| BLOCK_K = tkl.sym.BLOCK_K | ||
| # Address space (for GPU, shared(1) or global(0)) | ||
| ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE | ||
| dtype = tkl.f16 | ||
|
|
||
| # Define constraints for matmul | ||
| constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] | ||
| constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] | ||
| constraints += [tkw.TilingConstraint(K, BLOCK_K)] | ||
| constraints += [tkw.WaveConstraint(M, sympy.floor(BLOCK_M / 2))] | ||
| constraints += [tkw.WaveConstraint(N, sympy.floor(BLOCK_N / 2))] | ||
| constraints += [ | ||
| tkw.HardwareConstraint(threads_per_wave=64, mma_type=MMAType.F32_32x32x8_F16) | ||
| ] | ||
|
|
||
| @tkw.wave(constraints) | ||
| def matmul( | ||
| a: tkl.Memory[M, K, ADDRESS_SPACE, dtype], | ||
| b: tkl.Memory[N, K, ADDRESS_SPACE, dtype], | ||
| c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], | ||
| ): | ||
| c_reg = tkl.Register[M, N, tkl.f32](0.0) | ||
|
|
||
| @tkw.iterate(K, init_args=[c_reg]) | ||
| def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: | ||
| a_reg = tkw.read(a) | ||
| b_reg = tkw.read(b) | ||
| acc = tkw.mma(a_reg, b_reg, acc) | ||
| return acc | ||
|
|
||
| tkw.write(repeat, c) | ||
|
|
||
| m = 1024 | ||
| n = 5120 | ||
| k = 640 | ||
| # Set parameters for compilation | ||
| subs: dict[str | IndexSymbol, Any] = { | ||
| ADDRESS_SPACE: SHARED_ADDRESS_SPACE, | ||
| BLOCK_M: 64, | ||
| BLOCK_N: 64, | ||
| BLOCK_K: 32, | ||
| M: m, | ||
| N: n, | ||
| K: k, | ||
| } | ||
|
|
||
| options_mlir = WaveCompileOptions( | ||
| subs=subs, | ||
| compile_to_mlir=True, | ||
| location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE), | ||
| enforce_locations=False, | ||
| print_mlir=True, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to print mlir? |
||
| ) | ||
| options_mlir = set_default_run_config(options_mlir) | ||
|
|
||
| compiled_kernel = wave_compile(options_mlir, matmul) | ||
| trace = compiled_kernel.compiled_graph | ||
| constraints = matmul.constraints | ||
|
|
||
| # Emit Wave dialect MLIR | ||
| wave_dialect_mlir, diagnostics, _ = emit_wave_dialect( | ||
| trace, constraints, options_mlir | ||
| ) | ||
|
|
||
| # Apply Water PassManager lowering | ||
| lowered_mlir = apply_water_middle_end_passes(wave_dialect_mlir) | ||
|
|
||
| print(lowered_mlir) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's have FileCheck comments here so it doesn't look like forgotten debug output. |
||
|
|
||
| # Create test tensors | ||
| a_tensor = device_randn(m, k, dtype=torch.float16) | ||
| b_tensor = device_randn(n, k, dtype=torch.float16) # Note: transposed in matmul | ||
| c_tensor = device_zeros(m, n, dtype=torch.float32) | ||
|
|
||
| # Expected result (CPU computation) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code above creates tensors on device, why is this called CPU computation? |
||
| expected = torch.matmul(a_tensor.float(), b_tensor.T.float()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is a bad idea to compute expected values with a higher precision than actual values. |
||
|
|
||
| # Test execution with lowered MLIR | ||
| options_e2e = WaveCompileOptions( | ||
| subs=subs, | ||
| canonicalize=True, | ||
| location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE), | ||
| enforce_locations=False, | ||
| override_mlir=lowered_mlir, | ||
| ) | ||
| options_e2e = set_default_run_config(options_e2e) | ||
|
|
||
| compiled_e2e = wave_compile(options_e2e, matmul) | ||
|
|
||
| compiled_e2e(a_tensor, b_tensor, c_tensor) | ||
|
|
||
| assert_close(c_tensor, expected, rtol=1e-3, atol=1e-3) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1e-3 looks a bit too lax, do we really need it? |
||
|
|
||
|
|
||
| # CHECK-LABEL: test_matmul_water_e2e | ||
| # CHECK: module { | ||
| # CHECK-NOT: wave.normal_form | ||
| # | ||
| # Verify function signature with correct memref types. | ||
| # CHECK: func.func @kernel(%{{.*}}: memref<1024x640xf16, #gpu.address_space<global>>, %{{.*}}: memref<5120x640xf16, #gpu.address_space<global>>, %{{.*}}: memref<1024x5120xf32, #gpu.address_space<global>>) | ||
| # | ||
| # Verify shared memory allocations for A and B tiles. | ||
| # CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>> | ||
| # CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>> | ||
| # | ||
| # Verify K-loop structure (20 iterations = 640/32). | ||
| # CHECK-NOT: wave.iterate | ||
| # CHECK: scf.for %{{.*}} = %c0 to %c20 step %c1 iter_args(%{{.*}} = %{{.*}}) -> (vector<16xf32>) | ||
| # | ||
| # Verify global memory loads inside the loop. | ||
| # CHECK-NOT: wave.read | ||
| # CHECK: vector.load %arg0[%{{.*}}, %{{.*}}] : memref<1024x640xf16, #gpu.address_space<global>>, vector<8xf16> | ||
| # | ||
| # Verify LDS barriers for synchronization. | ||
| # CHECK: amdgpu.lds_barrier | ||
| # | ||
| # Verify stores to shared memory. | ||
| # CHECK: vector.store %{{.*}}, %{{.*}} : memref<64x36xf16, #gpu.address_space<workgroup>>, vector<8xf16> | ||
| # | ||
| # Verify load from B matrix. | ||
| # CHECK: vector.load %arg1[%{{.*}}, %{{.*}}] : memref<5120x640xf16, #gpu.address_space<global>>, vector<8xf16> | ||
| # CHECK: vector.store %{{.*}}, %{{.*}} : memref<64x36xf16, #gpu.address_space<workgroup>>, vector<8xf16> | ||
| # CHECK: amdgpu.lds_barrier | ||
| # | ||
| # Verify loads from shared memory for MMA operands. | ||
| # CHECK: vector.load %{{.*}} : memref<64x36xf16, #gpu.address_space<workgroup>>, vector<4xf16> | ||
| # | ||
| # Verify MMA operations (4 mfma 32x32x8 ops per iteration). | ||
| # CHECK-NOT: wave.mma | ||
| # CHECK: amdgpu.mfma 32x32x8 %{{.*}} * %{{.*}} + %{{.*}} {{.*}} : vector<4xf16>, vector<4xf16>, vector<16xf32> | ||
| # CHECK: amdgpu.mfma 32x32x8 %{{.*}} * %{{.*}} + %{{.*}} {{.*}} : vector<4xf16>, vector<4xf16>, vector<16xf32> | ||
| # CHECK: amdgpu.mfma 32x32x8 %{{.*}} * %{{.*}} + %{{.*}} {{.*}} : vector<4xf16>, vector<4xf16>, vector<16xf32> | ||
| # CHECK: amdgpu.mfma 32x32x8 %{{.*}} * %{{.*}} + %{{.*}} {{.*}} : vector<4xf16>, vector<4xf16>, vector<16xf32> | ||
| # CHECK: scf.yield %{{.*}} : vector<16xf32> | ||
| # | ||
| # Verify extract_strided_slice and stores for output (16 elements per thread). | ||
| # CHECK: vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [1], strides = [1]} : vector<16xf32> to vector<1xf32> | ||
| # CHECK-NOT: wave.write | ||
| # CHECK: vector.store %{{.*}}, %arg2[%{{.*}}, %{{.*}}] : memref<1024x5120xf32, #gpu.address_space<global>>, vector<1xf32> | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use one from
templates?