Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions lit_tests/kernel/wave/mlir_converter_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.lang.wave_types import *
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
from wave_lang.kernel.wave.constraints import MMAType
from wave_lang.kernel.wave.mlir_converter.mlir_converter import emit_wave_dialect
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
from wave_lang.kernel.wave.utils.general_utils import run_test
Expand Down Expand Up @@ -135,3 +136,156 @@ def matrix_add(
# CHECK: arith.addf
# CHECK-NOT: wave.write
# CHECK: vector.maskedstore


@run_test
def test_matmul_water_e2e():
"""Test Water PassManager with matmul kernel and e2e execution."""
torch.manual_seed(0)

# Input sizes
M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
dtype = tkl.f16

# Define constraints for matmul
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]
constraints += [tkw.WaveConstraint(M, sympy.floor(BLOCK_M / 2))]
constraints += [tkw.WaveConstraint(N, sympy.floor(BLOCK_N / 2))]
constraints += [
tkw.HardwareConstraint(threads_per_wave=64, mma_type=MMAType.F32_32x32x8_F16)
]

@tkw.wave(constraints)
def matmul(
a: tkl.Memory[M, K, ADDRESS_SPACE, dtype],
b: tkl.Memory[N, K, ADDRESS_SPACE, dtype],
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.iterate(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a)
b_reg = tkw.read(b)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c)
Comment on lines +167 to +183
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use one from templates?


m = 1024
n = 5120
k = 640
# Set parameters for compilation
subs: dict[str | IndexSymbol, Any] = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
BLOCK_M: 64,
BLOCK_N: 64,
BLOCK_K: 32,
M: m,
N: n,
K: k,
}

options_mlir = WaveCompileOptions(
subs=subs,
compile_to_mlir=True,
location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE),
enforce_locations=False,
print_mlir=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to print mlir?

)
options_mlir = set_default_run_config(options_mlir)

compiled_kernel = wave_compile(options_mlir, matmul)
trace = compiled_kernel.compiled_graph
constraints = matmul.constraints

# Emit Wave dialect MLIR
wave_dialect_mlir, diagnostics, _ = emit_wave_dialect(
trace, constraints, options_mlir
)

# Apply Water PassManager lowering
lowered_mlir = apply_water_middle_end_passes(wave_dialect_mlir)

print(lowered_mlir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have FileCheck comments here so it doesn't look like forgotten debug output.


# Create test tensors
a_tensor = device_randn(m, k, dtype=torch.float16)
b_tensor = device_randn(n, k, dtype=torch.float16) # Note: transposed in matmul
c_tensor = device_zeros(m, n, dtype=torch.float32)

# Expected result (CPU computation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code above creates tensors on device, why is this called CPU computation?

expected = torch.matmul(a_tensor.float(), b_tensor.T.float())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bad idea to compute expected values with a higher precision than actual values.


# Test execution with lowered MLIR
options_e2e = WaveCompileOptions(
subs=subs,
canonicalize=True,
location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE),
enforce_locations=False,
override_mlir=lowered_mlir,
)
options_e2e = set_default_run_config(options_e2e)

compiled_e2e = wave_compile(options_e2e, matmul)

compiled_e2e(a_tensor, b_tensor, c_tensor)

assert_close(c_tensor, expected, rtol=1e-3, atol=1e-3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1e-3 looks a bit too lax, do we really need it?



# CHECK-LABEL: test_matmul_water_e2e
# CHECK: module {
# CHECK-NOT: wave.normal_form
#
# Verify function signature with correct memref types.
# CHECK: func.func @kernel(%{{.*}}: memref<1024x640xf16, #gpu.address_space<global>>, %{{.*}}: memref<5120x640xf16, #gpu.address_space<global>>, %{{.*}}: memref<1024x5120xf32, #gpu.address_space<global>>)
#
# Verify shared memory allocations for A and B tiles.
# CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>>
# CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space<workgroup>>
#
# Verify K-loop structure (20 iterations = 640/32).
# CHECK-NOT: wave.iterate
# CHECK: scf.for %{{.*}} = %c0 to %c20 step %c1 iter_args(%{{.*}} = %{{.*}}) -> (vector<16xf32>)
#
# Verify global memory loads inside the loop.
# CHECK-NOT: wave.read
# CHECK: vector.load %arg0[%{{.*}}, %{{.*}}] : memref<1024x640xf16, #gpu.address_space<global>>, vector<8xf16>
#
# Verify LDS barriers for synchronization.
# CHECK: amdgpu.lds_barrier
#
# Verify stores to shared memory.
# CHECK: vector.store %{{.*}}, %{{.*}} : memref<64x36xf16, #gpu.address_space<workgroup>>, vector<8xf16>
#
# Verify load from B matrix.
# CHECK: vector.load %arg1[%{{.*}}, %{{.*}}] : memref<5120x640xf16, #gpu.address_space<global>>, vector<8xf16>
# CHECK: vector.store %{{.*}}, %{{.*}} : memref<64x36xf16, #gpu.address_space<workgroup>>, vector<8xf16>
# CHECK: amdgpu.lds_barrier
#
# Verify loads from shared memory for MMA operands.
# CHECK: vector.load %{{.*}} : memref<64x36xf16, #gpu.address_space<workgroup>>, vector<4xf16>
#
# Verify MMA operations (4 mfma 32x32x8 ops per iteration).
# CHECK-NOT: wave.mma
# CHECK: amdgpu.mfma 32x32x8 %{{.*}} * %{{.*}} + %{{.*}} {{.*}} : vector<4xf16>, vector<4xf16>, vector<16xf32>
# CHECK: amdgpu.mfma 32x32x8 %{{.*}} * %{{.*}} + %{{.*}} {{.*}} : vector<4xf16>, vector<4xf16>, vector<16xf32>
# CHECK: amdgpu.mfma 32x32x8 %{{.*}} * %{{.*}} + %{{.*}} {{.*}} : vector<4xf16>, vector<4xf16>, vector<16xf32>
# CHECK: amdgpu.mfma 32x32x8 %{{.*}} * %{{.*}} + %{{.*}} {{.*}} : vector<4xf16>, vector<4xf16>, vector<16xf32>
# CHECK: scf.yield %{{.*}} : vector<16xf32>
#
# Verify extract_strided_slice and stores for output (16 elements per thread).
# CHECK: vector.extract_strided_slice %{{.*}} {offsets = [0], sizes = [1], strides = [1]} : vector<16xf32> to vector<1xf32>
# CHECK-NOT: wave.write
# CHECK: vector.store %{{.*}}, %arg2[%{{.*}}, %{{.*}}] : memref<1024x5120xf32, #gpu.address_space<global>>, vector<1xf32>
Loading