diff --git a/.github/workflows/ci-gpu.yaml b/.github/workflows/ci-gpu.yaml index 198724b4c..1b7000275 100644 --- a/.github/workflows/ci-gpu.yaml +++ b/.github/workflows/ci-gpu.yaml @@ -210,7 +210,13 @@ jobs: - name: Run unit tests run: | - pytest -n 4 --capture=tee-sys -vv ./tests/unittests/ + if [[ "${{ contains(matrix.os, 'mi35x') }}" == 'true' ]]; then + # TODO: water-related tests segfault on mi35x + pytest -n 4 --capture=tee-sys -vv ./tests/unittests/ --ignore=tests/unittests/index_sequence_difference_test.py --ignore=tests/unittests/location_exception.py + else + pytest -n 4 --capture=tee-sys -vv ./tests/unittests/ + fi + pytest -n 4 --capture=tee-sys -vv ./tests/mlir_wave_iface - name: Test TKW runtime related stack on amdgpu if: ${{ env.HAS_GPU == 'true' }} @@ -233,10 +239,9 @@ jobs: - name: Run LIT tests env: - WAVE_TEST_WATER: ${{ env.IS_CDNA3 == 'true' && '1' || '0' }} + WAVE_TEST_WATER: ${{ (env.IS_CDNA3 == 'true' || env.is_CDNA4 == 'true') && '1' || '0' }} WAVE_TEST_DWARFDUMP: ${{ env.IS_RDNA4 == 'false' && '1' || '0' }} run: | - # TODO: mlir_converter tests segfault on mi35x # TODO: can't sudo to install dwarfdump on rdna4 echo "WAVE_TEST_WATER=$WAVE_TEST_WATER" echo "WAVE_TEST_DWARFDUMP=$WAVE_TEST_DWARFDUMP" diff --git a/.github/workflows/ci-happy.yml b/.github/workflows/ci-happy.yml index 3c4eeb058..5fa87faec 100644 --- a/.github/workflows/ci-happy.yml +++ b/.github/workflows/ci-happy.yml @@ -159,6 +159,7 @@ jobs: echo "Run unit tests" WAVE_CACHE_ON=0 python3 -m pytest -n 4 --capture=tee-sys -vv ./tests/unittests/ + WAVE_CACHE_ON=0 python3 -m pytest -n 4 --capture=tee-sys -vv ./tests/mlir_wave_iface ' # 5. TKW runtime related e2e diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 4c70be342..4a7cb8cf8 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -81,6 +81,7 @@ jobs: - name: Run unit tests run: | pytest -n 4 --capture=tee-sys -vv ./tests/unittests/ + pytest -n 4 --capture=tee-sys -vv ./tests/mlir_wave_iface - name: Run LIT tests run: | diff --git a/lit_tests/kernel/wave/infer_index_exprs.py b/lit_tests/kernel/wave/infer_index_exprs.py new file mode 100644 index 000000000..8e8da65e8 --- /dev/null +++ b/lit_tests/kernel/wave/infer_index_exprs.py @@ -0,0 +1,123 @@ +# REQUIRES: water +# RUN: python %s +# The point of this test is to avoid crashing or asserting, so just run it under lit. + +# Copyright 2025 The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import wave_lang.kernel.lang as tkl +import wave_lang.kernel.wave as tkw +from wave_lang.kernel.wave.wave import LaunchableWave +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile + +from wave_lang.kernel.lang.global_symbols import * +from wave_lang.kernel.wave.constraints import MMAType +from wave_lang.kernel.wave.utils.general_utils import torch_dtype_to_wave + +import torch + + +# TODO: use the generic template, currently blocked by water not handling wave constraints. +def _get_gemm_kernel( + shape: tuple[int, int, int], + mfma_variant: MMAType, + dtype: torch.dtype = torch.float16, + block_shape: tuple[int, int, int] | None = None, + waves_per_block: tuple[int, int] | None = None, +) -> tuple[LaunchableWave, dict[tkl.IndexSymbol, tkl.IndexExpr]]: + if not block_shape: + # BLOCK_M, BLOCK_N, BLOCK_K + block_shape = (64, 64, 32) + + if not waves_per_block: + # WAVE_M, WAVE_N + waves_per_block = (2, 2) + + assert len(block_shape) == 3, "block_shape needs to be rank 3 for M, N, K." + assert len(waves_per_block) == 2, "waves_per_block needs to be rank 2 for M, N." + + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.GLOBAL_ADDRESS_SPACE + dtype = torch_dtype_to_wave(dtype) + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + + # TODO: dialect expects waves_per_block to be rank 3, so we append a 1 to the end. + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=mfma_variant, + waves_per_block=waves_per_block + (1,), + ) + ] + + # Wave-level micro-kernel. + # Since warps are not directly addressable, there is no + # explicit notion of a warp id (like a workgroup or thread id). + # This kernel uses the input sizes M, N, K throughout, as the tiling + # and data movement strategy is determined during the compilation process. + # These can be influenced by introducing constraints. + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K, GLOBAL_ADDRESS_SPACE, dtype], + b: tkl.Memory[N, K, GLOBAL_ADDRESS_SPACE, dtype], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + # This microkernel encodes the fact that if the iterate + # dimension were tiled, then we would need to materialize a loop. + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + # a_reg: tkw.Register[M, K, dtype] + a_reg = tkw.read(a) + # b_reg: tkw.Register[N, K, dtype] + b_reg = tkw.read(b) + # acc: tkw.Register[M, N, tkl.f32] + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + # repeat represents the results of the loop + tkw.write(repeat, c) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + BLOCK_M: block_shape[0], + BLOCK_N: block_shape[1], + BLOCK_K: block_shape[2], + M: shape[0], + N: shape[1], + K: shape[2], + } + return gemm, hyperparams + + +def testGemm(): + gemm, hyperparams = _get_gemm_kernel( + shape=(1024, 1024, 1024), mfma_variant=MMAType.F32_16x16x16_F16 + ) + options = WaveCompileOptions( + subs=hyperparams, + run_bench=False, + check_water_analysis=True, + print_mlir_after_water=True, + ) + compiled_gemm = wave_compile(options, gemm) + assert compiled_gemm is not None + + +if __name__ == "__main__": + testGemm() diff --git a/lit_tests/kernel/wave/mlir_converter.py b/lit_tests/kernel/wave/mlir_converter.py index ac63ea959..761e51226 100644 --- a/lit_tests/kernel/wave/mlir_converter.py +++ b/lit_tests/kernel/wave/mlir_converter.py @@ -80,7 +80,7 @@ def failure_to_parse_override_mlir(): # Override the MLIR module after `wave_compile` so it doesn't attempt to parse it. options.override_mlir = "module {" - _, diagnostics = emit_wave_dialect(trace, constraints, options) + _, diagnostics, _ = emit_wave_dialect(trace, constraints, options) assert len(diagnostics) == 1 # CHECK: Unable to parse module assembly @@ -91,7 +91,9 @@ def failure_to_parse_override_mlir(): @run_test def failure_to_parse_pipeline(): trace, options, constraints = _get_dummy_trace_options_and_constraints() - _, diagnostics = emit_wave_dialect(trace, constraints, options, pipeline="module {") + _, diagnostics, _ = emit_wave_dialect( + trace, constraints, options, pipeline="module {" + ) assert len(diagnostics) == 1 # CHECK: Failed to apply transform script: Unable to parse module assembly @@ -102,7 +104,7 @@ def failure_to_parse_pipeline(): @run_test def pipeline_is_empty(): trace, options, constraints = _get_dummy_trace_options_and_constraints() - _, diagnostics = emit_wave_dialect( + _, diagnostics, _ = emit_wave_dialect( trace, constraints, options, pipeline="module {}" ) @@ -115,7 +117,7 @@ def pipeline_is_empty(): @run_test def pipeline_is_not_a_named_sequence(): trace, options, constraints = _get_dummy_trace_options_and_constraints() - _, diagnostics = emit_wave_dialect( + _, diagnostics, _ = emit_wave_dialect( trace, constraints, options, pipeline="module { module {}}" ) @@ -141,7 +143,7 @@ def pipeline_is_not_a_named_sequence(): def failure_in_pipeline(): trace, options, constraints = _get_dummy_trace_options_and_constraints() options.override_mlir = "module {}" - _, diagnostics = emit_wave_dialect( + _, diagnostics, _ = emit_wave_dialect( trace, constraints, options, pipeline=GUARANTEED_FAIL_TRANSFORM_SCRIPT ) assert len(diagnostics) == 1 @@ -158,7 +160,7 @@ def override_mlir(): module { func.func private @overridden_mlir() }""" - emitted, diagnostics = emit_wave_dialect(trace, constraints, options) + emitted, diagnostics, _ = emit_wave_dialect(trace, constraints, options) assert len(diagnostics) == 0, "Did not expect errors in overridden IR." # CHECK: func.func private @overridden_mlir() @@ -218,7 +220,7 @@ def mlir_converter_matrix_add(): constraints = matrix_add.constraints # Use the mlir_converter to emit wave MLIR dialect - mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options) + mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options) if diagnostics: for diagnostic in diagnostics: @@ -374,7 +376,7 @@ def pipeline(root: OpHandle): # Use the mlir_converter to emit wave MLIR dialect and apply the empty # pipeline. - mlir_output, diagnostics = emit_wave_dialect( + mlir_output, diagnostics, _ = emit_wave_dialect( trace, constraints, options, pipeline=pipeline_asm ) @@ -528,7 +530,7 @@ def mixed_memory_kernel( constraints = mixed_memory_kernel.constraints with Context(), Location.unknown(): - mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options) + mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options) assert len(diagnostics) == 0, f"Should have no diagnostics, got: {diagnostics}" @@ -582,7 +584,7 @@ def invalid_hyperparameter_kernel( # This should raise a RuntimeError due to invalid non-int hyperparameter try: with Context(), Location.unknown(): - mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options) + mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options) assert False, "Expected RuntimeError for invalid non-int hyperparameter" except RuntimeError as e: # Verify the error message is what we expect diff --git a/lit_tests/kernel/wave/mlir_converter_debug_locations.py b/lit_tests/kernel/wave/mlir_converter_debug_locations.py index cfcd98491..2a8b0da58 100644 --- a/lit_tests/kernel/wave/mlir_converter_debug_locations.py +++ b/lit_tests/kernel/wave/mlir_converter_debug_locations.py @@ -95,7 +95,7 @@ def mlir_converter_location(): constraints = matrix_add.constraints # Use the mlir_converter to emit wave MLIR dialect - mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options) + mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options) if diagnostics: print(diagnostics) @@ -210,7 +210,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: constraints = matmul.constraints # Use the mlir_converter to emit wave MLIR dialect - mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options) + mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options) if diagnostics: print(diagnostics) diff --git a/lit_tests/kernel/wave/mlir_converter_diagnostics.py b/lit_tests/kernel/wave/mlir_converter_diagnostics.py index 15ac7aa1a..dcd0412ba 100644 --- a/lit_tests/kernel/wave/mlir_converter_diagnostics.py +++ b/lit_tests/kernel/wave/mlir_converter_diagnostics.py @@ -85,7 +85,7 @@ def mlir_converter_diagnostics_emission(): constraints = matrix_add.constraints # Use the mlir_converter to emit wave MLIR dialect - _, diagnostics = emit_wave_dialect( + _, diagnostics, _ = emit_wave_dialect( trace, constraints, options, test_diagnostic_emission=True ) diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 000000000..2bebad9a7 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,9 @@ +[pytest] +testspath = ./tests +filterwarnings = + # TODO: Remove once flatbuffer 'imp' usage resolved. + ignore::DeprecationWarning +# Because of the pytest collection process, it will import all modules from all +# tests, which is undesirable for mlir/wave interfacing tests. Exclude them +# from the global run. +addopts = "--ignore=tests/mlir_wave_iface" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 358360671..000000000 --- a/setup.cfg +++ /dev/null @@ -1,6 +0,0 @@ -[tool:pytest] -testpaths = - ./tests -filterwarnings = - # TODO: Remove once flatbuffer 'imp' usage resolved. - ignore::DeprecationWarning diff --git a/tests/kernel/common/utils.py b/tests/kernel/common/utils.py index 744eaa518..0b5253612 100644 --- a/tests/kernel/common/utils.py +++ b/tests/kernel/common/utils.py @@ -71,7 +71,7 @@ def param_bool(name, shortname=None, values=None): def _is_water_and_ee_available() -> bool: - from wave_lang.kernel.wave.water import is_water_available + from wave_lang.support.detect_water import is_water_available from wave_lang.kernel.wave.execution_engine import is_execution_engine_available return is_water_available() and is_execution_engine_available() diff --git a/tests/kernel/test_water.py b/tests/kernel/test_water.py index 63b48423d..8cc8fecdf 100644 --- a/tests/kernel/test_water.py +++ b/tests/kernel/test_water.py @@ -9,6 +9,8 @@ from unittest.mock import patch from wave_lang.kernel.wave.water import ( apply_water_middle_end_passes, +) +from wave_lang.support.detect_water import ( find_binary, get_water_opt, is_water_available, @@ -36,7 +38,7 @@ def test_apply_water_middle_end_passes_unavailable(self): # Mock find_binary to return None, simulating water-opt not being found get_water_opt.cache_clear() # get_water_opt caches find_binary result try: - with patch("wave_lang.kernel.wave.water.find_binary", return_value=None): + with patch("wave_lang.support.detect_water.find_binary", return_value=None): with pytest.raises(RuntimeError, match="water-opt binary not found"): apply_water_middle_end_passes("module {}") finally: diff --git a/tests/mlir_wave_iface/mlir_to_wave_test.py b/tests/mlir_wave_iface/mlir_to_wave_test.py new file mode 100644 index 000000000..b977027f4 --- /dev/null +++ b/tests/mlir_wave_iface/mlir_to_wave_test.py @@ -0,0 +1,501 @@ +# Copyright 2025 The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest +import sympy +import importlib.util +import os +import sys + +# These are fine since there is no transitive dependency on IREE. +from wave_lang.support.indexing import ( + MMA_ACC_SYMBOL_NAME, + IndexSequence, + index_symbol, + sym, +) +from wave_lang.support.detect_water import is_water_available + +# XXX: This ugly hack adds the directory wave_lang/kernel/wave directory to +# import paths to we can import water_mlir without prefixing it and therefore +# bypass initializers of the wave_lang.kernel.wave package that may import IREE, +# which would clash with Water MLIR bindings. Get the path to `wave_lang/` +# since that does not import IREE, then manually concatenate `wave` to that path +# so avoid touching wave_lang/kernel/wave/__init__.py that will transitively +# import IREE. +# TODO: Remove this hack once either (1) the main wave package doesn't +# systematically import IREE until needed or (2) there's no dependency on IREE +# anymore. +__wave_lang_spec = importlib.util.find_spec("wave_lang") +assert __wave_lang_spec is not None and __wave_lang_spec.origin is not None +__wave_lang_wave_path = os.path.join( + os.path.dirname(__wave_lang_spec.origin), "kernel", "wave" +) +if __wave_lang_wave_path not in sys.path: + sys.path.append(__wave_lang_wave_path) + + +# Only import water_mlir components if water_mlir is available, skip testing otherwise. +if is_water_available(): + from water_mlir.water_mlir import ir + from water_mlir.water_mlir.dialects import wave + + from mlir_converter.mlir_to_wave import ( + _convert_affine_expr_to_sympy_expr, + _convert_index_mapping_attr_to_sympy, + _convert_index_mapping_dict_to_sympy, + convert_index_mapping_array_to_sympy, + _make_piecewise_sequence, + ITER_SYMBOL_NAME_WAVE_PREFIX, + ) + +pytestmark = pytest.mark.skipif( + not is_water_available(), reason="water_mlir not available" +) + + +@pytest.fixture(autouse=True) +def ir_context(): + """Fixture to create and manage IR context for all tests.""" + if is_water_available(): + with ir.Context() as ctx: + wave.register_dialect(ctx) + yield ctx + else: + yield None + + +class TestConvertAffineExprToSympyExpr: + """Tests for _convert_affine_expr_to_sympy_expr function.""" + + def test_constant_expr(self): + """Test conversion of constant affine expressions.""" + expr = ir.AffineConstantExpr.get(42) + result = _convert_affine_expr_to_sympy_expr(expr, []) + assert result == sympy.Integer(42) + + def test_symbol_expr(self): + """Test conversion of symbol affine expressions.""" + x = sympy.Symbol("x") + y = sympy.Symbol("y") + z = sympy.Symbol("z") + symbol_mapping = [x, y, z] + + # Test s0 + expr = ir.AffineSymbolExpr.get(0) + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == x + + # Test s1 + expr = ir.AffineSymbolExpr.get(1) + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == y + + # Test s2 + expr = ir.AffineSymbolExpr.get(2) + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == z + + def test_add_expr(self): + """Test conversion of addition affine expressions.""" + x = sympy.Symbol("x") + y = sympy.Symbol("y") + symbol_mapping = [x, y] + + # Test s0 + s1 + s0 = ir.AffineSymbolExpr.get(0) + s1 = ir.AffineSymbolExpr.get(1) + expr = s0 + s1 + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == x + y + + # Test s0 + 5 + expr = s0 + 5 + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == x + 5 + + def test_mul_expr(self): + """Test conversion of multiplication affine expressions.""" + x = sympy.Symbol("x") + y = sympy.Symbol("y") + symbol_mapping = [x, y] + + # Test s0 * 3 + s0 = ir.AffineSymbolExpr.get(0) + expr = s0 * 3 + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == 3 * x + + # Test s0 * s1 + s1 = ir.AffineSymbolExpr.get(1) + expr = s0 * s1 + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == x * y + + def test_floor_div_expr(self): + """Test conversion of floor division affine expressions.""" + x = sympy.Symbol("x") + y = sympy.Symbol("y") + symbol_mapping = [x, y] + + s0 = ir.AffineSymbolExpr.get(0) + s1 = ir.AffineSymbolExpr.get(1) + + # Test s0 floordiv 2 + expr = ir.AffineExpr.get_floor_div(s0, ir.AffineConstantExpr.get(2)) + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == sympy.floor(x / 2) + + # Test s0 floordiv s1 + expr = ir.AffineExpr.get_floor_div(s0, s1) + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == sympy.floor(x / y) + + def test_ceil_div_expr(self): + """Test conversion of ceiling division affine expressions.""" + x = sympy.Symbol("x") + y = sympy.Symbol("y") + symbol_mapping = [x, y] + + s0 = ir.AffineSymbolExpr.get(0) + s1 = ir.AffineSymbolExpr.get(1) + + # Test s0 ceildiv 4 + expr = ir.AffineExpr.get_ceil_div(s0, ir.AffineConstantExpr.get(4)) + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == sympy.ceiling(x / 4) + + # Test s0 ceildiv s1 + expr = ir.AffineExpr.get_ceil_div(s0, s1) + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == sympy.ceiling(x / y) + + def test_mod_expr(self): + """Test conversion of modulo affine expressions.""" + x = sympy.Symbol("x") + y = sympy.Symbol("y") + symbol_mapping = [x, y] + + s0 = ir.AffineSymbolExpr.get(0) + s1 = ir.AffineSymbolExpr.get(1) + + # Test s0 mod s1 + expr = s0 % s1 + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == x % y + + # Test s0 mod 8 + expr = s0 % 8 + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == x % 8 + + def test_complex_expr(self): + """Test conversion of complex nested affine expressions.""" + x = sympy.Symbol("x") + y = sympy.Symbol("y") + z = sympy.Symbol("z") + symbol_mapping = [x, y, z] + + s0 = ir.AffineSymbolExpr.get(0) + s1 = ir.AffineSymbolExpr.get(1) + s2 = ir.AffineSymbolExpr.get(2) + + # Test (s0 * 2 + s1) floordiv s2 + expr = ir.AffineExpr.get_floor_div(s0 * 2 + s1, s2) + result = _convert_affine_expr_to_sympy_expr(expr, symbol_mapping) + assert result == sympy.floor((2 * x + y) / z) + + def test_unsupported_expr_raises_error(self): + """Test that unsupported expression types raise ValueError.""" + # Create a dimension expression (not supported by the function) + expr = ir.AffineDimExpr.get(0) + + with pytest.raises(ValueError, match="Unsupported affine expression"): + _convert_affine_expr_to_sympy_expr(expr, []) + + +class TestConvertIndexMappingAttrToSympy: + """Tests for _convert_index_mapping_attr_to_sympy function.""" + + def test_basic_index_mapping_with_symbol_attr(self): + """Test conversion of basic index mapping with WaveSymbolAttr.""" + # Create symbols + symbols = [ + wave.WaveSymbolAttr.get("M"), + wave.WaveSymbolAttr.get("N"), + ] + + # Create affine maps + s0 = ir.AffineSymbolExpr.get(0) + s1 = ir.AffineSymbolExpr.get(1) + start_map = ir.AffineMap.get(0, 2, [s0]) + step_map = ir.AffineMap.get(0, 2, [s1]) + stride_map = ir.AffineMap.get(0, 2, [ir.AffineConstantExpr.get(1)]) + + # Create index mapping attribute + attr = wave.WaveIndexMappingAttr.get(symbols, start_map, step_map, stride_map) + + # Convert to sympy + result = _convert_index_mapping_attr_to_sympy(attr) + + # Check the result + assert isinstance(result, IndexSequence) + assert result.start == index_symbol("M") + assert result.size == index_symbol("N") + assert result.stride == 1 + + def test_index_mapping_with_index_symbol_attr(self): + """Test conversion with WaveIndexSymbolAttr (special symbols like $WG0).""" + # Create symbols including WaveIndexSymbolAttr + symbols = [ + wave.WaveIndexSymbolAttr.get(wave.WaveIndexSymbol.WORKGROUP_0), + wave.WaveSymbolAttr.get("BLOCK_M"), + ] + + s0 = ir.AffineSymbolExpr.get(0) + s1 = ir.AffineSymbolExpr.get(1) + start_map = ir.AffineMap.get(0, 2, [s0 * 3]) + step_map = ir.AffineMap.get(0, 2, [s1]) + stride_map = ir.AffineMap.get(0, 2, [s0 + s1]) + + attr = wave.WaveIndexMappingAttr.get(symbols, start_map, step_map, stride_map) + result = _convert_index_mapping_attr_to_sympy(attr) + + assert isinstance(result, IndexSequence) + # $WG0 should be converted to index_symbol("$WG0") + assert result.start == index_symbol("$WG0") * 3 + assert result.size == index_symbol("BLOCK_M") + assert result.stride == index_symbol("$WG0") + index_symbol("BLOCK_M") + + def test_index_mapping_with_iter_symbol_attr(self): + """Test conversion with WaveIterSymbolAttr (iteration symbols).""" + # Create symbols including WaveIterSymbolAttr + symbols = [ + wave.WaveIterSymbolAttr.get("i"), + ] + + s0 = ir.AffineSymbolExpr.get(0) + start_map = ir.AffineMap.get(0, 1, [s0 * 2]) + step_map = ir.AffineMap.get(0, 1, [ir.AffineConstantExpr.get(16)]) + stride_map = ir.AffineMap.get(0, 1, [ir.AffineConstantExpr.get(1)]) + + attr = wave.WaveIndexMappingAttr.get(symbols, start_map, step_map, stride_map) + result = _convert_index_mapping_attr_to_sympy(attr) + + assert isinstance(result, IndexSequence) + # Iter symbol should be prefixed with ITER_SYMBOL_NAME_WAVE_PREFIX ($ARG) + assert result.start == index_symbol(ITER_SYMBOL_NAME_WAVE_PREFIX + "i") * 2 + assert result.size == 16 + assert result.stride == 1 + + +class TestConvertIndexMappingDictToSympy: + """Tests for _convert_index_mapping_dict_to_sympy function.""" + + def test_single_mapping(self): + """Test conversion of dict with single index mapping.""" + # Create a simple index mapping + symbols = [wave.WaveSymbolAttr.get("M")] + s0 = ir.AffineSymbolExpr.get(0) + start_map = ir.AffineMap.get(0, 1, [s0]) + step_map = ir.AffineMap.get(0, 1, [ir.AffineConstantExpr.get(16)]) + stride_map = ir.AffineMap.get(0, 1, [ir.AffineConstantExpr.get(1)]) + mapping_attr = wave.WaveIndexMappingAttr.get( + symbols, start_map, step_map, stride_map + ) + + # Create dict attribute + dict_attr = ir.DictAttr.get({"dim0": mapping_attr}) + + result = _convert_index_mapping_dict_to_sympy(dict_attr) + + assert isinstance(result, dict) + assert index_symbol("dim0") in result + assert isinstance(result[index_symbol("dim0")], IndexSequence) + assert result[index_symbol("dim0")].start == index_symbol("M") + assert result[index_symbol("dim0")].size == 16 + + def test_multiple_mappings(self): + """Test conversion of dict with multiple index mappings.""" + # Create first mapping + symbols1 = [wave.WaveSymbolAttr.get("M")] + s0 = ir.AffineSymbolExpr.get(0) + mapping1 = wave.WaveIndexMappingAttr.get( + symbols1, + ir.AffineMap.get(0, 1, [s0]), + ir.AffineMap.get(0, 1, [ir.AffineConstantExpr.get(16)]), + ir.AffineMap.get(0, 1, [ir.AffineConstantExpr.get(1)]), + ) + + # Create second mapping + symbols2 = [wave.WaveSymbolAttr.get("N")] + mapping2 = wave.WaveIndexMappingAttr.get( + symbols2, + ir.AffineMap.get(0, 1, [s0 * 2]), + ir.AffineMap.get(0, 1, [ir.AffineConstantExpr.get(32)]), + ir.AffineMap.get(0, 1, [ir.AffineConstantExpr.get(2)]), + ) + + dict_attr = ir.DictAttr.get({"m": mapping1, "n": mapping2}) + result = _convert_index_mapping_dict_to_sympy(dict_attr) + + assert len(result) == 2 + assert index_symbol("m") in result + assert index_symbol("n") in result + assert result[index_symbol("m")].size == 16 + assert result[index_symbol("n")].size == 32 + + +class TestMakePiecewiseSequence: + """Tests for _make_piecewise_sequence function.""" + + def test_two_component_piecewise(self): + """Test piecewise sequence with two components.""" + # Create two index sequences + seq1 = IndexSequence(start=0, size=10, stride=1) + seq2 = IndexSequence(start=100, size=20, stride=2) + + # Create conditions + cond1 = sympy.Symbol("x") < 5 + cond2 = sympy.Symbol("x") >= 5 + + # Create piecewise sequence + result = _make_piecewise_sequence((seq1, cond1), (seq2, cond2)) + + assert isinstance(result, IndexSequence) + assert isinstance(result.start, sympy.Piecewise) + assert isinstance(result.size, sympy.Piecewise) + assert isinstance(result.stride, sympy.Piecewise) + + def test_single_component_piecewise(self): + """Test piecewise sequence with single component.""" + seq = IndexSequence(start=5, size=15, stride=3) + cond = sympy.Symbol("flag") + + result = _make_piecewise_sequence((seq, cond)) + + assert isinstance(result, IndexSequence) + assert isinstance(result.start, sympy.Piecewise) + assert isinstance(result.size, sympy.Piecewise) + assert isinstance(result.stride, sympy.Piecewise) + + +class TestConvertIndexMappingArrayToSympy: + """Tests for convert_index_mapping_array_to_sympy function.""" + + def test_non_mma_op_single_mapping(self): + """Test conversion for non-MMA operations (expects single mapping).""" + # Create a simple mapping + symbols = [wave.WaveSymbolAttr.get("M")] + s0 = ir.AffineSymbolExpr.get(0) + mapping = wave.WaveIndexMappingAttr.get( + symbols, + ir.AffineMap.get(0, 1, [s0]), + ir.AffineMap.get(0, 1, [ir.AffineConstantExpr.get(16)]), + ir.AffineMap.get(0, 1, [ir.AffineConstantExpr.get(1)]), + ) + + dict_attr = ir.DictAttr.get({"dim": mapping}) + array_attr = ir.ArrayAttr.get([dict_attr]) + + # We don't need anything from the operation except its name, so use an empty module. + dummy_op = ir.Operation.create("builtin.module", loc=ir.Location.unknown()) + result = convert_index_mapping_array_to_sympy(dummy_op, array_attr) + + assert isinstance(result, dict) + assert index_symbol("dim") in result + + def test_mma_op_with_valid_four_mappings(self): + """Test MMA operation with correct 4 mappings creates piecewise sequence.""" + # Create symbols for M, N, K dimensions + m_sym = wave.WaveSymbolAttr.get("M") + n_sym = wave.WaveSymbolAttr.get("N") + k_sym = wave.WaveSymbolAttr.get("K") + + s0 = ir.AffineSymbolExpr.get(0) + c16 = ir.AffineConstantExpr.get(16) + c1 = ir.AffineConstantExpr.get(1) + + # LHS mapping: M and K + lhs_m_mapping = wave.WaveIndexMappingAttr.get( + [m_sym], + ir.AffineMap.get(0, 1, [s0]), + ir.AffineMap.get(0, 1, [c16]), + ir.AffineMap.get(0, 1, [c1]), + ) + lhs_k_mapping = wave.WaveIndexMappingAttr.get( + [k_sym], + ir.AffineMap.get(0, 1, [s0]), + ir.AffineMap.get(0, 1, [c16]), + ir.AffineMap.get(0, 1, [c1]), + ) + + # RHS mapping: N and K (K must match LHS K) + rhs_n_mapping = wave.WaveIndexMappingAttr.get( + [n_sym], + ir.AffineMap.get(0, 1, [s0]), + ir.AffineMap.get(0, 1, [c16]), + ir.AffineMap.get(0, 1, [c1]), + ) + rhs_k_mapping = wave.WaveIndexMappingAttr.get( + [k_sym], + ir.AffineMap.get(0, 1, [s0]), + ir.AffineMap.get(0, 1, [c16]), + ir.AffineMap.get(0, 1, [c1]), + ) + + # Accumulator mapping: M and N + acc_m_mapping = wave.WaveIndexMappingAttr.get( + [m_sym], + ir.AffineMap.get(0, 1, [c16 - s0]), + ir.AffineMap.get(0, 1, [c1]), + ir.AffineMap.get(0, 1, [c16]), + ) + acc_n_mapping = wave.WaveIndexMappingAttr.get( + [n_sym], + ir.AffineMap.get(0, 1, [s0]), + ir.AffineMap.get(0, 1, [c16]), + ir.AffineMap.get(0, 1, [c1]), + ) + + # Note that result mapping is the same as the accumulator mapping. + lhs_dict = ir.DictAttr.get({"M": lhs_m_mapping, "K": lhs_k_mapping}) + rhs_dict = ir.DictAttr.get({"N": rhs_n_mapping, "K": rhs_k_mapping}) + acc_dict = ir.DictAttr.get({"M": acc_m_mapping, "N": acc_n_mapping}) + result_dict = ir.DictAttr.get({"M": acc_m_mapping, "N": acc_n_mapping}) + + array_attr = ir.ArrayAttr.get([lhs_dict, rhs_dict, acc_dict, result_dict]) + + # Create a mock MMA operation, we only need the name, it doesn't even need to verify correctly. + dummy_mma_op = ir.Operation.create("wave.mma", loc=ir.Location.unknown()) + result = convert_index_mapping_array_to_sympy(dummy_mma_op, array_attr) + + assert isinstance(result, dict) + assert len(result) == 3 # M, N, K + assert index_symbol("M") in result + assert index_symbol("N") in result + assert index_symbol("K") in result + + # M should be piecewise (combining LHS and ACC) + m_seq = result[index_symbol("M")] + assert isinstance(m_seq, IndexSequence) + # The start/size/stride should be Piecewise expressions + assert isinstance(m_seq.start, sympy.Piecewise) + assert isinstance(m_seq.size, sympy.Piecewise) + assert isinstance(m_seq.stride, sympy.Piecewise) + + assert m_seq.start == sympy.Piecewise( + (sym.M, ~index_symbol(MMA_ACC_SYMBOL_NAME)), + (sympy.sympify(16 - sym.M), index_symbol(MMA_ACC_SYMBOL_NAME)), + ) + assert m_seq.size == sympy.Piecewise( + (16, ~index_symbol(MMA_ACC_SYMBOL_NAME)), + (1, index_symbol(MMA_ACC_SYMBOL_NAME)), + ) + assert m_seq.stride == sympy.Piecewise( + (1, ~index_symbol(MMA_ACC_SYMBOL_NAME)), + (16, index_symbol(MMA_ACC_SYMBOL_NAME)), + ) diff --git a/tests/unittests/index_sequence_difference_test.py b/tests/unittests/index_sequence_difference_test.py new file mode 100644 index 000000000..7dae68a73 --- /dev/null +++ b/tests/unittests/index_sequence_difference_test.py @@ -0,0 +1,216 @@ +# Copyright 2025 The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pytest + +from wave_lang.kernel._support.indexing import IndexSequence, index_symbol +from wave_lang.kernel.wave.analysis.index_sequence_analysis import ( + _check_index_difference_is_zero, +) + + +def test_equal_index_sequences_returns_true(): + """Test that equal index sequences return True.""" + dim1 = index_symbol("dim1") + dim2 = index_symbol("dim2") + + index1 = { + dim1: IndexSequence(start=0, size=10, stride=1), + dim2: IndexSequence(start=5, size=20, stride=2), + } + + index2 = { + dim1: IndexSequence(start=0, size=10, stride=1), + dim2: IndexSequence(start=5, size=20, stride=2), + } + + result = _check_index_difference_is_zero(index1, index2) + assert result is True + + +def test_equal_symbolic_index_sequences_returns_true(): + """Test that symbolically equal index sequences return True.""" + dim1 = index_symbol("dim1") + M = index_symbol("M") + N = index_symbol("N") + + index1 = { + dim1: IndexSequence(start=M, size=N, stride=M + 1), + } + + index2 = { + dim1: IndexSequence(start=M, size=N, stride=M + 1), + } + + result = _check_index_difference_is_zero(index1, index2) + assert result is True + + +def test_different_start_raises_value_error(): + """Test that different start values raise ValueError.""" + dim1 = index_symbol("dim1") + + index1 = { + dim1: IndexSequence(start=0, size=10, stride=1), + } + + index2 = { + dim1: IndexSequence(start=5, size=10, stride=1), + } + + with pytest.raises(ValueError, match="Start difference"): + _check_index_difference_is_zero(index1, index2) + + +def test_different_size_raises_value_error(): + """Test that different size values raise ValueError.""" + dim1 = index_symbol("dim1") + + index1 = { + dim1: IndexSequence(start=0, size=10, stride=1), + } + + index2 = { + dim1: IndexSequence(start=0, size=20, stride=1), + } + + with pytest.raises(ValueError, match="Size difference"): + _check_index_difference_is_zero(index1, index2) + + +def test_different_stride_raises_value_error(): + """Test that different stride values raise ValueError.""" + dim1 = index_symbol("dim1") + + index1 = { + dim1: IndexSequence(start=0, size=10, stride=1), + } + + index2 = { + dim1: IndexSequence(start=0, size=10, stride=2), + } + + with pytest.raises(ValueError, match="Stride difference"): + _check_index_difference_is_zero(index1, index2) + + +def test_different_keys_returns_false(): + """Test that different dictionary keys return False.""" + dim1 = index_symbol("dim1") + dim2 = index_symbol("dim2") + + index1 = { + dim1: IndexSequence(start=0, size=10, stride=1), + } + + index2 = { + dim2: IndexSequence(start=0, size=10, stride=1), + } + + result = _check_index_difference_is_zero(index1, index2) + assert result is False + + +def test_empty_dictionaries_returns_true(): + """Test that empty dictionaries return True.""" + index1 = {} + index2 = {} + + result = _check_index_difference_is_zero(index1, index2) + assert result is True + + +def test_one_empty_one_not_returns_false(): + """Test that one empty and one non-empty dictionary returns False.""" + dim1 = index_symbol("dim1") + + index1 = {} + index2 = { + dim1: IndexSequence(start=0, size=10, stride=1), + } + + result = _check_index_difference_is_zero(index1, index2) + assert result is False + + +def test_symbolic_expressions_simplify_to_zero(): + """Test that symbolic expressions that simplify to zero are considered equal.""" + dim1 = index_symbol("dim1") + M = index_symbol("M") + N = index_symbol("N") + + index1 = { + dim1: IndexSequence(start=M + N, size=M * 2, stride=N - 1), + } + + index2 = { + dim1: IndexSequence(start=N + M, size=2 * M, stride=N - 1), + } + + result = _check_index_difference_is_zero(index1, index2) + assert result is True + + +def test_symbolic_expressions_do_not_simplify_to_zero(): + """Test that symbolic expressions that don't simplify to zero raise ValueError.""" + dim1 = index_symbol("dim1") + M = index_symbol("M") + N = index_symbol("N") + + index1 = { + dim1: IndexSequence(start=M, size=N, stride=1), + } + + index2 = { + dim1: IndexSequence(start=M + 1, size=N, stride=1), + } + + with pytest.raises(ValueError, match="Start difference"): + _check_index_difference_is_zero(index1, index2) + + +def test_multiple_dimensions_all_equal(): + """Test multiple dimensions where all are equal.""" + dim1 = index_symbol("dim1") + dim2 = index_symbol("dim2") + dim3 = index_symbol("dim3") + + index1 = { + dim1: IndexSequence(start=0, size=10, stride=1), + dim2: IndexSequence(start=5, size=20, stride=2), + dim3: IndexSequence(start=10, size=30, stride=3), + } + + index2 = { + dim1: IndexSequence(start=0, size=10, stride=1), + dim2: IndexSequence(start=5, size=20, stride=2), + dim3: IndexSequence(start=10, size=30, stride=3), + } + + result = _check_index_difference_is_zero(index1, index2) + assert result is True + + +def test_multiple_dimensions_one_different(): + """Test multiple dimensions where one differs.""" + dim1 = index_symbol("dim1") + dim2 = index_symbol("dim2") + dim3 = index_symbol("dim3") + + index1 = { + dim1: IndexSequence(start=0, size=10, stride=1), + dim2: IndexSequence(start=5, size=20, stride=2), + dim3: IndexSequence(start=10, size=30, stride=3), + } + + index2 = { + dim1: IndexSequence(start=0, size=10, stride=1), + dim2: IndexSequence(start=5, size=25, stride=2), # Different size + dim3: IndexSequence(start=10, size=30, stride=3), + } + + with pytest.raises(ValueError, match="Size difference"): + _check_index_difference_is_zero(index1, index2) diff --git a/tests/unittests/location_exception.py b/tests/unittests/location_exception.py index 1cde81731..3ddea77f4 100644 --- a/tests/unittests/location_exception.py +++ b/tests/unittests/location_exception.py @@ -3,7 +3,7 @@ from wave_lang.kernel._support.location import StackTraceInfo from wave_lang.kernel.lang.global_symbols import * from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile -from wave_lang.kernel.wave.water import is_water_available +from wave_lang.support.detect_water import is_water_available import pytest from iree.compiler.ir import Context diff --git a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp index 5a14643fd..012df8973 100644 --- a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp +++ b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp @@ -294,10 +294,21 @@ WaveIndexMappingAttr WaveIndexMappingAttr::removeUnusedInputs() const { newSymbols.push_back(symbol); } assert(newSymbols.size() == usedSymbolPositions.size()); - AffineMap start = getStart() ? getStart().replace(replacement) : AffineMap(); - AffineMap step = getStep() ? getStep().replace(replacement) : AffineMap(); + AffineMap start = + getStart() + ? getStart().replace(replacement, /*numResultDims=*/0, + /*numResultSyms=*/usedSymbolPositions.size()) + : AffineMap(); + AffineMap step = + getStep() + ? getStep().replace(replacement, /*numResultDims=*/0, + /*numResultSyms=*/usedSymbolPositions.size()) + : AffineMap(); AffineMap stride = - getStride() ? getStride().replace(replacement) : AffineMap(); + getStride() + ? getStride().replace(replacement, /*numResultDims=*/0, + /*numResultSyms=*/usedSymbolPositions.size()) + : AffineMap(); return get(getContext(), newSymbols, start, step, stride); } diff --git a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp index db6292632..3ff45b05d 100644 --- a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp +++ b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp @@ -1007,7 +1007,7 @@ wave::IndexExprsLatticeStorage::withoutIterSymbols( for (wave::WaveSymbolAttr iterSymbol : iterSymbols) { auto actualIterSymbol = wave::WaveIterSymbolAttr::get(ctx, iterSymbol.getName()); - value = value.removeInput(actualIterSymbol); + value = value.removeInput(actualIterSymbol).removeUnusedInputs(); } return mlir::NamedAttribute(attr.getName(), value); }); diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index 20a6eef47..e547bd9ca 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -726,8 +726,9 @@ struct MmaIndexingExprBuilder { }; auto buildOne = [&](const MmaSingleIndexExprBuilder &builder) { return wave::WaveIndexMappingAttr::get( - ctx, symbols, buildMap(builder.offsetExpr), - buildMap(builder.sizeExpr), buildMap(builder.strideExpr)); + ctx, symbols, buildMap(builder.offsetExpr), + buildMap(builder.sizeExpr), buildMap(builder.strideExpr)) + .removeUnusedInputs(); }; if (mSymbol) diff --git a/water/python/CMakeLists.txt b/water/python/CMakeLists.txt index 1053c11fa..49bce8dd5 100644 --- a/water/python/CMakeLists.txt +++ b/water/python/CMakeLists.txt @@ -36,6 +36,8 @@ if (WATER_ENABLE_PYTHON) WaterExtensionNanobind.cpp EMBED_CAPI_LINK_LIBS WaterCAPI + PRIVATE_LINK_LIBS + LLVMSupport ) add_mlir_python_common_capi_library(WaterPythonCAPI @@ -63,4 +65,29 @@ if (WATER_ENABLE_PYTHON) COMMON_CAPI_LINK_LIBS WaterPythonCAPI ) + + if (UNIX) + target_compile_options(WaterPythonCAPI + PRIVATE "-fvisibility=hidden" + ) + target_link_options(WaterPythonCAPI + PRIVATE "-fvisibility=hidden" + ) + if (APPLE) + # See https://github.com/numba/numba-mlir/blob/fd9609d81dab96b0658392886ebc470ea2fa4eeb/numba_mlir/numba_mlir/mlir_compiler/CMakeLists.txt#L71 + message(SEND_ERROR "symbol management not implemented on Darwin") + else() + target_link_options(WaterPythonCAPI + PRIVATE + "-Wl,--no-undefined" + "-Wl,-z,defs" + "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/capi.lds" + ) + set_target_properties(WaterPythonCAPI + PROPERTIES + LINK_DEPENDS + "${CMAKE_CURRENT_SOURCE_DIR}/capi.lds" + ) + endif() + endif() endif() diff --git a/water/python/capi.lds b/water/python/capi.lds new file mode 100644 index 000000000..c87461699 --- /dev/null +++ b/water/python/capi.lds @@ -0,0 +1,8 @@ +{ + global: + # Bindings symbols should remain visible. + mlir*; + local: + # Hide everything else. + *; +}; diff --git a/water/test/Dialect/Wave/infer-index-exprs.mlir b/water/test/Dialect/Wave/infer-index-exprs.mlir index d730a1094..88bef9d4d 100644 --- a/water/test/Dialect/Wave/infer-index-exprs.mlir +++ b/water/test/Dialect/Wave/infer-index-exprs.mlir @@ -63,20 +63,20 @@ module attributes { wave.normal_form = #wave.normal_form } { ]} { // CHECK: wave.mma // Left-hand side - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) // CHECK: }, { // Right-hand side - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: }, { // Accumulator - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: }, { // Result (matches the accumulator) - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: } wave.mma %a, %b, %c {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@N, @K] of f16>, !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32> @@ -96,35 +96,35 @@ module attributes { wave.normal_form = #wave.normal_form } { waves_per_block = [2, 3, 4]> ]} { // CHECK: wave.read - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) %a_read = wave.read %a : (!wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16, > // CHECK: wave.read - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) %b_read = wave.read %b : (!wave.tensor<[@N, @K] of f16>) -> !wave.tensor<[@N, @K] of f16, > %cst = arith.constant 0.0 : f32 // CHECK: wave.register - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) %c_reg = wave.register %cst : !wave.tensor<[@M, @N] of f32, > // CHECK: wave.mma - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) // CHECK: }, { - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) %mma = wave.mma %a_read, %b_read, %c_reg {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > // CHECK: wave.write - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) wave.write %mma, %c : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32> return } @@ -148,81 +148,81 @@ module attributes { wave.normal_form = #wave.normal_form } { waves_per_block = [1, 2, 2]> ]} { // CHECK: wave.read - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 16, 1, 1) %a_read = wave.read %a : (!wave.tensor<[@M, @K] of f16>) -> !wave.tensor<[@M, @K] of f16, > // CHECK: wave.read - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) %b_read = wave.read %b : (!wave.tensor<[@N, @K] of f16>) -> !wave.tensor<[@N, @K] of f16, > %cst_0 = arith.constant 0.0 : f32 // CHECK: wave.register - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) %c_reg = wave.register %cst_0 : !wave.tensor<[@M, @N] of f32, > // CHECK: wave.mma - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: }, { - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) %mma1 = wave.mma %a_read, %b_read, %c_reg {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16, >, !wave.tensor<[@N, @K] of f16, >, !wave.tensor<[@M, @N] of f32, >) -> !wave.tensor<[@M, @N] of f32, > // CHECK: wave.cast - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) %mma1_casted = wave.cast %mma1 : !wave.tensor<[@M, @N] of f32, > to !wave.tensor<[@M, @N] of f16, > // CHECK: wave.write - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) wave.write %mma1_casted, %storage : !wave.tensor<[@M, @N] of f16, >, !wave.tensor<[@M, @N] of f16> // CHECK: wave.read - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) %reloaded = wave.read %storage : (!wave.tensor<[@M, @N] of f16>) -> !wave.tensor<[@M, @N] of f16, > // Second read and register // CHECK: wave.read - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) - // CHECK-DAG: P : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: P : [#wave.index_symbol] -> (T0 mod 16, 1, 1) %d_read = wave.read %d : (!wave.tensor<[@P, @N] of f16>) -> !wave.tensor<[@P, @N] of f16, > %cst_1 = arith.constant 0.0 : f32 // CHECK: wave.register - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: P : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: P : [#wave.index_symbol] -> (T0 mod 16, 1, 1) %c_reg2 = wave.register %cst_1 : !wave.tensor<[@M, @P] of f32, > // CHECK: wave.mma - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) // CHECK: }, { - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) - // CHECK-DAG: P : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 1) + // CHECK-DAG: P : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: P : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: P : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: P : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: P : [#wave.index_symbol] -> (T0 mod 16, 1, 1) %mma2 = wave.mma %reloaded, %d_read, %c_reg2 {kind = #wave.mma_kind} : (!wave.tensor<[@M, @N] of f16, >, !wave.tensor<[@P, @N] of f16, >, !wave.tensor<[@M, @P] of f32, >) -> !wave.tensor<[@M, @P] of f32, > // CHECK: wave.write - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: P : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: P : [#wave.index_symbol] -> (T0 mod 16, 1, 1) wave.write %mma2, %c : !wave.tensor<[@M, @P] of f32, >, !wave.tensor<[@M, @P] of f32> return } @@ -240,17 +240,17 @@ module attributes { wave.normal_form = #wave.normal_form } { waves_per_block = [2, 3, 4]> ]} { // CHECK: wave.mma - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 4, 4, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 4, 4, 1) // CHECK: }, { - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 4, 4, 1) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 4, 4, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 32, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 32, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 32, 1, 1) wave.mma %a, %b, %c {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@N, @K] of f16>, !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32> return @@ -269,17 +269,17 @@ module attributes { wave.normal_form = #wave.normal_form } { waves_per_block = [2, 3, 4]> ]} { // CHECK: wave.mma - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 8, 8, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 8, 8, 1) // CHECK: }, { - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 8, 8, 1) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 8, 8, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) wave.mma %a, %b, %c {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@N, @K] of f16>, !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32> return @@ -298,17 +298,17 @@ module attributes { wave.normal_form = #wave.normal_form } { waves_per_block = [2, 3, 4]> ]} { // CHECK: wave.mma - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> ((GPR_NUM floordiv 4) * 16 + ((T0 mod 64) floordiv 16) * 4 + GPR_NUM mod 4, 8, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol] -> ((GPR_NUM floordiv 4) * 16 + ((T0 mod 64) floordiv 16) * 4 + GPR_NUM mod 4, 8, 1) // CHECK: }, { - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> ((GPR_NUM floordiv 4) * 16 + ((T0 mod 64) floordiv 16) * 4 + GPR_NUM mod 4, 8, 1) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol] -> ((GPR_NUM floordiv 4) * 16 + ((T0 mod 64) floordiv 16) * 4 + GPR_NUM mod 4, 8, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 16, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (((T0 mod 64) floordiv 16) * 4, 4, 16) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 16, 1, 1) wave.mma %a, %b, %c {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f8E5M2>, !wave.tensor<[@N, @K] of f8E5M2>, !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32> return @@ -327,17 +327,17 @@ module attributes { wave.normal_form = #wave.normal_form } { waves_per_block = [2, 3, 4]> ]} { // CHECK: wave.mma - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 8, 8, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 8, 8, 1) // CHECK: }, { - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 8, 8, 1) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 8, 8, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 32, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 32, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 32, 1, 1) wave.mma %a, %b, %c {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f16>, !wave.tensor<[@N, @K] of f16>, !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32> return @@ -356,17 +356,17 @@ module attributes { wave.normal_form = #wave.normal_form } { waves_per_block = [2, 3, 4]> ]} { // CHECK: wave.mma - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> ((GPR_NUM floordiv 4) * 8 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 8, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol] -> ((GPR_NUM floordiv 4) * 8 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 8, 1) // CHECK: }, { - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> ((GPR_NUM floordiv 4) * 8 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 8, 1) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol] -> ((GPR_NUM floordiv 4) * 8 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 8, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 32, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 32, 1, 1) // CHECK: }, { - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 32, 1, 1) wave.mma %a, %b, %c {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of f8E5M2>, !wave.tensor<[@N, @K] of f8E5M2>, !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32> return @@ -389,31 +389,31 @@ module attributes { wave.normal_form = #wave.normal_form } { %0 = arith.constant 0.0 : f32 // CHECK: wave.register - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 32, 1, 1) %c_reg = wave.register %0 : !wave.tensor<[@M, @N] of f32> // CHECK: wave.iterate // CHECK-SAME: iter_args // CHECK-SAME: index - // CHECK-DAG: M = #wave, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32)> - // CHECK-DAG: N = #wave, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1)> + // CHECK-DAG: M = #wave, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32)> + // CHECK-DAG: N = #wave] -> (T0 mod 32, 1, 1)> %mma_result = wave.iterate @K iter_args(%c_reg) { ^bb0(%acc: !wave.tensor<[@M, @N] of f32>): // CHECK: wave.read - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 8, 8, 1) - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 8, 8, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 32, 1, 1) %a_reg = wave.read %a : (!wave.tensor<[@M, @K] of bf16, >) -> !wave.tensor<[@M, @K] of bf16> // CHECK: wave.read - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 8, 8, 1) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 8, 8, 1) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 32, 1, 1) %b_reg = wave.read %b : (!wave.tensor<[@N, @K] of bf16, >) -> !wave.tensor<[@N, @K] of bf16> // CHECK: wave.mma - // CHECK-DAG: K : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 8, 8, 1) - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: K : [#wave.index_symbol] -> (((T0 mod 64) floordiv 32) * 8, 8, 1) + // CHECK-DAG: M : [#wave.index_symbol] -> (T0 mod 32, 1, 1) %inner_acc = wave.mma %a_reg, %b_reg, %acc {kind = #wave.mma_kind} : (!wave.tensor<[@M, @K] of bf16>, !wave.tensor<[@N, @K] of bf16>, !wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M, @N] of f32> @@ -422,8 +422,8 @@ module attributes { wave.normal_form = #wave.normal_form } { } : (!wave.tensor<[@M, @N] of f32>)-> (!wave.tensor<[@M, @N] of f32>) // CHECK: wave.write - // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) - // CHECK-DAG: N : [#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T0 mod 32, 1, 1) + // CHECK-DAG: M : [#wave.index_symbol, #wave.index_symbol] -> (((GPR_NUM floordiv 4) * 8) mod 32 + ((T0 mod 64) floordiv 32) * 4 + GPR_NUM mod 4, 16, 32) + // CHECK-DAG: N : [#wave.index_symbol] -> (T0 mod 32, 1, 1) wave.write %mma_result, %c : !wave.tensor<[@M, @N] of f32> , !wave.tensor<[@M, @N] of f32, > return diff --git a/wave_lang/kernel/_support/indexing.py b/wave_lang/kernel/_support/indexing.py index 6c9173b4f..2a9e95535 100644 --- a/wave_lang/kernel/_support/indexing.py +++ b/wave_lang/kernel/_support/indexing.py @@ -2,12 +2,13 @@ import copy from abc import ABC from dataclasses import dataclass -from typing import Any, ClassVar, Optional, Type, TypeAlias, TypeVar, Union +from typing import Any, ClassVar, Optional, Type, TypeVar, Union import sympy from . import context, dtype from .shaped_type import ShapedType +from ...support.indexing import * __all__ = [ "backed_sym_index_type", @@ -33,32 +34,6 @@ class NotSetType: ... SubtypeT = TypeVar("SubtypeT") -############################################################################### -# Index symbols and expressions -# These are just light-weight helpers around sympy symbols and expressions. -############################################################################### - -IndexSymbol: TypeAlias = sympy.Symbol -IndexExpr: TypeAlias = sympy.Expr - - -def index_symbol(name: str) -> IndexSymbol: - """Returns a named symbol, assumed to be a non-negative integer.""" - return sympy.Symbol(name, integer=True, nonnegative=True) - - -def index_expr(value: Any) -> IndexExpr: - expr = sympy.sympify(value) - return expr - - -class _IndexSymbolExpando: - def __getattr__(self, n) -> IndexSymbol: - return index_symbol(n) - - -sym = _IndexSymbolExpando() - class xor(sympy.Function): pass @@ -475,49 +450,3 @@ def backed_sym_index_type(assumption: IndexRelation) -> Type[SymIndex]: class BackedSymIndex(SymIndex, assumption=assumption): ... return BackedSymIndex - - -@dataclass -class IndexSequence: - start: IndexExpr | int - size: IndexExpr | int - stride: IndexExpr | int = 1 - - @staticmethod - def _subs( - value: int | IndexExpr, - map: dict[IndexExpr, IndexExpr], - simultaneous: bool = False, - ) -> int | IndexExpr: - if isinstance(value, (sympy.Basic, IndexSequence)): - return value.subs(map, simultaneous=simultaneous) # type: ignore - return value - - def has(self, symbol: IndexSymbol) -> bool: - return ( - sympy.sympify(self.start).has(symbol) - or sympy.sympify(self.size).has(symbol) - or sympy.sympify(self.stride).has(symbol) - ) - - def subs(self, map: dict[IndexExpr, IndexExpr], simultaneous: bool = False): - start = self._subs(self.start, map, simultaneous) - size = self._subs(self.size, map, simultaneous) - stride = self._subs(self.stride, map, simultaneous) - return IndexSequence(start, size, stride) - - @staticmethod - def from_expr(expr: IndexExpr, subs: dict[IndexExpr, Any]): - start_subs = {k: v.start for k, v in subs.items()} - size_subs = {k: v.size for k, v in subs.items()} - stride_subs = {k: v.stride for k, v in subs.items()} - start = IndexSequence._subs(expr, start_subs) - size = IndexSequence._subs(expr, size_subs) - stride = IndexSequence._subs(expr, stride_subs) - return IndexSequence(start, size, stride) - - def __repr__(self) -> str: - return f"{self.start} : {self.size} : {self.stride}" - - def __hash__(self): - return hash((self.start, self.size, self.stride)) diff --git a/wave_lang/kernel/lang/global_symbols.py b/wave_lang/kernel/lang/global_symbols.py index a65ed2fd1..0da42050a 100644 --- a/wave_lang/kernel/lang/global_symbols.py +++ b/wave_lang/kernel/lang/global_symbols.py @@ -1,6 +1,13 @@ import sympy -from .._support.indexing import index_symbol +from .._support.indexing import ( + index_symbol, + MMA_ACC_SYMBOL_NAME, + THREAD_SYMBOL_NAMES, + WORKGROUP_SYMBOL_NAMES, + DEVICE_SYMBOL_NAMES, + GPR_SYMBOL_NAME, +) # Global symbols used throughout the code. @@ -10,28 +17,29 @@ # Device Distribution symbols. # TODO: Can only do three dimensions for now. -DEVICE_DIM_0 = index_symbol("$DD0") -DEVICE_DIM_1 = index_symbol("$DD1") -DEVICE_DIM_2 = index_symbol("$DD2") +DEVICE_DIM_0 = index_symbol(DEVICE_SYMBOL_NAMES[0]) +DEVICE_DIM_1 = index_symbol(DEVICE_SYMBOL_NAMES[1]) +DEVICE_DIM_2 = index_symbol(DEVICE_SYMBOL_NAMES[2]) # Distribution symbols. -WORKGROUP_0 = index_symbol("$WG0") -WORKGROUP_1 = index_symbol("$WG1") -WORKGROUP_2 = index_symbol("$WG2") +WORKGROUP_0 = index_symbol(WORKGROUP_SYMBOL_NAMES[0]) +WORKGROUP_1 = index_symbol(WORKGROUP_SYMBOL_NAMES[1]) +WORKGROUP_2 = index_symbol(WORKGROUP_SYMBOL_NAMES[2]) def get_workgroup_symbol(i: int): assert i >= 0, "Workgroup index must be non-negative." symbol_name = f"WORKGROUP_{i}" + symbol = index_symbol(WORKGROUP_SYMBOL_NAMES[i] if i < 3 else "$WG" + str(i)) if symbol_name not in globals(): - globals()[symbol_name] = index_symbol(f"$WG{i}") - return index_symbol(f"$WG{i}") + globals()[symbol_name] = symbol + return symbol -THREAD_0 = index_symbol("$T0") -THREAD_1 = index_symbol("$T1") -THREAD_2 = index_symbol("$T2") +THREAD_0 = index_symbol(THREAD_SYMBOL_NAMES[0]) +THREAD_1 = index_symbol(THREAD_SYMBOL_NAMES[1]) +THREAD_2 = index_symbol(THREAD_SYMBOL_NAMES[2]) # Input selector symbol for selecting input from different tensors. INPUT_SELECTOR = index_symbol("$INPUT_SELECTOR") @@ -39,11 +47,11 @@ def get_workgroup_symbol(i: int): # MMA symbols. MMA_LHS = index_symbol("$MMA_LHS") MMA_RHS = index_symbol("$MMA_RHS") -MMA_ACC = index_symbol("$MMA_ACC") +MMA_ACC = index_symbol(MMA_ACC_SYMBOL_NAME) MMA_LHS_SCALE = index_symbol("$MMA_LHS_SCALE") MMA_RHS_SCALE = index_symbol("$MMA_RHS_SCALE") MMA_SCALE_FP4 = index_symbol("$MMA_SCALE_FP4") -GPR_NUM = index_symbol("$GPR_NUM") +GPR_NUM = index_symbol(GPR_SYMBOL_NAME) # Scheduling symbols. READ_SHARED_DELAY = index_symbol("$READ_SHARED_DELAY") diff --git a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py index 6c33b7925..0fd1d9571 100644 --- a/wave_lang/kernel/wave/analysis/index_sequence_analysis.py +++ b/wave_lang/kernel/wave/analysis/index_sequence_analysis.py @@ -14,6 +14,8 @@ import wave_lang.kernel.lang as tkl from wave_lang.kernel._support.dtype import DataType +from wave_lang.kernel.wave.mlir_converter.mlir_converter import emit_wave_dialect +from wave_lang.kernel.wave.compile_options import WaveCompileOptions from wave_lang.support.logging import get_logger from ..._support.indexing import IndexSequence, IndexSymbol @@ -260,6 +262,104 @@ def verify_nodes(trace: CapturedTrace, constraints: list[Constraint]): ), f"Vector shapes not set for node {custom.fx_node}: {custom}" +def _set_water_id(trace: CapturedTrace): + """Set a unique identifier for each node in the trace.""" + for node in trace.walk(lambda x: x): + setattr(node, "_water_id", str(id(node))) + + +def _reset_water_id(trace: CapturedTrace): + """Remove the previously set unique identifier for each node in the trace.""" + for node in trace.walk(lambda x: x): + delattr(node, "_water_id") + + +def _check_index_difference_is_zero( + index1: dict[IndexSymbol, IndexSequence], index2: dict[IndexSymbol, IndexSequence] +) -> bool: + """Check if two index sequences are equal, raise assertions if not.""" + + def f(seq1: IndexSequence, seq2: IndexSequence) -> bool: + start = sympy.simplify(seq1.start - seq2.start) + size = sympy.simplify(seq1.size - seq2.size) + stride = sympy.simplify(seq1.stride - seq2.stride) + if start != 0: + raise ValueError(f"Start difference: {start}") + if size != 0: + raise ValueError(f"Size difference: {size}") + if stride != 0: + raise ValueError(f"Stride difference: {stride}") + return True + + return index1.keys() == index2.keys() and all( + f(seq, index2[dim]) for dim, seq in index1.items() + ) + + +def _check_water_indices(trace: CapturedTrace, inferred: dict[str, IndexSequence]): + """Check that the indices for each node in the trace match the water-inferred indices. + + Expects unique identifiers to be set on each node in the trace and uses + those to find the index inferred by Water. + """ + for node in trace.walk(lambda x: x): + water_id = getattr(node, "_water_id") + custom = get_custom(node) + if isinstance(custom, (Placeholder, Output)): + continue + if water_id not in inferred: + raise RuntimeError( + f"Node {get_custom(node)} with id {water_id} not found in water-inferred index expressions." + ) + inferred_index = inferred[water_id].get("index", None) + if not getattr(node, "index", None): + assert isinstance( + custom, NestedRegionOp + ), "Index may only be missing for NestedRegionOps." + continue + # Skip GetResult because they are special-cased in Python propagation, + # making them have incorrect indexes in dataflow sense. + if isinstance(custom, GetResult): + continue + + # Check that that indices match, raise an error if they don't. Start by + # a trivial direct comparison, fall back to computing and simplifying + # the difference. The latter can raise with additional information, + # which this wants to preserve. + try: + if node.index != inferred_index and not _check_index_difference_is_zero( + node.index, inferred_index + ): + raise ValueError("mismatching indices") + except ValueError as e: + raise RuntimeError( + f"Index for node {get_custom(node)}, {get_custom(node).index} does not match inferred index {inferred_index}." + ) from e + + +def set_node_indices_water_checked( + trace: CapturedTrace, + constraints: list[Constraint], + options: WaveCompileOptions, + print_ir_before: Sequence[str] = [], + print_ir_after: Sequence[str] = [], +): + """Set the indices for each note in the trace and checks whether water infers the same indices. + + For now, indices inferred by water are discarded after comparison. Raises + errors if the indices do not match or if there was a problem communicating + with water. + """ + + _set_water_id(trace) + _, diagnostics, inferred_attributes = emit_wave_dialect(trace, constraints, options) + if diagnostics: + raise RuntimeError(f"Water indices check failed: {diagnostics}") + set_node_indices(trace, constraints, print_ir_before, print_ir_after) + _check_water_indices(trace, inferred_attributes) + _reset_water_id(trace) + + def set_node_indices( trace: CapturedTrace, constraints: list[Constraint], diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 584aef4dd..d2cb4cac8 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -34,6 +34,7 @@ # Passes from .analysis.index_sequence_analysis import ( set_node_indices, + set_node_indices_water_checked, set_post_expansion_indices, ) from .analysis.partition_strided_operators import ( @@ -447,41 +448,58 @@ def finalize_indices(): def substitute_vector_shapes(): launchable.hardware_constraints[0].subs_vector_shapes(idxc.subs) - return [ - partial(debug_log_hoist, trace, debug_handlers), - partial(initialize_iter_args, trace), - partial(launchable.create_induction_vars, trace), - partial(launchable.initialize_reductions, trace), - finalize_indices, - substitute_vector_shapes, - partial(add_get_results, trace), - partial(infer_types, trace, launchable.constraints), - partial(construct_index_mapping, trace, launchable.constraints), - partial( - debug_log_write_replace, - trace, - launchable.constraints, - options, - debug_arg_info, - ), - partial( - promote_placeholders, - trace, - launchable.constraints, - options.reorder_allocs, - ), - partial( - set_node_indices, - trace, - launchable.constraints, - print_ir_before, - print_ir_after, - ), - partial(reorder_workgroups, trace, launchable.reordering_constraints), - partial(expand_graph, trace, launchable.constraints), - partial(set_post_expansion_indices, trace, launchable.constraints), - partial(remove_chained_getresult, trace), - ] + return ( + [ + partial(debug_log_hoist, trace, debug_handlers), + partial(initialize_iter_args, trace), + partial(launchable.create_induction_vars, trace), + partial(launchable.initialize_reductions, trace), + finalize_indices, + substitute_vector_shapes, + partial(add_get_results, trace), + partial(infer_types, trace, launchable.constraints), + partial(construct_index_mapping, trace, launchable.constraints), + partial( + debug_log_write_replace, + trace, + launchable.constraints, + options, + debug_arg_info, + ), + partial( + promote_placeholders, + trace, + launchable.constraints, + options.reorder_allocs, + ), + ] + + ( + [ + partial( + set_node_indices_water_checked, + trace, + launchable.constraints, + options, + ) + ] + if options.check_water_analysis + else [ + partial( + set_node_indices, + trace, + launchable.constraints, + print_ir_before, + print_ir_after, + ) + ] + ) + + [ + partial(reorder_workgroups, trace, launchable.reordering_constraints), + partial(expand_graph, trace, launchable.constraints), + partial(set_post_expansion_indices, trace, launchable.constraints), + partial(remove_chained_getresult, trace), + ] + ) def _rewrite_module_for_iree_stream_abi( diff --git a/wave_lang/kernel/wave/compile_options.py b/wave_lang/kernel/wave/compile_options.py index f0f1e736c..1958cea66 100644 --- a/wave_lang/kernel/wave/compile_options.py +++ b/wave_lang/kernel/wave/compile_options.py @@ -75,6 +75,7 @@ class WaveCompileOptions: ) use_local_scope: bool = False use_water_leak_check: bool | str = False # If string, check the given IR instead. + check_water_analysis: bool = False enforce_locations: bool = True drop_debug_info_before_mlir: bool = True @@ -115,6 +116,8 @@ class WaveCompileOptions: print_signature: bool = False print_mlir: bool = False print_mlir_file: Optional[str] = None + print_mlir_before_water: bool = False + print_mlir_after_water: bool = False print_pass_times: bool = False # === ASM backend options === diff --git a/wave_lang/kernel/wave/mlir_converter/mlir_converter.py b/wave_lang/kernel/wave/mlir_converter/mlir_converter.py index 71b67890d..358b9b49f 100644 --- a/wave_lang/kernel/wave/mlir_converter/mlir_converter.py +++ b/wave_lang/kernel/wave/mlir_converter/mlir_converter.py @@ -1,4 +1,4 @@ -# Copyright 2025 The IREE Authors +# Copyright 2025 The Wave Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. @@ -20,6 +20,7 @@ import subprocess import sys from pathlib import Path +from typing import Any import dill from wave_lang.kernel._support.tracing import CapturedTrace from wave_lang.kernel.wave.compile_options import WaveCompileOptions @@ -30,10 +31,9 @@ def emit_wave_dialect( trace: CapturedTrace, constraints: list[Constraint], options: WaveCompileOptions, - *, test_diagnostic_emission: bool = False, pipeline: str = "", -) -> tuple[str, list[str]]: +) -> tuple[str, list[str], dict[str, dict[str, Any]]]: """Emit Wave MLIR by sending the pickled trace and options to the emitter. The `subs` field of options is the only option used during emission. If @@ -65,6 +65,20 @@ def emit_wave_dialect( stderr=subprocess.PIPE, ) + assert ( + not options.check_water_analysis or not pipeline + ), "Cannot check water analysis and use a pipeline" + if options.check_water_analysis: + pipeline = """ +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + %0 = transform.apply_registered_pass "water-wave-detect-normal-forms" to %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.apply_registered_pass "water-wave-propagate-defaults-from-constraints" to %0 : (!transform.any_op) -> !transform.any_op + transform.apply_registered_pass "water-wave-infer-index-exprs" to %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +}""" + output, err = proc.communicate( dill.dumps( { @@ -93,9 +107,16 @@ def emit_wave_dialect( ) from e diagnostics = unpickled.get("diagnostics") if isinstance(unpickled, dict) else None module = unpickled.get("module") if isinstance(unpickled, dict) else None + inferred_attributes = ( + unpickled.get("inferred_attributes") if isinstance(unpickled, dict) else None + ) # Preserve stderr messages. if err: print(err.decode("utf-8", errors="replace"), file=sys.stderr) - return module.decode("utf-8"), [d.decode("utf-8") for d in diagnostics] + return ( + module.decode("utf-8"), + [d.decode("utf-8") for d in diagnostics], + inferred_attributes, + ) diff --git a/wave_lang/kernel/wave/mlir_converter/mlir_to_wave.py b/wave_lang/kernel/wave/mlir_converter/mlir_to_wave.py new file mode 100644 index 000000000..6ea312e31 --- /dev/null +++ b/wave_lang/kernel/wave/mlir_converter/mlir_to_wave.py @@ -0,0 +1,220 @@ +# Copyright 2025 The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import sys + +if "iree" in sys.modules: + raise ImportError( + "Must not import this module when IREE is loaded. This leads to clashes between copies of MLIR bindings." + ) + +import sympy # type: ignore +from typing import Sequence +from water_mlir.water_mlir import ir +from water_mlir.water_mlir.dialects import wave + +# This is fine since it doesn't depend on IREE transitively. +from wave_lang.support.indexing import ( + MMA_ACC_SYMBOL_NAME, + IndexSequence, + index_symbol, + IndexSymbol, +) + +assert ( + "iree" not in sys.modules +), "IREE was loaded transitively by modules. This should not happen." + + +ITER_SYMBOL_NAME_WAVE_PREFIX = "$ARG" +ITER_SYMBOL_NAME_WATER_PREFIX = "_Iter_" + +# Mapping of special symbol names to WaveIndexSymbol enum values. +INDEX_SYMBOL_MAP: dict[str, wave.WaveIndexSymbol] = { + "$WG0": wave.WaveIndexSymbol.WORKGROUP_0, + "$WG1": wave.WaveIndexSymbol.WORKGROUP_1, + "$WG2": wave.WaveIndexSymbol.WORKGROUP_2, + "$T0": wave.WaveIndexSymbol.THREAD_0, + "$T1": wave.WaveIndexSymbol.THREAD_1, + "$T2": wave.WaveIndexSymbol.THREAD_2, + "$DD0": wave.WaveIndexSymbol.DEVICE_DIM_0, + "$DD1": wave.WaveIndexSymbol.DEVICE_DIM_1, + "$DD2": wave.WaveIndexSymbol.DEVICE_DIM_2, + "$GPR_NUM": wave.WaveIndexSymbol.GPR_NUMBER, +} +INDEX_SYMBOL_REVERSE_MAP: dict[wave.WaveIndexSymbol, str] = { + value: key for key, value in INDEX_SYMBOL_MAP.items() +} + + +def _convert_affine_expr_to_sympy_expr( + expr: ir.AffineExpr, + symbol_mapping: Sequence[sympy.Symbol], +) -> sympy.Expr: + """Convert an MLIR AffineExpr to a sympy expression. + + Args: + expr: The MLIR AffineExpr to convert. + symbol_mapping: A list of sympy symbols co-indexed with the positional + affine symbols in the MLIR AffineExpr. + + Returns: + The sympy expression corresponding to the MLIR AffineExpr. + + Raises: + ValueError: If the expression is not supported. + """ + if ir.AffineConstantExpr.isinstance(expr): + return sympy.Integer(ir.AffineConstantExpr(expr).value) + if ir.AffineSymbolExpr.isinstance(expr): + return symbol_mapping[ir.AffineSymbolExpr(expr).position] + if ir.AffineAddExpr.isinstance(expr): + add_expr = ir.AffineAddExpr(expr) + return _convert_affine_expr_to_sympy_expr( + add_expr.lhs, symbol_mapping + ) + _convert_affine_expr_to_sympy_expr(add_expr.rhs, symbol_mapping) + if ir.AffineMulExpr.isinstance(expr): + mul_expr = ir.AffineMulExpr(expr) + return _convert_affine_expr_to_sympy_expr( + mul_expr.lhs, symbol_mapping + ) * _convert_affine_expr_to_sympy_expr(mul_expr.rhs, symbol_mapping) + if ir.AffineFloorDivExpr.isinstance(expr): + floor_div_expr = ir.AffineFloorDivExpr(expr) + return sympy.floor( + _convert_affine_expr_to_sympy_expr(floor_div_expr.lhs, symbol_mapping) + / _convert_affine_expr_to_sympy_expr(floor_div_expr.rhs, symbol_mapping) + ) + if ir.AffineCeilDivExpr.isinstance(expr): + ceil_div_expr = ir.AffineCeilDivExpr(expr) + return sympy.ceiling( + _convert_affine_expr_to_sympy_expr(ceil_div_expr.lhs, symbol_mapping) + / _convert_affine_expr_to_sympy_expr(ceil_div_expr.rhs, symbol_mapping) + ) + if ir.AffineModExpr.isinstance(expr): + mod_expr = ir.AffineModExpr(expr) + return _convert_affine_expr_to_sympy_expr( + mod_expr.lhs, symbol_mapping + ) % _convert_affine_expr_to_sympy_expr(mod_expr.rhs, symbol_mapping) + raise ValueError(f"Unsupported affine expression: {expr} of type {type(expr)}") + + +def _convert_index_mapping_attr_to_sympy( + attr: wave.WaveIndexMappingAttr, +) -> IndexSequence: + """Convert a WaveIndexMappingAttr to a Wave IndexSequence. + + Args: + attr: The WaveIndexMappingAttr to convert. + + Returns: + The Wave IndexSequence corresponding to the WaveIndexMappingAttr. + + Raises: + ValueError: If any subexpression in the mapping is not supported. + """ + + def wrap_symbol(symbol_name: ir.Attribute) -> sympy.Symbol: + if isinstance(symbol_name, wave.WaveSymbolAttr): + return index_symbol(symbol_name.name) + elif isinstance(symbol_name, wave.WaveIterSymbolAttr): + return index_symbol(ITER_SYMBOL_NAME_WAVE_PREFIX + symbol_name.name) + elif isinstance(symbol_name, wave.WaveIndexSymbolAttr): + index_symbol_var = INDEX_SYMBOL_REVERSE_MAP.get(symbol_name.value, None) + if index_symbol_var is None: + raise ValueError(f"Unsupported index symbol: {symbol_name.value}") + return index_symbol(index_symbol_var) + else: + raise ValueError(f"Unsupported symbol attribute: {symbol_name}") + + symbols = list(map(wrap_symbol, attr.symbols)) + assert ( + len(attr.start.results) == 1 + ), f"Expected start map to have one expression, got {attr.start}" + assert ( + len(attr.step.results) == 1 + ), f"Expected step map to have one expression, got {attr.step}" + assert ( + len(attr.stride.results) == 1 + ), f"Expected stride map to have one expression, got {attr.stride}" + start = _convert_affine_expr_to_sympy_expr(attr.start.results[0], symbols) + step = _convert_affine_expr_to_sympy_expr(attr.step.results[0], symbols) + stride = _convert_affine_expr_to_sympy_expr(attr.stride.results[0], symbols) + return IndexSequence(start, step, stride) + + +def _convert_index_mapping_dict_to_sympy( + dict_attr: ir.DictAttr, +) -> dict[IndexSymbol, IndexSequence]: + """Convert a dictionary attribute containing WaveIndexMappingAttr to a dictionary of Wave IndexSequences.""" + result = {} + for named_attr in dict_attr: + key = named_attr.name + value = named_attr.attr + assert isinstance( + value, wave.WaveIndexMappingAttr + ), f"Unsupported index mapping attribute: {value}" + result[index_symbol(key)] = _convert_index_mapping_attr_to_sympy(value) + return result + + +def _make_piecewise_sequence( + *components: tuple[IndexSequence, sympy.Expr] +) -> IndexSequence: + """Create a Piecewise IndexSequence from a list of components. + + Args: + *components: A list of tuples: (subexpression, condition). + + Returns: + The Piecewise IndexSequence corresponding to the list of components. + """ + return IndexSequence( + start=sympy.Piecewise( + *[(component[0].start, component[1]) for component in components] + ), + size=sympy.Piecewise( + *[(component[0].size, component[1]) for component in components] + ), + stride=sympy.Piecewise( + *[(component[0].stride, component[1]) for component in components] + ), + ) + + +def convert_index_mapping_array_to_sympy( + op: ir.Operation, array_attr: ir.ArrayAttr +) -> dict[IndexSymbol, IndexSequence]: + # TODO: for some reason, isinstance(op.opview, MmaOp) is not working. Something is off with dialect loading/registration. + if op.name != "wave.mma": + assert ( + len(array_attr) == 1 + ), f"Expected exactly one index mapping attribute for non-MMA op: {op}" + return _convert_index_mapping_dict_to_sympy(array_attr[0]) + + assert ( + len(array_attr) == 4 + ), f"Expected exactly four index mapping attributes for MMA op: {op}" + lhs_index = _convert_index_mapping_dict_to_sympy(array_attr[0]) + rhs_index = _convert_index_mapping_dict_to_sympy(array_attr[1]) + acc_index = _convert_index_mapping_dict_to_sympy(array_attr[2]) + result_index = _convert_index_mapping_dict_to_sympy(array_attr[3]) + mk_symbols = set(lhs_index.keys()) + nk_symbols = set(rhs_index.keys()) + m_symbol = (mk_symbols - nk_symbols).pop() + n_symbol = (nk_symbols - mk_symbols).pop() + k_symbol = (mk_symbols.intersection(nk_symbols)).pop() + assert lhs_index[k_symbol] == rhs_index[k_symbol] + assert rhs_index[n_symbol] == acc_index[n_symbol] + assert acc_index[m_symbol] == result_index[m_symbol] + assert acc_index[n_symbol] == result_index[n_symbol] + return { + m_symbol: _make_piecewise_sequence( + (lhs_index[m_symbol], ~index_symbol(MMA_ACC_SYMBOL_NAME)), + (acc_index[m_symbol], index_symbol(MMA_ACC_SYMBOL_NAME)), + ), + n_symbol: rhs_index[n_symbol], + k_symbol: lhs_index[k_symbol], + } diff --git a/wave_lang/kernel/wave/mlir_converter/water_emitter.py b/wave_lang/kernel/wave/mlir_converter/water_emitter.py index e5433aa0c..573027c29 100644 --- a/wave_lang/kernel/wave/mlir_converter/water_emitter.py +++ b/wave_lang/kernel/wave/mlir_converter/water_emitter.py @@ -22,14 +22,21 @@ _parent_dir = str(_current_file.parent.parent) # Go up to wave_lang/kernel/wave/ if _parent_dir not in sys.path: sys.path.append(_parent_dir) + # Add current directory to enable importing mlir_to_wave without full package path + _current_dir = str(_current_file.parent) + if _current_dir not in sys.path: + sys.path.append(_current_dir) + +from mlir_to_wave import ( + INDEX_SYMBOL_MAP, + ITER_SYMBOL_NAME_WATER_PREFIX, + convert_index_mapping_array_to_sympy, + ITER_SYMBOL_NAME_WAVE_PREFIX, +) if TYPE_CHECKING: - from wave_lang.kernel._support.tracing import CapturedTrace - from wave_lang.kernel.wave.compile_options import WaveCompileOptions - from wave_lang.kernel.wave.constraints import Constraint - from wave_lang.kernel.lang.wave_types import Memory, Register, IndexSymbol - from wave_lang.kernel._support.indexing import IndexSequence + from wave_lang.kernel._support.indexing import IndexSequence, IndexSymbol from wave_lang.kernel._support import dtype from wave_lang.kernel.ops.wave_ops import * @@ -79,24 +86,24 @@ AddOp, AllocateOp, DivOp, - ExtractSliceOp, Exp2Op, + ExtractSliceOp, + IterateOp, MmaOp, MulOp, ReadOp, RegisterOp, WriteOp, - IterateOp, YieldOp, - WaveExprListAttr, + DeviceConstraintAttr, HardwareConstraintAttr, - WorkgroupConstraintAttr, - WaveConstraintAttr, TilingConstraintAttr, - DeviceConstraintAttr, + WaveConstraintAttr, + WaveExprListAttr, WaveMmaKind, WaveMmaKindAttr, WaveWorkgroupDimAttr, + WorkgroupConstraintAttr, ) from water_mlir.water_mlir.sympy_to_affine_converter import ( convert_sympy_to_affine_map, @@ -274,14 +281,20 @@ def _preprocess_symbols( ) -> dict[sympy.Symbol, sympy.Symbol]: """ Preprocess symbols by: - (1) adding assumptions about all symbols being positive to later enable more simplifications. - (2) replacing `$ARG` prefix of argument symbols (e.g. `ARG0`) by `_Iter_` to match dialect expectations. + + 1. adding assumptions about all symbols being positive to later enable + more simplifications. + 2. replacing ITER_SYMBOL_NAME_WAVE_PREFIX (`$ARG`) prefix of argument + symbols (e.g. `ARG0`) by ITER_SYMBOL_NAME_WATER_PREFIX (`_Iter_`) to + match dialect expectations. """ result = {} for sym in symbols: - # Special case: rename $ARG* symbols to _Iter_* - if sym.name.startswith("$ARG"): - new_name = sym.name.replace("$ARG", "_Iter_") + # Special case: rename $ARG* symbols to _Iter_*. + if sym.name.startswith(ITER_SYMBOL_NAME_WAVE_PREFIX): + new_name = sym.name.replace( + ITER_SYMBOL_NAME_WAVE_PREFIX, ITER_SYMBOL_NAME_WATER_PREFIX + ) result[sym] = sympy.Symbol(new_name, positive=True) else: result[sym] = sympy.Symbol(sym.name, positive=True) @@ -295,24 +308,13 @@ def _symbol_name_to_attribute(name: str) -> ir.Attribute: Special symbols starting with $ are converted to WaveIndexSymbolAttr, while regular symbols are converted to WaveSymbolAttr. """ - # Mapping of special symbol names to WaveIndexSymbol enum values - INDEX_SYMBOL_MAP = { - "$WG0": wave.WaveIndexSymbol.WORKGROUP_0, - "$WG1": wave.WaveIndexSymbol.WORKGROUP_1, - "$WG2": wave.WaveIndexSymbol.WORKGROUP_2, - "$T0": wave.WaveIndexSymbol.THREAD_0, - "$T1": wave.WaveIndexSymbol.THREAD_1, - "$T2": wave.WaveIndexSymbol.THREAD_2, - "$DD0": wave.WaveIndexSymbol.DEVICE_DIM_0, - "$DD1": wave.WaveIndexSymbol.DEVICE_DIM_1, - "$DD2": wave.WaveIndexSymbol.DEVICE_DIM_2, - "$GPR_NUM": wave.WaveIndexSymbol.GPR_NUMBER, - } if name in INDEX_SYMBOL_MAP: return wave.WaveIndexSymbolAttr.get(INDEX_SYMBOL_MAP[name]) - elif name.startswith("_Iter_"): - return wave.WaveIterSymbolAttr.get(name.replace("_Iter_", "")) + if name.startswith(ITER_SYMBOL_NAME_WATER_PREFIX): + return wave.WaveIterSymbolAttr.get( + name.replace(ITER_SYMBOL_NAME_WATER_PREFIX, "") + ) else: return wave.WaveSymbolAttr.get(name) @@ -344,7 +346,7 @@ def _build_index_mapping_dict( induction_symbols_to_remove = { symbol for symbol in all_symbols_set - if symbol.name.startswith("$ARG") + if symbol.name.startswith(ITER_SYMBOL_NAME_WAVE_PREFIX) and symbol not in allowed_induction_symbols } if induction_symbols_to_remove: @@ -374,7 +376,9 @@ def _build_index_mapping_dict( return ir.DictAttr.get(index_mappings) -def _attach_attributes(node: CustomOp, op: ir.Operation): +def _attach_attributes( + node: CustomOp, op: ir.Operation, known_ids: set[str] | None = None +): if getattr(node, "index", None) and isinstance(node.index, dict): dict_attrs: list[ir.DictAttr] = [] @@ -428,6 +432,13 @@ def _attach_attributes(node: CustomOp, op: ir.Operation): bounds[dim.name] = wave.WaveExprListAttr.get(symbol_attrs, result) op.attributes["bounds"] = wave.WaveReadWriteBoundsAttr.get(bounds) + if water_id := getattr(node.fx_node, "_water_id", None): + op.attributes[_INTERNAL_WATER_ID_ATTR_NAME] = ir.StringAttr.get(water_id) + if known_ids is not None: + known_ids.add(water_id) + elif known_ids is not None: + raise RuntimeError(f"Water id requested but not specified for node {node}.") + def _convert_to_wave_expr_list_tuple( exprs: Sequence[sympy.Expr | int], @@ -468,6 +479,7 @@ def _emit_ops_from_graph( trace: CapturedTrace, value_map: dict[fx.Node | fx.Proxy, ir.Value], ctx: ir.Context, + known_ids: set[str] | None = None, ): # Emit in original order to preserve dependencies for fx_node in graph.nodes: @@ -495,6 +507,29 @@ def _emit_ops_from_graph( f"GetResult index is higher than number of results of corresponding iterate node ({node.res_idx} vs {len(iterate_op.results)})" ) value_map[fx_node] = iterate_op.results[node.res_idx] + + # Attach IDs of `get_result` to the loop instead so we can recover them + # later because `get_result` doesn't exist in the dialect. + if known_ids is not None: + water_id = getattr(fx_node, "_water_id", None) + if water_id is None: + raise RuntimeError( + f"Water id requested for 'get_result' but not specified: {node}" + ) + known_ids.add(water_id) + current_attribute = ( + iterate_op.attributes[_INTERNAL_RESULT_WATER_IRS_ATTR_NAME] + if _INTERNAL_RESULT_WATER_IRS_ATTR_NAME in iterate_op.attributes + else ir.ArrayAttr.get( + [ir.UnitAttr.get()] * len(iterate_op.results) + ) + ) + attribute_list = list(current_attribute) + attribute_list[node.res_idx] = ir.StringAttr.get(water_id) + iterate_op.attributes[_INTERNAL_RESULT_WATER_IRS_ATTR_NAME] = ( + ir.ArrayAttr.get(attribute_list) + ) + # additional handling for this op is not needed, skip rest continue if isinstance(node, SharedMemoryBarrier): @@ -534,6 +569,8 @@ def _emit_ops_from_graph( result_types = [] result_locs = [] outputs = node.outputs() + if not isinstance(outputs, Sequence): + outputs = [outputs] for fx_output in outputs: output = get_custom(fx_output) output.infer_type() @@ -563,6 +600,7 @@ def _emit_ops_from_graph( trace, value_map, ctx, + known_ids, ) # create YieldOp @@ -603,7 +641,7 @@ def _emit_ops_from_graph( f"Missing support for '{node.tkw_op_name}' operation" ) - _attach_attributes(node, mlir_op.operation) + _attach_attributes(node, mlir_op.operation, known_ids) # Add results to the value map in case they are used as # operands later @@ -667,11 +705,18 @@ def _emit_wave_constraints(constraint: Constraint) -> ir.Attribute: raise NotImplementedError(f"Unsupported constraint type: {type(constraint)}") -def _flush_output(module_str: str, diagnostics: list[str]) -> None: +def _flush_output( + module_str: str, + diagnostics: list[str], + inferred_attributes: dict[str, dict[str, Any]] | None = None, +) -> None: output = dill.dumps( { "diagnostics": [d.encode("utf-8") for d in diagnostics], "module": module_str.encode("utf-8"), + "inferred_attributes": ( + inferred_attributes if inferred_attributes is not None else {} + ), } ) sys.stdout.buffer.write(output) @@ -684,7 +729,7 @@ def _create_kernel_module( constraints: list[Constraint], options: WaveCompileOptions, test_diagnostics: bool = False, -) -> tuple[ir.Module | None, list[str]]: +) -> tuple[ir.Module | None, list[str], set[str]]: """Creates an MLIR module containing the kernel function from the captured trace. Args: @@ -697,8 +742,10 @@ def _create_kernel_module( Returns: - The created MLIR module, or None if creation failed. - List of diagnostic messages. + - Set of known water IDs if options require checking water analysis. """ diagnostics: list[str] = [] + known_ids: set[str] | None = set() if options.check_water_analysis else None def diagnostics_handler(d): diagnostics.append(f"{d.location}: {d.message}") @@ -711,9 +758,9 @@ def diagnostics_handler(d): module = ir.Module.parse(options.override_mlir, context=ctx) except ir.MLIRError as e: diagnostics.append(str(e)) - return None, diagnostics + return None, diagnostics, known_ids else: - return module, diagnostics + return module, diagnostics, known_ids # Keep track of which emitted value stems from what node to wire # arguments correctly. @@ -743,7 +790,6 @@ def diagnostics_handler(d): # should be global by now (shared memory allocation happens inside the kernel). # Thus, resolve symbolic address spaces from hyperparameters. - # print(t, t.address_space) if issubclass(t, Memory) and t.address_space in options.subs: # Create a new type with resolved address space resolved_address_space = options.subs[t.address_space] @@ -807,17 +853,23 @@ def diagnostics_handler(d): ] with ir.InsertionPoint(entry_block): - _emit_ops_from_graph(trace.get_root_graph(), trace, value_map, ctx) + _emit_ops_from_graph( + trace.get_root_graph(), trace, value_map, ctx, known_ids + ) func.ReturnOp(operands_=[]) - return module, diagnostics + return module, diagnostics, known_ids + + +_INTERNAL_WATER_ID_ATTR_NAME = "_water_internal.id" +_INTERNAL_RESULT_WATER_IRS_ATTR_NAME = "_water_internal.result_ids" def _emit_from_captured_trace( trace: CapturedTrace, constraints: list[Constraint], options: WaveCompileOptions, - pipeline: str, + pipeline: str = "", test_diagnostics=False, ) -> int: @@ -830,18 +882,20 @@ def _emit_from_captured_trace( if enable_debug_info and not trace.location: diagnostics.append("Missing debug location for wave trace") - with ir.Context() as ctx, ( - trace.location.to_water() if trace.location else ir.Location.unknown() + with ( + ir.Context() as ctx, + trace.location.to_water() if trace.location else ir.Location.unknown(), ): ctx.allow_unregistered_dialects = False wave.register_dialect(ctx) + wave.register_passes() - module, creation_diagnostics = _create_kernel_module( + module, creation_diagnostics, known_ids = _create_kernel_module( ctx, trace, constraints, options, test_diagnostics ) diagnostics.extend(creation_diagnostics) if module is None: - _flush_output("", diagnostics) + _flush_output("", diagnostics, None) return 0 # Verify the module before transforming or printing. @@ -856,9 +910,13 @@ def _emit_from_captured_trace( enable_debug_info=enable_debug_info, print_generic_op_form=True ), diagnostics, + None, ) return 0 + if options.print_mlir_before_water: + print(module.operation.get_asm(), file=sys.stderr) + # If a transform script was provided, parse and apply it to the module. # This expects a transform module with a named sequence as first operation. if pipeline: @@ -882,7 +940,70 @@ def _emit_from_captured_trace( diagnostics.append(f"Failed to apply transform script: {e}") module_str = module.operation.get_asm(enable_debug_info=enable_debug_info) - _flush_output(module_str, diagnostics) + if options.print_mlir_after_water: + print(module_str, file=sys.stderr) + + # Collect attributes inferred by the pass and store them in the per-id dictionary. + inferred_attributes: dict[str, dict[str, Any]] = ( + {id: {} for id in known_ids} if known_ids else {} + ) + if options.check_water_analysis: + + def extractor(op: ir.Operation) -> ir.WalkResult: + attribute: ir.Attribute | None = ( + op.attributes[_INTERNAL_WATER_ID_ATTR_NAME] + if _INTERNAL_WATER_ID_ATTR_NAME in op.attributes + else None + ) + result_attribute: ir.Attribute | None = ( + op.attributes[_INTERNAL_RESULT_WATER_IRS_ATTR_NAME] + if _INTERNAL_RESULT_WATER_IRS_ATTR_NAME in op.attributes + else None + ) + if attribute is None and result_attribute is None: + return ir.WalkResult.ADVANCE + + def record_index( + attribute: ir.Attribute, + inferred_attributes: dict[str, dict[str, Any]], + ): + assert isinstance( + attribute, ir.StringAttr + ), f"Unexpected attribute type: {attribute}." + assert ( + attribute.value in inferred_attributes + ), f"Unknown water id {attribute.value}." + assert ( + "index" not in inferred_attributes[attribute.value] + ), f"Index already set for water id {attribute.value}." + assert "index" in op.attributes, f"Index not inferred for {op}." + + inferred_attributes[attribute.value].update( + { + "index": convert_index_mapping_array_to_sympy( + op, op.attributes["index"] + ) + } + ) + + if attribute is not None: + record_index(attribute, inferred_attributes) + if result_attribute is not None: + assert isinstance( + result_attribute, ir.ArrayAttr + ), f"Unexpected attribute type: {result_attribute}." + for attribute in result_attribute: + record_index(attribute, inferred_attributes) + + return ir.WalkResult.ADVANCE + + module.operation.walk(extractor) + for water_id, inferred_attribute in inferred_attributes.items(): + if "index" not in inferred_attribute: + raise RuntimeError(f"Index not inferred for water id {water_id}.") + + module_str = module.operation.get_asm(enable_debug_info=enable_debug_info) + _flush_output(module_str, diagnostics, inferred_attributes) return 0 @@ -897,9 +1018,9 @@ def _emit_from_captured_trace( args = parser.parse_args() - trace, constraints, options, pipeline = _parse_input() + trace, constraints, options, pass_pipeline = _parse_input() sys.exit( _emit_from_captured_trace( - trace, constraints, options, pipeline, args.test_diagnostic_emission + trace, constraints, options, pass_pipeline, args.test_diagnostic_emission ) ) diff --git a/wave_lang/kernel/wave/templates/gemm.py b/wave_lang/kernel/wave/templates/gemm.py index 48fd833ba..49237141d 100644 --- a/wave_lang/kernel/wave/templates/gemm.py +++ b/wave_lang/kernel/wave/templates/gemm.py @@ -56,8 +56,8 @@ def get_gemm_kernel( constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] constraints += [tkw.TilingConstraint(K, BLOCK_K)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / waves_per_block[0])] - constraints += [tkw.WaveConstraint(N, BLOCK_N / waves_per_block[1])] + constraints += [tkw.WaveConstraint(M, sympy.floor(BLOCK_M / waves_per_block[0]))] + constraints += [tkw.WaveConstraint(N, sympy.floor(BLOCK_N / waves_per_block[1]))] constraints += [ tkw.HardwareConstraint( diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index ac518318e..6de3e2413 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -13,9 +13,9 @@ import sys import math from typing import Any, Sequence -from functools import lru_cache from wave_lang.kernel.wave.compile_options import WaveCompileOptions +from wave_lang.support.detect_water import get_water_mlir_pkg_path, get_water_opt from wave_lang.support.ir_imports import ( Attribute, BlockArgument, @@ -174,39 +174,6 @@ def replace_ops_and_collect_subspans(op: Operation) -> WalkResult: return local_module.get_asm(binary=False, print_generic_op_form=True) -def get_water_mlir_dir() -> Path: - return Path(__file__).parent / "water_mlir" - - -def find_binary(name: str) -> str | None: - tool_path = get_water_mlir_dir() / "bin" / name - if not tool_path.is_file() or not os.access(tool_path, os.X_OK): - return None - - return str(tool_path) - - -@lru_cache -def is_water_available() -> bool: - """Returns True if the water_mlir package is available.""" - return (get_water_mlir_dir()).exists() - - -@lru_cache -def is_water_binary_available() -> bool: - """Returns True if the water-opt binary is available and executable.""" - return find_binary("water-opt") is not None - - -@lru_cache -def get_water_opt() -> str: - path = find_binary("water-opt") - if path is None: - raise RuntimeError("water-opt binary not found") - - return path - - def make_linear_pass_pipeline( pipeline: Sequence[ tuple[str, dict[str, Any]] | tuple[str, dict[str, Any], str] | str @@ -453,7 +420,7 @@ def add_transform(transform: str, entry_point: str) -> tuple[str, dict[str, Any] llvm_opt_level = 3 if options.optimization_level else 0 dump_intermediates = options.dump_intermediates or "" - toolkit_path = get_water_mlir_dir() + toolkit_path = get_water_mlir_pkg_path() pipeline = [ "lower-affine", diff --git a/wave_lang/support/detect_water.py b/wave_lang/support/detect_water.py new file mode 100644 index 000000000..9838ea231 --- /dev/null +++ b/wave_lang/support/detect_water.py @@ -0,0 +1,50 @@ +# Copyright 2025, The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from pathlib import Path +import os +from functools import lru_cache + + +@lru_cache +def get_water_mlir_pkg_path() -> Path: + """Returns the path to the water_mlir package.""" + # Assumes we are located at wave_lang/support/detect_water.py + assert Path(__file__).parent.name == "support" + assert Path(__file__).parent.parent.name == "wave_lang" + wave_lang_path = Path(__file__).parent.parent + return wave_lang_path / "kernel" / "wave" / "water_mlir" + + +def find_binary(name: str) -> str | None: + """Returns the path to the water binary with the given name.""" + tool_path = get_water_mlir_pkg_path() / "bin" / name + if not tool_path.is_file() or not os.access(tool_path, os.X_OK): + return None + + return str(tool_path) + + +@lru_cache +def is_water_available() -> bool: + """Returns True if the water_mlir package is available.""" + return (get_water_mlir_pkg_path() / "water_mlir").exists() + + +@lru_cache +def is_water_binary_available() -> bool: + """Returns True if the water-opt binary is available and executable.""" + return find_binary("water-opt") is not None + + +@lru_cache +def get_water_opt() -> str: + """Returns the path to the water-opt binary.""" + path = find_binary("water-opt") + if path is None: + raise RuntimeError("water-opt binary not found") + + return path diff --git a/wave_lang/support/indexing.py b/wave_lang/support/indexing.py new file mode 100644 index 000000000..b16d97602 --- /dev/null +++ b/wave_lang/support/indexing.py @@ -0,0 +1,102 @@ +# Copyright 2025, The Wave Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from dataclasses import dataclass +from typing import Any, TypeAlias + +import sympy # type: ignore + +__all__ = [ + "sym", + "IndexExpr", + "IndexSequence", + "IndexSymbol", + "index_symbol", + "index_expr", + "MMA_ACC_SYMBOL_NAME", + "THREAD_SYMBOL_NAMES", + "WORKGROUP_SYMBOL_NAMES", + "DEVICE_SYMBOL_NAMES", + "GPR_SYMBOL_NAME", +] + +MMA_ACC_SYMBOL_NAME = "$MMA_ACC" +THREAD_SYMBOL_NAMES = ("$T0", "$T1", "$T2") +WORKGROUP_SYMBOL_NAMES = ("$WG0", "$WG1", "$WG2") +DEVICE_SYMBOL_NAMES = ("$DD0", "$DD1", "$DD2") +GPR_SYMBOL_NAME = "$GPR_NUM" + +############################################################################### +# Index symbols and expressions +# These are just light-weight helpers around sympy symbols and expressions. +############################################################################### + +IndexSymbol: TypeAlias = sympy.Symbol +IndexExpr: TypeAlias = sympy.Expr + + +def index_symbol(name: str) -> IndexSymbol: + """Returns a named symbol, assumed to be a non-negative integer.""" + return sympy.Symbol(name, integer=True, nonnegative=True) + + +def index_expr(value: Any) -> IndexExpr: + expr = sympy.sympify(value) + return expr + + +class _IndexSymbolExpando: + def __getattr__(self, n) -> IndexSymbol: + return index_symbol(n) + + +sym = _IndexSymbolExpando() + + +@dataclass +class IndexSequence: + start: IndexExpr | int + size: IndexExpr | int + stride: IndexExpr | int = 1 + + @staticmethod + def _subs( + value: int | IndexExpr, + map: dict[IndexExpr, IndexExpr], + simultaneous: bool = False, + ) -> int | IndexExpr: + if isinstance(value, (sympy.Basic, IndexSequence)): + return value.subs(map, simultaneous=simultaneous) # type: ignore + return value + + def has(self, symbol: IndexSymbol) -> bool: + return ( + sympy.sympify(self.start).has(symbol) + or sympy.sympify(self.size).has(symbol) + or sympy.sympify(self.stride).has(symbol) + ) + + def subs(self, map: dict[IndexExpr, IndexExpr], simultaneous: bool = False): + start = self._subs(self.start, map, simultaneous) + size = self._subs(self.size, map, simultaneous) + stride = self._subs(self.stride, map, simultaneous) + return IndexSequence(start, size, stride) + + @staticmethod + def from_expr(expr: IndexExpr, subs: dict[IndexExpr, Any]): + start_subs = {k: v.start for k, v in subs.items()} + size_subs = {k: v.size for k, v in subs.items()} + stride_subs = {k: v.stride for k, v in subs.items()} + start = IndexSequence._subs(expr, start_subs) + size = IndexSequence._subs(expr, size_subs) + stride = IndexSequence._subs(expr, stride_subs) + return IndexSequence(start, size, stride) + + def __repr__(self) -> str: + return f"{self.start} : {self.size} : {self.stride}" + + def __hash__(self): + return hash((self.start, self.size, self.stride))