Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions .github/workflows/ci-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,13 @@ jobs:

- name: Run unit tests
run: |
pytest -n 4 --capture=tee-sys -vv ./tests/unittests/
if [[ "${{ contains(matrix.os, 'mi35x') }}" == 'true' ]]; then
# TODO: water-related tests segfault on mi35x
pytest -n 4 --capture=tee-sys -vv ./tests/unittests/ --ignore=tests/unittests/index_sequence_difference_test.py --ignore=tests/unittests/location_exception.py
else
pytest -n 4 --capture=tee-sys -vv ./tests/unittests/
fi
pytest -n 4 --capture=tee-sys -vv ./tests/mlir_wave_iface

- name: Test TKW runtime related stack on amdgpu
if: ${{ env.HAS_GPU == 'true' }}
Expand All @@ -233,10 +239,9 @@ jobs:

- name: Run LIT tests
env:
WAVE_TEST_WATER: ${{ env.IS_CDNA3 == 'true' && '1' || '0' }}
WAVE_TEST_WATER: ${{ (env.IS_CDNA3 == 'true' || env.is_CDNA4 == 'true') && '1' || '0' }}
WAVE_TEST_DWARFDUMP: ${{ env.IS_RDNA4 == 'false' && '1' || '0' }}
run: |
# TODO: mlir_converter tests segfault on mi35x
# TODO: can't sudo to install dwarfdump on rdna4
echo "WAVE_TEST_WATER=$WAVE_TEST_WATER"
echo "WAVE_TEST_DWARFDUMP=$WAVE_TEST_DWARFDUMP"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci-happy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ jobs:

echo "Run unit tests"
WAVE_CACHE_ON=0 python3 -m pytest -n 4 --capture=tee-sys -vv ./tests/unittests/
WAVE_CACHE_ON=0 python3 -m pytest -n 4 --capture=tee-sys -vv ./tests/mlir_wave_iface
'

# 5. TKW runtime related e2e
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ jobs:
- name: Run unit tests
run: |
pytest -n 4 --capture=tee-sys -vv ./tests/unittests/
pytest -n 4 --capture=tee-sys -vv ./tests/mlir_wave_iface

- name: Run LIT tests
run: |
Expand Down
123 changes: 123 additions & 0 deletions lit_tests/kernel/wave/infer_index_exprs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# REQUIRES: water
# RUN: python %s
# The point of this test is to avoid crashing or asserting, so just run it under lit.

# Copyright 2025 The Wave Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import wave_lang.kernel.lang as tkl
import wave_lang.kernel.wave as tkw
from wave_lang.kernel.wave.wave import LaunchableWave
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile

from wave_lang.kernel.lang.global_symbols import *
from wave_lang.kernel.wave.constraints import MMAType
from wave_lang.kernel.wave.utils.general_utils import torch_dtype_to_wave

import torch


# TODO: use the generic template, currently blocked by water not handling wave constraints.
def _get_gemm_kernel(
shape: tuple[int, int, int],
mfma_variant: MMAType,
dtype: torch.dtype = torch.float16,
block_shape: tuple[int, int, int] | None = None,
waves_per_block: tuple[int, int] | None = None,
) -> tuple[LaunchableWave, dict[tkl.IndexSymbol, tkl.IndexExpr]]:
if not block_shape:
# BLOCK_M, BLOCK_N, BLOCK_K
block_shape = (64, 64, 32)

if not waves_per_block:
# WAVE_M, WAVE_N
waves_per_block = (2, 2)

assert len(block_shape) == 3, "block_shape needs to be rank 3 for M, N, K."
assert len(waves_per_block) == 2, "waves_per_block needs to be rank 2 for M, N."

# Input sizes
M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.GLOBAL_ADDRESS_SPACE
dtype = torch_dtype_to_wave(dtype)
# Expose user-constraints
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]

# TODO: dialect expects waves_per_block to be rank 3, so we append a 1 to the end.
constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
mma_type=mfma_variant,
waves_per_block=waves_per_block + (1,),
)
]

