diff --git a/lit_tests/kernel/wave/mlir_converter.py b/lit_tests/kernel/wave/mlir_converter.py index efa64eb582..f622781282 100644 --- a/lit_tests/kernel/wave/mlir_converter.py +++ b/lit_tests/kernel/wave/mlir_converter.py @@ -19,6 +19,7 @@ from wave_lang.kernel.wave.mlir_converter.mlir_converter import emit_wave_dialect from wave_lang.kernel.wave.utils.run_utils import set_default_run_config from wave_lang.kernel.wave.utils.general_utils import run_test +from wave_lang.kernel.wave.water import apply_water_middle_end_passes from wave_lang.support.location_config import ( LocationCaptureConfig, LocationCaptureLevel, @@ -231,9 +232,13 @@ def mlir_converter_matrix_add(): len(diagnostics) == 0 ), "dialect emission should create valid IR, therefore diagnostics should be empty" - # Print to stdout for FileCheck + # Print to stdout for FileCheck. print(mlir_output) + # Apply Water middle-end pipeline. + lowered_mlir = apply_water_middle_end_passes(mlir_output) + print(lowered_mlir) + # CHECK-LABEL: mlir_converter_matrix_add # CHECK: module # CHECK-NEXT: func.func @kernel(%[[ARG0:.*]]: !wave.tensor<[@M, @N] of f16, >, %[[ARG1:.*]]: !wave.tensor<[@M, @N] of f16, >, %[[ARG2:.*]]: !wave.tensor<[@M, @N] of f32, >) @@ -294,6 +299,17 @@ def mlir_converter_matrix_add(): # CHECK: return + # Water lowered output. + # CHECK: module { + # CHECK: func.func @kernel( + # CHECK-NOT: wave.read + # CHECK: vector.maskedload + # CHECK: vector.maskedload + # CHECK-NOT: wave.add + # CHECK: arith.addf + # CHECK-NOT: wave.write + # CHECK: vector.maskedstore + @run_test def mlir_converter_matmul(): @@ -396,7 +412,7 @@ def pipeline(root: OpHandle): len(diagnostics) == 0 ), "dialect emission should create valid IR, therefore diagnostics should be empty" - # Print to stdout for FileCheck + # Print to stdout for FileCheck. # CHECK-LABEL: mlir_converter_matmul print(pipeline_asm) # CHECK: module @@ -512,6 +528,20 @@ def pipeline(root: OpHandle): # CHECK-NEXT: wave.write %[[SLICE_15]], %[[ARG2]] # CHECK-NEXT: return + # Apply Water middle-end pipeline. + lowered_mlir = apply_water_middle_end_passes(mlir_output) + print(lowered_mlir) + + # Water lowered output. + # CHECK: module { + # CHECK: func.func @kernel( + # CHECK: memref.alloc() : memref<9216xi8, #gpu.address_space> + # CHECK: memref.view + # CHECK-NOT: wave.iterate + # CHECK: scf.for + # CHECK-NOT: wave.mma + # CHECK: amdgpu.mfma 32x32x8 + @run_test def mlir_converter_mixed_memory_spaces(): diff --git a/lit_tests/kernel/wave/mlir_converter_e2e.py b/lit_tests/kernel/wave/mlir_converter_e2e.py deleted file mode 100644 index 82788e6fca..0000000000 --- a/lit_tests/kernel/wave/mlir_converter_e2e.py +++ /dev/null @@ -1,137 +0,0 @@ -# REQUIRES: water -# RUN: python %s | FileCheck %s - -import torch -from torch.testing import assert_close -from typing import Any -import sympy - -from wave_lang.kernel._support.indexing import IndexSymbol -import wave_lang.kernel.wave as wave -import wave_lang.kernel.lang as tkl -import wave_lang.kernel.wave as tkw -from wave_lang.kernel.lang.global_symbols import * -from wave_lang.kernel.lang.wave_types import * -from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile -from wave_lang.kernel.wave.mlir_converter.mlir_converter import emit_wave_dialect -from wave_lang.kernel.wave.utils.run_utils import set_default_run_config -from wave_lang.kernel.wave.utils.general_utils import run_test -from wave_lang.kernel.wave.utils.torch_utils import device_randn, device_zeros -from wave_lang.kernel.wave.water import apply_water_middle_end_passes -from wave_lang.support.location_config import ( - LocationCaptureConfig, - LocationCaptureLevel, -) - - -@run_test -def test_matrix_add_water_e2e(): - """Test Water PassManager with Wave MLIR dialect generation and e2e execution.""" - torch.manual_seed(0) - - # Simple matrix addition kernel - M = tkl.sym.M - N = tkl.sym.N - BLOCK_M = tkl.sym.BLOCK_M - BLOCK_N = tkl.sym.BLOCK_N - ADDRESS_SPACE_A = tkl.sym.ADDRESS_SPACE_A - ADDRESS_SPACE_B = tkl.sym.ADDRESS_SPACE_B - ADDRESS_SPACE_C = tkl.sym.ADDRESS_SPACE_C - - # Define constraints for the kernel - constraints = [ - tkw.WorkgroupConstraint(M, BLOCK_M, 0), - tkw.WorkgroupConstraint(N, BLOCK_N, 1), - tkw.WaveConstraint(M, sympy.floor(BLOCK_M / 2)), - tkw.WaveConstraint(N, sympy.floor(BLOCK_N / 2)), - tkw.HardwareConstraint( - threads_per_wave=64, vector_shapes={M: BLOCK_M, N: BLOCK_N} - ), - ] - - @wave.wave(constraints) - def matrix_add( - a: Memory[M, N, ADDRESS_SPACE_A, tkl.f16], - b: Memory[M, N, ADDRESS_SPACE_B, tkl.f16], - c: Memory[M, N, ADDRESS_SPACE_C, tkl.f16], - ): - # Load values from memory into registers - a_reg = wave.read(a) - b_reg = wave.read(b) - - # Compute the sum - c_reg = a_reg + b_reg - - # Write results back to memory - wave.write(c_reg, c) - - # Set parameters for compilation - subs: dict[str | IndexSymbol, Any] = { - ADDRESS_SPACE_A: GLOBAL_ADDRESS_SPACE, - ADDRESS_SPACE_B: GLOBAL_ADDRESS_SPACE, - ADDRESS_SPACE_C: GLOBAL_ADDRESS_SPACE, - BLOCK_M: 64, - BLOCK_N: 64, - M: 128, - N: 128, - } - - options_mlir = WaveCompileOptions( - subs=subs, - compile_to_mlir=True, - location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE), - enforce_locations=False, - ) - options_mlir = set_default_run_config(options_mlir) - - compiled_kernel = wave_compile(options_mlir, matrix_add) - trace = compiled_kernel.compiled_graph - constraints = matrix_add.constraints - - # Emit Wave dialect MLIR - wave_dialect_mlir, diagnostics, _ = emit_wave_dialect( - trace, constraints, options_mlir - ) - - # Apply Water PassManager lowering - lowered_mlir = apply_water_middle_end_passes(wave_dialect_mlir) - - print(lowered_mlir) - - # Create test tensors - shape = (128, 128) - a_tensor = device_randn(*shape, dtype=torch.float16) - b_tensor = device_randn(*shape, dtype=torch.float16) - c_tensor = device_zeros(*shape, dtype=torch.float16) - - # Expected result (CPU computation) - expected = a_tensor + b_tensor - - # Test execution with lowered MLIR - options_e2e = WaveCompileOptions( - subs=subs, - canonicalize=True, - location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE), - enforce_locations=False, - override_mlir=lowered_mlir, - ) - options_e2e = set_default_run_config(options_e2e) - - compiled_e2e = wave_compile(options_e2e, matrix_add) - - compiled_e2e(a_tensor, b_tensor, c_tensor) - - assert_close(c_tensor, expected, rtol=1e-4, atol=1e-4) - - -# CHECK-LABEL: test_matrix_add_water_e2e -# CHECK: module -# CHECK-NOT: wave.normal_form -# CHECK: func.func @kernel( -# CHECK-NOT: wave.read -# CHECK: vector.maskedload -# CHECK: vector.maskedload -# CHECK-NOT: wave.add -# CHECK: arith.addf -# CHECK-NOT: wave.write -# CHECK: vector.maskedstore diff --git a/tests/kernel/wave/water_e2e_test.py b/tests/kernel/wave/water_e2e_test.py new file mode 100644 index 0000000000..3292c22ac0 --- /dev/null +++ b/tests/kernel/wave/water_e2e_test.py @@ -0,0 +1,88 @@ +"""End-to-end tests for Water middle-end pipeline.""" + +import torch +from torch.testing import assert_close + +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.constraints import MMAType +from wave_lang.kernel.wave.mlir_converter.mlir_converter import emit_wave_dialect +from wave_lang.kernel.wave.utils.run_utils import set_default_run_config +from wave_lang.kernel.wave.utils.torch_utils import device_randn, device_zeros +from wave_lang.kernel.wave.water import apply_water_middle_end_passes +from wave_lang.support.location_config import ( + LocationCaptureConfig, + LocationCaptureLevel, +) + +from tests.kernel.common.utils import require_e2e, require_water_and_ee + + +def _run_matmul_water_e2e(minimize_shared_allocs: bool): + """Test Water PassManager with matmul kernel and e2e execution.""" + from wave_lang.kernel.wave.templates.gemm import get_gemm_kernel + + m = 1024 + n = 5120 + k = 640 + + gemm, hyperparams, _ = get_gemm_kernel( + shape=(m, n, k), + dynamic_dims=False, + mfma_variant=MMAType.F32_32x32x8_F16, + block_shape=(64, 64, 32), + waves_per_block=(2, 2), + ) + + options_mlir = WaveCompileOptions( + subs=hyperparams, + compile_to_mlir=True, + location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE), + enforce_locations=False, + minimize_shared_allocs=minimize_shared_allocs, + ) + options_mlir = set_default_run_config(options_mlir) + + compiled_kernel = wave_compile(options_mlir, gemm) + trace = compiled_kernel.compiled_graph + constraints = gemm.constraints + + wave_dialect_mlir, diagnostics, _ = emit_wave_dialect( + trace, constraints, options_mlir + ) + + lowered_mlir = apply_water_middle_end_passes(wave_dialect_mlir) + + a_tensor = device_randn(m, k, dtype=torch.float16) + b_tensor = device_randn(n, k, dtype=torch.float16) + c_tensor = device_zeros(m, n, dtype=torch.float32) + + expected = torch.matmul(a_tensor, b_tensor.T).float() + + options_e2e = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + location_capture_config=LocationCaptureConfig(level=LocationCaptureLevel.NONE), + enforce_locations=False, + override_mlir=lowered_mlir, + minimize_shared_allocs=minimize_shared_allocs, + ) + options_e2e = set_default_run_config(options_e2e) + + compiled_e2e = wave_compile(options_e2e, gemm) + compiled_e2e(a_tensor, b_tensor, c_tensor) + + assert_close(c_tensor, expected, rtol=1e-3, atol=1e-3) + + +@require_e2e +@require_water_and_ee +def test_matmul_water_e2e(): + """Test matmul with separate shared memory allocations.""" + _run_matmul_water_e2e(minimize_shared_allocs=False) + + +@require_e2e +@require_water_and_ee +def test_matmul_water_e2e_minimize_shared_allocs(): + """Test matmul with minimized shared memory allocations (parent allocations).""" + _run_matmul_water_e2e(minimize_shared_allocs=True) diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 4fee351d2b..b9b277e469 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -477,8 +477,11 @@ def apply_water_middle_end_passes(mlir_text: str) -> str: This function applies the following passes: - water-wave-detect-normal-forms - - water-wave-propagate-elements-per-thread (nested in normalform.module) - - lower-wave-to-mlir (nested in normalform.module) + - (nested in normalform.module) + - water-wave-propagate-elements-per-thread + - water-wave-resolve-distributed-allocations + - water-wave-detect-normal-forms + - lower-wave-to-mlir - canonicalize - cse @@ -500,7 +503,12 @@ def apply_water_middle_end_passes(mlir_text: str) -> str: pass_pipeline = ( "--pass-pipeline=builtin.module(" "water-wave-detect-normal-forms," - "normalform.module(water-wave-propagate-elements-per-thread,lower-wave-to-mlir)," + "normalform.module(" + "water-wave-propagate-elements-per-thread," + "water-wave-resolve-distributed-allocations," + "water-wave-detect-normal-forms," + "lower-wave-to-mlir" + ")," "lower-normalform-module," "canonicalize," "cse"