Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions water/include/water/Dialect/Wave/IR/WaveOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@ include "mlir/IR/BuiltinAttributeInterfaces.td"
#ifndef WATER_DIALECT_WAVE_WAVEOPS
#define WATER_DIALECT_WAVE_WAVEOPS

//-----------------------------------------------------------------------------
// Type constraints for Wave operations.
//-----------------------------------------------------------------------------

// Named constraint for types supported by wave.iterate and wave.yield operations.
// Supports both WaveTensorType (before PropagateElementsPerThread pass) and
// 1D vectors (after PropagateElementsPerThread pass).
def WaveIterableType : AnyTypeOf<[WaveTensorType, VectorOfRank<[1]>],
"wave tensor or 1d vector type">;

//-----------------------------------------------------------------------------
// Base class for all Wave operations.
//-----------------------------------------------------------------------------
Expand Down Expand Up @@ -168,12 +178,16 @@ def IterateOp : Op<WaveDialect, "iterate", [

let arguments = (ins
Arg<WaveSymbolAttr, "Iterator symbol">:$iterator,
Arg<Variadic<WaveTensorType>, "Carried values">:$iter_args,
Arg<Variadic<WaveTensorType>, "Captured values">:$captures
// Accept both WaveTensorType (before PropagateElementsPerThread) and 1D vectors (after).
// We cannot use Arg<Variadic<WaveTensorInRegisters>> because Variadic
// requires a Type and not TypeConstraint.
Arg<Variadic<WaveIterableType>, "Carried values">:$iter_args,
Arg<Variadic<WaveIterableType>, "Captured values">:$captures
);

let results = (outs
Res<Variadic<WaveTensorType>, "Yielded values">:$results
// Results follow the same type constraints as inputs.
Res<Variadic<WaveIterableType>, "Yielded values">:$results
);

let regions = (region
Expand Down Expand Up @@ -214,7 +228,8 @@ def YieldOp : Op<WaveDialect, "yield",
let summary = "Yields values from the current control flow context";

let arguments = (ins
Arg<Variadic<WaveTensorType>, "Yielded values">:$values
// Must match the type constraints of wave.iterate results.
Arg<Variadic<WaveIterableType>, "Yielded values">:$values
);

let assemblyFormat = "$values attr-dict `:` type($values)";
Expand Down
149 changes: 114 additions & 35 deletions water/lib/Dialect/Wave/IR/WaveOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,27 @@ void wave::IterateOp::makeNonIsolated(RewriterBase &rewriter) {
rewriter.mergeBlocks(originalBlock, newBlock, replacementValues);
}

bool wave::IterateOp::areTypesCompatible(Type lhs, Type rhs) {
return detail::verifyTypesCompatible(llvm::cast<wave::WaveTensorType>(lhs),
llvm::cast<wave::WaveTensorType>(rhs),
/*includeAddressSpace=*/true)
.succeeded();
bool wave::IterateOp::areTypesCompatible(mlir::Type lhs, mlir::Type rhs) {
// Handle both WaveTensorType and VectorType combinations.
auto lhsTensor = llvm::dyn_cast<wave::WaveTensorType>(lhs);
auto rhsTensor = llvm::dyn_cast<wave::WaveTensorType>(rhs);
auto lhsVector = llvm::dyn_cast<mlir::VectorType>(lhs);
auto rhsVector = llvm::dyn_cast<mlir::VectorType>(rhs);

// Both are wave tensors - check shape and address space compatibility.
if (lhsTensor && rhsTensor) {
return detail::verifyTypesCompatible(lhsTensor, rhsTensor,
/*includeAddressSpace=*/true)
.succeeded();
}

// Both are vectors - simple equality check.
if (lhsVector && rhsVector) {
return lhsVector == rhsVector;
}

// Mixed types are not compatible.
return false;
}

OperandRange wave::IterateOp::getEntrySuccessorOperands(RegionSuccessor) {
Expand Down Expand Up @@ -258,34 +274,41 @@ LogicalResult wave::IterateOp::verify() {
<< blockIterArgTypes.size() << ") and results ("
<< resultTypes.size() << ")";
}
for (auto &&[i, iterArg, blockIterArg, result] :
llvm::enumerate(iterArgTypes, blockIterArgTypes, resultTypes)) {
auto iterArgTensor = llvm::cast<wave::WaveTensorType>(iterArg);
auto blockIterArgTensor = llvm::cast<wave::WaveTensorType>(blockIterArg);
auto resultTensor = llvm::cast<wave::WaveTensorType>(result);
for (auto &&[i, iterArg, result] :
llvm::enumerate(iterArgTypes, resultTypes)) {
// Handle verification for both wave tensors and vectors.
auto iterArgTensor = llvm::dyn_cast<wave::WaveTensorType>(iterArg);
auto resultTensor = llvm::dyn_cast<wave::WaveTensorType>(result);

auto istr = std::to_string(i);
if (llvm::failed(detail::verifyTypesCompatible(
iterArgTensor, blockIterArgTensor, /*includeAddressSpace=*/true,
getLoc(), "operand iter_arg #" + istr,
"block argument #" + istr))) {
return llvm::failure();

// Both are wave tensors - verify shapes match across all dimensions.
if (iterArgTensor && resultTensor) {
if (!iterArgTensor.getFullySpecified() ||
!resultTensor.getFullySpecified())
continue;

auto allDims =
llvm::to_vector(llvm::iota_range<int>(0, iterArgTensor.getRank(),
/*Inclusive=*/false));
if (mlir::failed(detail::verifyTypesMatchingDimensions(
getLoc(), "iter_args #" + istr, iterArgTensor, allDims,
"result #" + istr, resultTensor, allDims)))
return mlir::failure();
}
if (llvm::failed(detail::verifyTypesCompatible(
iterArgTensor, resultTensor, /*includeAddressSpace=*/true, getLoc(),
"operand iter_arg #" + istr, "result #" + istr))) {
return llvm::failure();
// Both are vectors - check exact type equality.
else if (isa<VectorType>(iterArg) && isa<VectorType>(result)) {
if (iterArg != result) {
return emitOpError()
<< "iter_args #" << i << " type (" << iterArg
<< ") must match result #" << i << " type (" << result << ")";
}
}
}
for (auto &&[i, capture, captureBlockArg] :
llvm::enumerate(captureTypes, captureBlockArgTypes)) {
auto captureTensor = llvm::cast<wave::WaveTensorType>(capture);
auto captureBlockArgTensor =
llvm::cast<wave::WaveTensorType>(captureBlockArg);
if (captureTensor != captureBlockArgTensor) {
return emitOpError() << "expects the same type for capture #" << i
<< " and block argument #"
<< (getIterArgs().size() + i);
// Mixed types are not allowed.
else {
return emitOpError() << "iter_args #" << i << " and result #" << i
<< " must be the same category of types (both wave "
"tensors or both vectors)";
}
}

Expand All @@ -310,14 +333,70 @@ llvm::LogicalResult wave::IterateOp::verifyRegions() {
llvm::enumerate(resultTypes, terminatorOperandTypes, iterArgTypes,
blockIterArgTypes)) {
auto istr = std::to_string(i);
if (llvm::failed(detail::verifyTypesCompatible(
llvm::cast<wave::WaveTensorType>(result),
llvm::cast<wave::WaveTensorType>(terminatorOperand),
/*includeAddressSpace=*/true, getLoc(), "result #" + istr,
"terminator operand #" + istr))) {
return llvm::failure();

auto iterArgTensor = llvm::dyn_cast<wave::WaveTensorType>(iterArg);
auto resultTensor = llvm::dyn_cast<wave::WaveTensorType>(result);
auto blockIterArgTensor =
llvm::dyn_cast<wave::WaveTensorType>(blockIterArg);
auto terminatorOperandTensor =
llvm::dyn_cast<wave::WaveTensorType>(terminatorOperand);

// Verify result type vs terminator operand type.
if (resultTensor && terminatorOperandTensor) {
if (llvm::failed(detail::verifyTypesCompatible(
resultTensor, terminatorOperandTensor,
/*includeAddressSpace=*/true, getLoc(), "result #" + istr,
"terminator operand #" + istr))) {
return llvm::failure();
}
} else if (isa<VectorType>(result) && isa<VectorType>(terminatorOperand)) {
// For vector types, just check that they are exactly equal.
if (result != terminatorOperand) {
return emitOpError() << "result #" << i << " type (" << result
<< ") does not match terminator operand #" << i
<< " type (" << terminatorOperand << ")";
}
} else if (result != terminatorOperand) {
return emitOpError() << "result #" << i << " type (" << result
<< ") and terminator operand #" << i << " type ("
<< terminatorOperand << ") are not compatible types";
}

// Verify iter arg type vs block arg type.
if (iterArgTensor && blockIterArgTensor) {
if (llvm::failed(detail::verifyTypesCompatible(
iterArgTensor, blockIterArgTensor,
/*includeAddressSpace=*/true, getLoc(), "iter arg #" + istr,
"block iter arg #" + istr))) {
return llvm::failure();
}
} else if (isa<VectorType>(iterArg) && isa<VectorType>(blockIterArg)) {
// For vector types, just check that they are exactly equal.
if (iterArg != blockIterArg) {
return emitOpError() << "iter arg #" << i << " type (" << iterArg
<< ") does not match block iter arg #" << i
<< " type (" << blockIterArg << ")";
}
} else if (iterArg != blockIterArg) {
return emitOpError() << "iter arg #" << i << " type (" << iterArg
<< ") and block iter arg #" << i << " type ("
<< blockIterArg << ") are not compatible types";
}
}

// Verify capture types match their corresponding block arguments.
TypeRange captureTypes = getCaptures().getTypes();
TypeRange captureBlockArgTypes = TypeRange(getLoopBody()->getArgumentTypes())
.take_back(captureTypes.size());
for (auto &&[i, capture, captureBlockArg] :
llvm::enumerate(captureTypes, captureBlockArgTypes)) {
if (capture != captureBlockArg) {
return emitOpError() << "expects the same type for capture #" << i
<< " and block argument #"
<< (getIterArgs().size() + i);
}
}

return llvm::success();
}

Expand Down
39 changes: 39 additions & 0 deletions water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Transforms/DialectConversion.h"

Expand Down Expand Up @@ -58,6 +59,15 @@ make1DTransferCommonAttrs(MemRefType memrefType, int64_t vectorizedDim,
return TransferOpCommonAttrs{permAttr, inTrue, inFalse};
}

/// Find parent operation of type OpTy starting from the given block.
template <typename OpTy> static OpTy findParentOfType(Block *currentBlock) {
auto parentOp = currentBlock->getParentOp();
if (auto op = dyn_cast<OpTy>(parentOp)) {
return op;
}
return parentOp->getParentOfType<OpTy>();
}

/// Materialize affine.apply for expressions inside a `map` with `symbols`.
/// Each symbol is either a GPU id (thread/block) or a constant from `hyper`.
static FailureOr<SmallVector<Value>>
Expand Down Expand Up @@ -125,6 +135,34 @@ materializeAffine(Location loc, ArrayRef<Attribute> symbols, AffineMap map,
}
continue;
}

if (auto iterSymbol = dyn_cast<wave::WaveIterSymbolAttr>(attr)) {
// Check if we're inside an scf.for loop that corresponds to this
// iteration symbol.
Block *currentBlock = rewriter.getInsertionBlock();

if (findParentOfType<wave::IterateOp>(currentBlock)) {
return rewriter.notifyMatchFailure(
loc, "iteration symbol found inside wave.iterate - "
"please run lower-wave-control-flow pass first");
}

scf::ForOp parentFor = findParentOfType<scf::ForOp>(currentBlock);
assert(parentFor &&
"iteration symbol found but no iteration context available");

// Get the induction variable from the scf.for loop.
Value inductionVar = parentFor.getInductionVar();

// Pass the induction variable directly to the affine map. The index
// expressions are designed as affine maps that already incorporate tile
// size scaling. Pre-multiplying here would cause double multiplication
// when the affine map applies its own scaling. For example, if the map
// is (s0 * 32), it expects s0 = iteration, not s0 = iteration *
// tile_size.
baseSymVals.push_back(inductionVar);
continue;
}
}

// In case map contains multiple results, create one apply per result.
Expand All @@ -134,6 +172,7 @@ materializeAffine(Location loc, ArrayRef<Attribute> symbols, AffineMap map,
AffineMap submap =
AffineMap::get(map.getNumDims(), map.getNumSymbols(), expr);
SmallVector<Value> symVals = baseSymVals;

affine::canonicalizeMapAndOperands(&submap, &symVals);

Value apply = affine::AffineApplyOp::create(rewriter, loc, submap, symVals);
Expand Down
2 changes: 1 addition & 1 deletion water/test/Dialect/Wave/infer-types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ module {
module attributes {wave.normal_form = #wave.normal_form<full_func_boundary>} {
func.func @iterate_mismatching_results(%arg0: !wave.tensor<[@A] of f32>, %arg1: !wave.tensor<[@B] of f32>) {
%read = wave.read %arg1 : (!wave.tensor<[@B] of f32>) -> !wave.tensor<any of f32>
// expected-error @below {{type conflict was detected for result #0}}
// expected-error @below {{expected iter arg #0 dimension #0 (#wave.symbol<"A">) to match block iter arg #0 dimension #0 (#wave.symbol<"B">)}}
wave.iterate @I iter_args(%arg0, %read) {
^bb0(%arg2: !wave.tensor<[@B] of f32>, %arg3: !wave.tensor<any of f32>):
wave.yield %arg2, %arg3 : !wave.tensor<[@B] of f32>, !wave.tensor<any of f32>
Expand Down
58 changes: 58 additions & 0 deletions water/test/Dialect/Wave/lower-wave-to-mlir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -787,3 +787,61 @@ module attributes {wave.normal_form = #wave.normal_form<full_types,memory_only_t
func.return %alloc : memref<32x32xf32>
}
}

// -----

module attributes {wave.normal_form = #wave.normal_form<full_types,memory_only_types>} {
// CHECK-LABEL: @lower_iterate_with_vector_iter_args
func.func @lower_iterate_with_vector_iter_args() attributes {
wave.hyperparameters = #wave.hyperparameters<{K = 128, BLOCK_K = 32}>,
wave.constraints = [
#wave.tiling_constraint<dim = <"K">, tile_size = <[#wave.symbol<"BLOCK_K">] -> (BLOCK_K)>>
]
} {
%cst = arith.constant 0.0 : f16
%init = wave.register %cst : vector<32xf16>

// CHECK: %[[LB:.*]] = arith.constant 0 : index
// CHECK: %[[UB:.*]] = arith.constant 4 : index
// CHECK: %[[STEP:.*]] = arith.constant 1 : index
// CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%{{.*}} = %{{.*}}) -> (vector<32xf16>)
%result = wave.iterate @K iter_args(%init) {
^bb0(%arg: vector<32xf16>):
// Test vector operations within iterate body
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<32xf16>
%add = wave.add %arg, %arg : (vector<32xf16>, vector<32xf16>) -> vector<32xf16>
// CHECK: scf.yield %{{.*}} : vector<32xf16>
wave.yield %add : vector<32xf16>
} : (vector<32xf16>) -> vector<32xf16>
return
}
}

// -----

module attributes {wave.normal_form = #wave.normal_form<full_types,memory_only_types>} {
// CHECK-LABEL: @lower_iterate_multiple_vector_iter_args
func.func @lower_iterate_multiple_vector_iter_args() attributes {
wave.hyperparameters = #wave.hyperparameters<{I = 8}>,
wave.constraints = [
#wave.tiling_constraint<dim = <"I">, tile_size = <[#wave.symbol<"I">] -> (I)>>
]
} {
%cst_f32 = arith.constant 1.0 : f32
%cst_i32 = arith.constant 42 : i32
%init1 = wave.register %cst_f32 : vector<4xf32>
%init2 = wave.register %cst_i32 : vector<8xi32>

// CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) -> (vector<4xf32>, vector<8xi32>)
%result:2 = wave.iterate @I iter_args(%init1, %init2) {
^bb0(%arg1: vector<4xf32>, %arg2: vector<8xi32>):
// CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32>
%add = wave.add %arg1, %arg1 : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
// CHECK: arith.addi %{{.*}}, %{{.*}} : vector<8xi32>
%add2 = wave.add %arg2, %arg2 : (vector<8xi32>, vector<8xi32>) -> vector<8xi32>
// CHECK: scf.yield %{{.*}}, %{{.*}} : vector<4xf32>, vector<8xi32>
wave.yield %add, %add2 : vector<4xf32>, vector<8xi32>
} : (vector<4xf32>, vector<8xi32>) -> (vector<4xf32>, vector<8xi32>)
return
}
}
Loading