# Wave-level micro-kernel.
# Since warps are not directly addressable, there is no
# explicit notion of a warp id (like a workgroup or thread id).
# This kernel uses the input sizes M, N, K throughout, as the tiling
# and data movement strategy is determined during the compilation process.
# These can be influenced by introducing constraints.
@tkw.wave(constraints)
def gemm(
a: tkl.Memory[M, K, GLOBAL_ADDRESS_SPACE, dtype],
b: tkl.Memory[N, K, GLOBAL_ADDRESS_SPACE, dtype],
c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

# This microkernel encodes the fact that if the iterate
# dimension were tiled, then we would need to materialize a loop.
@tkw.iterate(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# a_reg: tkw.Register[M, K, dtype]
a_reg = tkw.read(a)
# b_reg: tkw.Register[N, K, dtype]
b_reg = tkw.read(b)
# acc: tkw.Register[M, N, tkl.f32]
acc = tkw.mma(a_reg, b_reg, acc)
return acc

# repeat represents the results of the loop
tkw.write(repeat, c)

hyperparams = {
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
BLOCK_M: block_shape[0],
BLOCK_N: block_shape[1],
BLOCK_K: block_shape[2],
M: shape[0],
N: shape[1],
K: shape[2],
}
return gemm, hyperparams


def testGemm():
gemm, hyperparams = _get_gemm_kernel(
shape=(1024, 1024, 1024), mfma_variant=MMAType.F32_16x16x16_F16
)
options = WaveCompileOptions(
subs=hyperparams,
run_bench=False,
check_water_analysis=True,
print_mlir_after_water=True,
)
compiled_gemm = wave_compile(options, gemm)
assert compiled_gemm is not None


if __name__ == "__main__":
testGemm()
22 changes: 12 additions & 10 deletions lit_tests/kernel/wave/mlir_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def failure_to_parse_override_mlir():

# Override the MLIR module after `wave_compile` so it doesn't attempt to parse it.
options.override_mlir = "module {"
_, diagnostics = emit_wave_dialect(trace, constraints, options)
_, diagnostics, _ = emit_wave_dialect(trace, constraints, options)

assert len(diagnostics) == 1
# CHECK: Unable to parse module assembly
Expand All @@ -91,7 +91,9 @@ def failure_to_parse_override_mlir():
@run_test
def failure_to_parse_pipeline():
trace, options, constraints = _get_dummy_trace_options_and_constraints()
_, diagnostics = emit_wave_dialect(trace, constraints, options, pipeline="module {")
_, diagnostics, _ = emit_wave_dialect(
trace, constraints, options, pipeline="module {"
)

assert len(diagnostics) == 1
# CHECK: Failed to apply transform script: Unable to parse module assembly
Expand All @@ -102,7 +104,7 @@ def failure_to_parse_pipeline():
@run_test
def pipeline_is_empty():
trace, options, constraints = _get_dummy_trace_options_and_constraints()
_, diagnostics = emit_wave_dialect(
_, diagnostics, _ = emit_wave_dialect(
trace, constraints, options, pipeline="module {}"
)

Expand All @@ -115,7 +117,7 @@ def pipeline_is_empty():
@run_test
def pipeline_is_not_a_named_sequence():
trace, options, constraints = _get_dummy_trace_options_and_constraints()
_, diagnostics = emit_wave_dialect(
_, diagnostics, _ = emit_wave_dialect(
trace, constraints, options, pipeline="module { module {}}"
)

Expand All @@ -141,7 +143,7 @@ def pipeline_is_not_a_named_sequence():
def failure_in_pipeline():
trace, options, constraints = _get_dummy_trace_options_and_constraints()
options.override_mlir = "module {}"
_, diagnostics = emit_wave_dialect(
_, diagnostics, _ = emit_wave_dialect(
trace, constraints, options, pipeline=GUARANTEED_FAIL_TRANSFORM_SCRIPT
)
assert len(diagnostics) == 1
Expand All @@ -158,7 +160,7 @@ def override_mlir():
module {
func.func private @overridden_mlir()
}"""
emitted, diagnostics = emit_wave_dialect(trace, constraints, options)
emitted, diagnostics, _ = emit_wave_dialect(trace, constraints, options)
assert len(diagnostics) == 0, "Did not expect errors in overridden IR."

# CHECK: func.func private @overridden_mlir()
Expand Down Expand Up @@ -218,7 +220,7 @@ def mlir_converter_matrix_add():
constraints = matrix_add.constraints

# Use the mlir_converter to emit wave MLIR dialect
mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options)
mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options)

if diagnostics:
for diagnostic in diagnostics:
Expand Down Expand Up @@ -374,7 +376,7 @@ def pipeline(root: OpHandle):

# Use the mlir_converter to emit wave MLIR dialect and apply the empty
# pipeline.
mlir_output, diagnostics = emit_wave_dialect(
mlir_output, diagnostics, _ = emit_wave_dialect(
trace, constraints, options, pipeline=pipeline_asm
)

Expand Down Expand Up @@ -528,7 +530,7 @@ def mixed_memory_kernel(
constraints = mixed_memory_kernel.constraints

with Context(), Location.unknown():
mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options)
mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options)

assert len(diagnostics) == 0, f"Should have no diagnostics, got: {diagnostics}"

Expand Down Expand Up @@ -582,7 +584,7 @@ def invalid_hyperparameter_kernel(
# This should raise a RuntimeError due to invalid non-int hyperparameter
try:
with Context(), Location.unknown():
mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options)
mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options)
assert False, "Expected RuntimeError for invalid non-int hyperparameter"
except RuntimeError as e:
# Verify the error message is what we expect
Expand Down
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/mlir_converter_debug_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def mlir_converter_location():
constraints = matrix_add.constraints

# Use the mlir_converter to emit wave MLIR dialect
mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options)
mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options)

if diagnostics:
print(diagnostics)
Expand Down Expand Up @@ -210,7 +210,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
constraints = matmul.constraints

# Use the mlir_converter to emit wave MLIR dialect
mlir_output, diagnostics = emit_wave_dialect(trace, constraints, options)
mlir_output, diagnostics, _ = emit_wave_dialect(trace, constraints, options)

if diagnostics:
print(diagnostics)
Expand Down
2 changes: 1 addition & 1 deletion lit_tests/kernel/wave/mlir_converter_diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def mlir_converter_diagnostics_emission():
constraints = matrix_add.constraints

# Use the mlir_converter to emit wave MLIR dialect
_, diagnostics = emit_wave_dialect(
_, diagnostics, _ = emit_wave_dialect(
trace, constraints, options, test_diagnostic_emission=True
)

Expand Down
9 changes: 9 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[pytest]
testspath = ./tests
filterwarnings =
# TODO: Remove once flatbuffer 'imp' usage resolved.
ignore::DeprecationWarning
# Because of the pytest collection process, it will import all modules from all
# tests, which is undesirable for mlir/wave interfacing tests. Exclude them
# from the global run.
addopts = "--ignore=tests/mlir_wave_iface"
6 changes: 0 additions & 6 deletions setup.cfg

This file was deleted.

2 changes: 1 addition & 1 deletion tests/kernel/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def param_bool(name, shortname=None, values=None):


def _is_water_and_ee_available() -> bool:
from wave_lang.kernel.wave.water import is_water_available
from wave_lang.support.detect_water import is_water_available
from wave_lang.kernel.wave.execution_engine import is_execution_engine_available

return is_water_available() and is_execution_engine_available()
Expand Down
4 changes: 3 additions & 1 deletion tests/kernel/test_water.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from unittest.mock import patch
from wave_lang.kernel.wave.water import (
apply_water_middle_end_passes,
)
from wave_lang.support.detect_water import (
find_binary,
get_water_opt,
is_water_available,
Expand Down Expand Up @@ -36,7 +38,7 @@ def test_apply_water_middle_end_passes_unavailable(self):
# Mock find_binary to return None, simulating water-opt not being found
get_water_opt.cache_clear() # get_water_opt caches find_binary result
try:
with patch("wave_lang.kernel.wave.water.find_binary", return_value=None):
with patch("wave_lang.support.detect_water.find_binary", return_value=None):
with pytest.raises(RuntimeError, match="water-opt binary not found"):
apply_water_middle_end_passes("module {}")
finally:
Expand Down
Loading