diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index f64c4b0b0..30f1bb306 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -16,6 +16,16 @@ include "mlir/IR/BuiltinAttributeInterfaces.td" #ifndef WATER_DIALECT_WAVE_WAVEOPS #define WATER_DIALECT_WAVE_WAVEOPS +//----------------------------------------------------------------------------- +// Type constraints for Wave operations. +//----------------------------------------------------------------------------- + +// Named constraint for types supported by wave.iterate and wave.yield operations. +// Supports both WaveTensorType (before PropagateElementsPerThread pass) and +// 1D vectors (after PropagateElementsPerThread pass). +def WaveIterableType : AnyTypeOf<[WaveTensorType, VectorOfRank<[1]>], + "wave tensor or 1d vector type">; + //----------------------------------------------------------------------------- // Base class for all Wave operations. //----------------------------------------------------------------------------- @@ -168,12 +178,16 @@ def IterateOp : Op:$iterator, - Arg, "Carried values">:$iter_args, - Arg, "Captured values">:$captures + // Accept both WaveTensorType (before PropagateElementsPerThread) and 1D vectors (after). + // We cannot use Arg> because Variadic + // requires a Type and not TypeConstraint. + Arg, "Carried values">:$iter_args, + Arg, "Captured values">:$captures ); let results = (outs - Res, "Yielded values">:$results + // Results follow the same type constraints as inputs. + Res, "Yielded values">:$results ); let regions = (region @@ -214,7 +228,8 @@ def YieldOp : Op, "Yielded values">:$values + // Must match the type constraints of wave.iterate results. + Arg, "Yielded values">:$values ); let assemblyFormat = "$values attr-dict `:` type($values)"; diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index 9d8458c7b..87d470dd4 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -193,11 +193,27 @@ void wave::IterateOp::makeNonIsolated(RewriterBase &rewriter) { rewriter.mergeBlocks(originalBlock, newBlock, replacementValues); } -bool wave::IterateOp::areTypesCompatible(Type lhs, Type rhs) { - return detail::verifyTypesCompatible(llvm::cast(lhs), - llvm::cast(rhs), - /*includeAddressSpace=*/true) - .succeeded(); +bool wave::IterateOp::areTypesCompatible(mlir::Type lhs, mlir::Type rhs) { + // Handle both WaveTensorType and VectorType combinations. + auto lhsTensor = llvm::dyn_cast(lhs); + auto rhsTensor = llvm::dyn_cast(rhs); + auto lhsVector = llvm::dyn_cast(lhs); + auto rhsVector = llvm::dyn_cast(rhs); + + // Both are wave tensors - check shape and address space compatibility. + if (lhsTensor && rhsTensor) { + return detail::verifyTypesCompatible(lhsTensor, rhsTensor, + /*includeAddressSpace=*/true) + .succeeded(); + } + + // Both are vectors - simple equality check. + if (lhsVector && rhsVector) { + return lhsVector == rhsVector; + } + + // Mixed types are not compatible. + return false; } OperandRange wave::IterateOp::getEntrySuccessorOperands(RegionSuccessor) { @@ -258,34 +274,41 @@ LogicalResult wave::IterateOp::verify() { << blockIterArgTypes.size() << ") and results (" << resultTypes.size() << ")"; } - for (auto &&[i, iterArg, blockIterArg, result] : - llvm::enumerate(iterArgTypes, blockIterArgTypes, resultTypes)) { - auto iterArgTensor = llvm::cast(iterArg); - auto blockIterArgTensor = llvm::cast(blockIterArg); - auto resultTensor = llvm::cast(result); + for (auto &&[i, iterArg, result] : + llvm::enumerate(iterArgTypes, resultTypes)) { + // Handle verification for both wave tensors and vectors. + auto iterArgTensor = llvm::dyn_cast(iterArg); + auto resultTensor = llvm::dyn_cast(result); auto istr = std::to_string(i); - if (llvm::failed(detail::verifyTypesCompatible( - iterArgTensor, blockIterArgTensor, /*includeAddressSpace=*/true, - getLoc(), "operand iter_arg #" + istr, - "block argument #" + istr))) { - return llvm::failure(); + + // Both are wave tensors - verify shapes match across all dimensions. + if (iterArgTensor && resultTensor) { + if (!iterArgTensor.getFullySpecified() || + !resultTensor.getFullySpecified()) + continue; + + auto allDims = + llvm::to_vector(llvm::iota_range(0, iterArgTensor.getRank(), + /*Inclusive=*/false)); + if (mlir::failed(detail::verifyTypesMatchingDimensions( + getLoc(), "iter_args #" + istr, iterArgTensor, allDims, + "result #" + istr, resultTensor, allDims))) + return mlir::failure(); } - if (llvm::failed(detail::verifyTypesCompatible( - iterArgTensor, resultTensor, /*includeAddressSpace=*/true, getLoc(), - "operand iter_arg #" + istr, "result #" + istr))) { - return llvm::failure(); + // Both are vectors - check exact type equality. + else if (isa(iterArg) && isa(result)) { + if (iterArg != result) { + return emitOpError() + << "iter_args #" << i << " type (" << iterArg + << ") must match result #" << i << " type (" << result << ")"; + } } - } - for (auto &&[i, capture, captureBlockArg] : - llvm::enumerate(captureTypes, captureBlockArgTypes)) { - auto captureTensor = llvm::cast(capture); - auto captureBlockArgTensor = - llvm::cast(captureBlockArg); - if (captureTensor != captureBlockArgTensor) { - return emitOpError() << "expects the same type for capture #" << i - << " and block argument #" - << (getIterArgs().size() + i); + // Mixed types are not allowed. + else { + return emitOpError() << "iter_args #" << i << " and result #" << i + << " must be the same category of types (both wave " + "tensors or both vectors)"; } } @@ -310,14 +333,70 @@ llvm::LogicalResult wave::IterateOp::verifyRegions() { llvm::enumerate(resultTypes, terminatorOperandTypes, iterArgTypes, blockIterArgTypes)) { auto istr = std::to_string(i); - if (llvm::failed(detail::verifyTypesCompatible( - llvm::cast(result), - llvm::cast(terminatorOperand), - /*includeAddressSpace=*/true, getLoc(), "result #" + istr, - "terminator operand #" + istr))) { - return llvm::failure(); + + auto iterArgTensor = llvm::dyn_cast(iterArg); + auto resultTensor = llvm::dyn_cast(result); + auto blockIterArgTensor = + llvm::dyn_cast(blockIterArg); + auto terminatorOperandTensor = + llvm::dyn_cast(terminatorOperand); + + // Verify result type vs terminator operand type. + if (resultTensor && terminatorOperandTensor) { + if (llvm::failed(detail::verifyTypesCompatible( + resultTensor, terminatorOperandTensor, + /*includeAddressSpace=*/true, getLoc(), "result #" + istr, + "terminator operand #" + istr))) { + return llvm::failure(); + } + } else if (isa(result) && isa(terminatorOperand)) { + // For vector types, just check that they are exactly equal. + if (result != terminatorOperand) { + return emitOpError() << "result #" << i << " type (" << result + << ") does not match terminator operand #" << i + << " type (" << terminatorOperand << ")"; + } + } else if (result != terminatorOperand) { + return emitOpError() << "result #" << i << " type (" << result + << ") and terminator operand #" << i << " type (" + << terminatorOperand << ") are not compatible types"; + } + + // Verify iter arg type vs block arg type. + if (iterArgTensor && blockIterArgTensor) { + if (llvm::failed(detail::verifyTypesCompatible( + iterArgTensor, blockIterArgTensor, + /*includeAddressSpace=*/true, getLoc(), "iter arg #" + istr, + "block iter arg #" + istr))) { + return llvm::failure(); + } + } else if (isa(iterArg) && isa(blockIterArg)) { + // For vector types, just check that they are exactly equal. + if (iterArg != blockIterArg) { + return emitOpError() << "iter arg #" << i << " type (" << iterArg + << ") does not match block iter arg #" << i + << " type (" << blockIterArg << ")"; + } + } else if (iterArg != blockIterArg) { + return emitOpError() << "iter arg #" << i << " type (" << iterArg + << ") and block iter arg #" << i << " type (" + << blockIterArg << ") are not compatible types"; } } + + // Verify capture types match their corresponding block arguments. + TypeRange captureTypes = getCaptures().getTypes(); + TypeRange captureBlockArgTypes = TypeRange(getLoopBody()->getArgumentTypes()) + .take_back(captureTypes.size()); + for (auto &&[i, capture, captureBlockArg] : + llvm::enumerate(captureTypes, captureBlockArgTypes)) { + if (capture != captureBlockArg) { + return emitOpError() << "expects the same type for capture #" << i + << " and block argument #" + << (getIterArgs().size() + i); + } + } + return llvm::success(); } diff --git a/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp b/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp index 8200146b0..aa90e38da 100644 --- a/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp +++ b/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Transforms/DialectConversion.h" @@ -58,6 +59,15 @@ make1DTransferCommonAttrs(MemRefType memrefType, int64_t vectorizedDim, return TransferOpCommonAttrs{permAttr, inTrue, inFalse}; } +/// Find parent operation of type OpTy starting from the given block. +template static OpTy findParentOfType(Block *currentBlock) { + auto parentOp = currentBlock->getParentOp(); + if (auto op = dyn_cast(parentOp)) { + return op; + } + return parentOp->getParentOfType(); +} + /// Materialize affine.apply for expressions inside a `map` with `symbols`. /// Each symbol is either a GPU id (thread/block) or a constant from `hyper`. static FailureOr> @@ -125,6 +135,34 @@ materializeAffine(Location loc, ArrayRef symbols, AffineMap map, } continue; } + + if (auto iterSymbol = dyn_cast(attr)) { + // Check if we're inside an scf.for loop that corresponds to this + // iteration symbol. + Block *currentBlock = rewriter.getInsertionBlock(); + + if (findParentOfType(currentBlock)) { + return rewriter.notifyMatchFailure( + loc, "iteration symbol found inside wave.iterate - " + "please run lower-wave-control-flow pass first"); + } + + scf::ForOp parentFor = findParentOfType(currentBlock); + assert(parentFor && + "iteration symbol found but no iteration context available"); + + // Get the induction variable from the scf.for loop. + Value inductionVar = parentFor.getInductionVar(); + + // Pass the induction variable directly to the affine map. The index + // expressions are designed as affine maps that already incorporate tile + // size scaling. Pre-multiplying here would cause double multiplication + // when the affine map applies its own scaling. For example, if the map + // is (s0 * 32), it expects s0 = iteration, not s0 = iteration * + // tile_size. + baseSymVals.push_back(inductionVar); + continue; + } } // In case map contains multiple results, create one apply per result. @@ -134,6 +172,7 @@ materializeAffine(Location loc, ArrayRef symbols, AffineMap map, AffineMap submap = AffineMap::get(map.getNumDims(), map.getNumSymbols(), expr); SmallVector symVals = baseSymVals; + affine::canonicalizeMapAndOperands(&submap, &symVals); Value apply = affine::AffineApplyOp::create(rewriter, loc, submap, symVals); diff --git a/water/test/Dialect/Wave/infer-types.mlir b/water/test/Dialect/Wave/infer-types.mlir index 92106a7e1..628c03502 100644 --- a/water/test/Dialect/Wave/infer-types.mlir +++ b/water/test/Dialect/Wave/infer-types.mlir @@ -206,7 +206,7 @@ module { module attributes {wave.normal_form = #wave.normal_form} { func.func @iterate_mismatching_results(%arg0: !wave.tensor<[@A] of f32>, %arg1: !wave.tensor<[@B] of f32>) { %read = wave.read %arg1 : (!wave.tensor<[@B] of f32>) -> !wave.tensor - // expected-error @below {{type conflict was detected for result #0}} + // expected-error @below {{expected iter arg #0 dimension #0 (#wave.symbol<"A">) to match block iter arg #0 dimension #0 (#wave.symbol<"B">)}} wave.iterate @I iter_args(%arg0, %read) { ^bb0(%arg2: !wave.tensor<[@B] of f32>, %arg3: !wave.tensor): wave.yield %arg2, %arg3 : !wave.tensor<[@B] of f32>, !wave.tensor diff --git a/water/test/Dialect/Wave/lower-wave-to-mlir.mlir b/water/test/Dialect/Wave/lower-wave-to-mlir.mlir index fc1402b35..e0ce0ee86 100644 --- a/water/test/Dialect/Wave/lower-wave-to-mlir.mlir +++ b/water/test/Dialect/Wave/lower-wave-to-mlir.mlir @@ -787,3 +787,61 @@ module attributes {wave.normal_form = #wave.normal_form } } + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { + // CHECK-LABEL: @lower_iterate_with_vector_iter_args + func.func @lower_iterate_with_vector_iter_args() attributes { + wave.hyperparameters = #wave.hyperparameters<{K = 128, BLOCK_K = 32}>, + wave.constraints = [ + #wave.tiling_constraint, tile_size = <[#wave.symbol<"BLOCK_K">] -> (BLOCK_K)>> + ] + } { + %cst = arith.constant 0.0 : f16 + %init = wave.register %cst : vector<32xf16> + + // CHECK: %[[LB:.*]] = arith.constant 0 : index + // CHECK: %[[UB:.*]] = arith.constant 4 : index + // CHECK: %[[STEP:.*]] = arith.constant 1 : index + // CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%{{.*}} = %{{.*}}) -> (vector<32xf16>) + %result = wave.iterate @K iter_args(%init) { + ^bb0(%arg: vector<32xf16>): + // Test vector operations within iterate body + // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<32xf16> + %add = wave.add %arg, %arg : (vector<32xf16>, vector<32xf16>) -> vector<32xf16> + // CHECK: scf.yield %{{.*}} : vector<32xf16> + wave.yield %add : vector<32xf16> + } : (vector<32xf16>) -> vector<32xf16> + return + } +} + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { + // CHECK-LABEL: @lower_iterate_multiple_vector_iter_args + func.func @lower_iterate_multiple_vector_iter_args() attributes { + wave.hyperparameters = #wave.hyperparameters<{I = 8}>, + wave.constraints = [ + #wave.tiling_constraint, tile_size = <[#wave.symbol<"I">] -> (I)>> + ] + } { + %cst_f32 = arith.constant 1.0 : f32 + %cst_i32 = arith.constant 42 : i32 + %init1 = wave.register %cst_f32 : vector<4xf32> + %init2 = wave.register %cst_i32 : vector<8xi32> + + // CHECK: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %{{.*}}) -> (vector<4xf32>, vector<8xi32>) + %result:2 = wave.iterate @I iter_args(%init1, %init2) { + ^bb0(%arg1: vector<4xf32>, %arg2: vector<8xi32>): + // CHECK: arith.addf %{{.*}}, %{{.*}} : vector<4xf32> + %add = wave.add %arg1, %arg1 : (vector<4xf32>, vector<4xf32>) -> vector<4xf32> + // CHECK: arith.addi %{{.*}}, %{{.*}} : vector<8xi32> + %add2 = wave.add %arg2, %arg2 : (vector<8xi32>, vector<8xi32>) -> vector<8xi32> + // CHECK: scf.yield %{{.*}}, %{{.*}} : vector<4xf32>, vector<8xi32> + wave.yield %add, %add2 : vector<4xf32>, vector<8xi32> + } : (vector<4xf32>, vector<8xi32>) -> (vector<4xf32>, vector<8xi32>) + return + } +} diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index 9f74ebc5e..c9216bfd7 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -107,7 +107,7 @@ func.func @iterate_iter_args_block_iter_args_mismatch(%arg0: !wave.tensor) { - // expected-error @below {{expected operand iter_arg #0 and result #0 elemental types to match, got 'f32', 'bf16'}} + // expected-error @below {{expected iter arg #0 and block iter arg #0 elemental types to match, got 'f32', 'bf16'}} wave.iterate @I iter_args(%arg0) { ^bb0(%arg1: !wave.tensor<[@A] of bf16>): wave.yield %arg1 : !wave.tensor<[@A] of bf16> @@ -117,7 +117,7 @@ func.func @iterate_iter_arg_block_arg_element_type_mismatch(%arg0: !wave.tensor< // ----- func.func @iterate_iter_arg_block_arg_rank_mismatch(%arg0: !wave.tensor<[@A] of f32>) { - // expected-error @below {{rank mismatch between operand iter_arg #0 and result #0}} + // expected-error @below {{rank mismatch between iter arg #0 and block iter arg #0}} wave.iterate @I iter_args(%arg0) { ^bb0(%arg1: !wave.tensor<[@A, @B] of f32>): wave.yield %arg1 : !wave.tensor<[@A, @B] of f32> @@ -127,7 +127,7 @@ func.func @iterate_iter_arg_block_arg_rank_mismatch(%arg0: !wave.tensor<[@A] of // ----- func.func @iterate_iter_arg_block_arg_shape_mismatch(%arg0: !wave.tensor<[@A] of f32>) { - // expected-error @below {{expected operand iter_arg #0 dimension #0 (#wave.symbol<"A">) to match result #0 dimension #0 (#wave.symbol<"B">)}} + // expected-error @below {{expected iter_args #0 dimension #0 (#wave.symbol<"A">) to match result #0 dimension #0 (#wave.symbol<"B">)}} wave.iterate @I iter_args(%arg0) { ^bb0(%arg1: !wave.tensor<[@B] of f32>): wave.yield %arg1 : !wave.tensor<[@B] of f32> @@ -137,7 +137,7 @@ func.func @iterate_iter_arg_block_arg_shape_mismatch(%arg0: !wave.tensor<[@A] of // ----- func.func @iterate_iter_arg_block_arg_address_space_mismatch(%arg0: !wave.tensor<[@A] of f32, >) { - // expected-error @below {{address space mismatch between operand iter_arg #0 and result #0}} + // expected-error @below {{address space mismatch between iter arg #0 and block iter arg #0}} wave.iterate @I iter_args(%arg0) { ^bb0(%arg1: !wave.tensor<[@A] of f32, >): wave.yield %arg1 : !wave.tensor<[@A] of f32, > @@ -147,7 +147,7 @@ func.func @iterate_iter_arg_block_arg_address_space_mismatch(%arg0: !wave.tensor // ----- func.func @iterate_iter_arg_result_element_type_mismatch(%arg0: !wave.tensor<[@A] of f32>) { - // expected-error @below {{expected operand iter_arg #0 and result #0 elemental types to match, got 'f32', 'bf16'}} + // expected-error @below {{expected result #0 and terminator operand #0 elemental types to match, got 'bf16', 'f32'}} wave.iterate @I iter_args(%arg0) { ^bb0(%arg1: !wave.tensor<[@A] of f32>): wave.yield %arg1 : !wave.tensor<[@A] of f32> @@ -157,7 +157,7 @@ func.func @iterate_iter_arg_result_element_type_mismatch(%arg0: !wave.tensor<[@A // ----- func.func @iterate_iter_arg_result_rank_mismatch(%arg0: !wave.tensor<[@A] of f32>) { - // expected-error @below {{rank mismatch between operand iter_arg #0 and result #0}} + // expected-error @below {{rank mismatch between result #0 and terminator operand #0}} wave.iterate @I iter_args(%arg0) { ^bb0(%arg1: !wave.tensor<[@A] of f32>): wave.yield %arg1 : !wave.tensor<[@A] of f32> @@ -167,7 +167,7 @@ func.func @iterate_iter_arg_result_rank_mismatch(%arg0: !wave.tensor<[@A] of f32 // ----- func.func @iterate_iter_arg_result_shape_mismatch(%arg0: !wave.tensor<[@A] of f32>) { - // expected-error @below {{expected operand iter_arg #0 dimension #0 (#wave.symbol<"A">) to match result #0 dimension #0 (#wave.symbol<"B">)}} + // expected-error @below {{expected iter_args #0 dimension #0 (#wave.symbol<"A">) to match result #0 dimension #0 (#wave.symbol<"B">)}} wave.iterate @I iter_args(%arg0) { ^bb0(%arg1: !wave.tensor<[@A] of f32>): wave.yield %arg1 : !wave.tensor<[@A] of f32> @@ -177,7 +177,7 @@ func.func @iterate_iter_arg_result_shape_mismatch(%arg0: !wave.tensor<[@A] of f3 // ----- func.func @iterate_iter_arg_result_address_space_mismatch(%arg0: !wave.tensor<[@A] of f32, >) { - // expected-error @below {{address space mismatch between operand iter_arg #0 and result #0}} + // expected-error @below {{address space mismatch between result #0 and terminator operand #0}} wave.iterate @I iter_args(%arg0) { ^bb0(%arg1: !wave.tensor<[@A] of f32, >): wave.yield %arg1 : !wave.tensor<[@A] of f32, > @@ -592,3 +592,62 @@ func.func @cast_underspecified_to_different_shape(%arg0: !wave.tensor<[@A, @B] o wave.cast %arg0 : !wave.tensor<[@A, @B] of f32> to !wave.tensor<[@X, @Y] of i32> return } + +// ----- + +// Test mixed wave tensor and vector types in iterate - should fail +func.func @iterate_mixed_tensor_vector_types() attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, I = 4}>} { + %tensor_input = wave.allocate {distributed_shape = #wave.expr_list<[] -> (128)>} : !wave.tensor<[@M] of f32, > + %vector_input = arith.constant dense<1.0> : vector<8xf32> + + // expected-error @below {{iter_args #0 and result #0 must be the same category of types (both wave tensors or both vectors)}} + %iter_result:2 = wave.iterate @I iter_args(%tensor_input, %vector_input) { + ^bb0(%in_arg0: !wave.tensor<[@M] of f32, >, %in_arg1: vector<8xf32>): + wave.yield %in_arg1, %in_arg0 : vector<8xf32>, !wave.tensor<[@M] of f32, > + } : (!wave.tensor<[@M] of f32, >, vector<8xf32>) -> (vector<8xf32>, !wave.tensor<[@M] of f32, >) + return +} + +// ----- + +// Test vector type mismatch in iterate +func.func @iterate_vector_type_mismatch() attributes {wave.hyperparameters = #wave.hyperparameters<{I = 4}>} { + %input = arith.constant dense<1.0> : vector<8xf32> + + // expected-error @below {{iter_args #0 type ('vector<8xf32>') must match result #0 type ('vector<4xf32>')}} + %result = wave.iterate @I iter_args(%input) { + ^bb0(%in_arg: vector<8xf32>): + %different = arith.constant dense<2.0> : vector<4xf32> + wave.yield %different : vector<4xf32> + } : (vector<8xf32>) -> (vector<4xf32>) + return +} + +// ----- + +// Test vector element type mismatch in iterate +func.func @iterate_vector_element_type_mismatch() attributes {wave.hyperparameters = #wave.hyperparameters<{I = 4}>} { + %input = arith.constant dense<1.0> : vector<8xf32> + + // expected-error @below {{iter_args #0 type ('vector<8xf32>') must match result #0 type ('vector<8xf16>')}} + %result = wave.iterate @I iter_args(%input) { + ^bb0(%in_arg: vector<8xf32>): + %different = arith.constant dense<2.0> : vector<8xf16> + wave.yield %different : vector<8xf16> + } : (vector<8xf32>) -> (vector<8xf16>) + return +} + +// ----- + +// Test that multidimensional vectors are rejected in iterate +func.func @iterate_multidim_vectors_rejected() attributes {wave.hyperparameters = #wave.hyperparameters<{I = 4}>} { + %input = arith.constant dense<1.0> : vector<4x8xf32> + + // expected-error @below {{'wave.iterate' op operand #0 must be variadic of wave tensor or 1d vector type, but got 'vector<4x8xf32>'}} + %result = wave.iterate @I iter_args(%input) { + ^bb0(%in_arg: vector<4x8xf32>): + wave.yield %in_arg : vector<4x8xf32> + } : (vector<4x8xf32>) -> (vector<4x8xf32>) + return +} diff --git a/water/test/Dialect/Wave/ops.mlir b/water/test/Dialect/Wave/ops.mlir index f13900608..5e0450478 100644 --- a/water/test/Dialect/Wave/ops.mlir +++ b/water/test/Dialect/Wave/ops.mlir @@ -412,3 +412,27 @@ func.func @cast_mixed_specified(%arg0: !wave.tensor<[@M, @N] of f32>) -> !wave.t %0 = wave.cast %arg0 : !wave.tensor<[@M, @N] of f32> to !wave.tensor return %0 : !wave.tensor } + +// ----- +// Test wave.iterate and wave.yield with vector types + +module attributes {wave.normal_form = #wave.normal_form, wave.hyperparameters = #wave.hyperparameters<{I = 4}>} { + +// Test that wave.iterate supports vector types in both iter_args and captures +// CHECK-LABEL: @iterate_vector_types +func.func @iterate_vector_types() { + // CHECK: %[[ITER_ARG:.*]] = arith.constant dense<1.000000e+00> : vector<8xf32> + %iter_arg = arith.constant dense<1.0> : vector<8xf32> + // CHECK: %[[CAPTURE:.*]] = arith.constant dense<2.000000e+00> : vector<4xf16> + %capture = arith.constant dense<2.0> : vector<4xf16> + + // CHECK: wave.iterate @I iter_args(%[[ITER_ARG]]) captures(%[[CAPTURE]]) + %result = wave.iterate @I iter_args(%iter_arg) captures(%capture) { + ^bb0(%in_arg: vector<8xf32>, %cap: vector<4xf16>): + // CHECK: wave.yield %{{.*}} : vector<8xf32> + wave.yield %in_arg : vector<8xf32> + } : (vector<8xf32>, vector<4xf16>) -> (vector<8xf32>) + return +} + +} diff --git a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir index 9b592df8f..012179e2e 100644 --- a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir +++ b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir @@ -301,3 +301,35 @@ module attributes {wave.normal_form = #wave.normal_form} { return } } + +// ----- + +// Test iterate working with vectors after PropagateElementsPerThread conversion +module attributes {wave.normal_form = #wave.normal_form} { + + // CHECK-LABEL: @iterate_with_vectors_after_ept + func.func @iterate_with_vectors_after_ept(%mem: !wave.tensor<[@M] of f32, >) + attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, I = 4}>, + wave.constraints = [#wave.hardware_constraint, vector_shapes = {M = 1}>]} { + + // Read into register tensor - this will become a vector after PropagateElementsPerThread. + // CHECK: %[[INIT:.*]] = wave.read {{.*}} : (!wave.tensor<[@M] of f32, >) -> vector<8xf32> + %init = wave.read %mem {elements_per_thread = 8} : (!wave.tensor<[@M] of f32, >) -> !wave.tensor<[@M] of f32, > + + // Iterate should work with vectors after transformation. + // CHECK: wave.iterate @I iter_args(%[[INIT]]) { + %result = wave.iterate @I iter_args(%init) { + ^bb0(%arg: !wave.tensor<[@M] of f32, >): + // Wave operations should work within the loop body after type conversion + // CHECK: wave.add {{.*}} : (vector<8xf32>, vector<8xf32>) -> vector<8xf32> + %doubled = wave.add %arg, %arg : (!wave.tensor<[@M] of f32, >, !wave.tensor<[@M] of f32, >) -> !wave.tensor<[@M] of f32, > + // CHECK: wave.yield {{.*}} : vector<8xf32> + wave.yield %doubled : !wave.tensor<[@M] of f32, > + } : (!wave.tensor<[@M] of f32, >) -> (!wave.tensor<[@M] of f32, >) + + // Write should work with the vector result after iterate + // CHECK: wave.write {{.*}} : vector<8xf32>, !wave.tensor<[@M] of f32, > + wave.write %result, %mem {elements_per_thread = 8} : !wave.tensor<[@M] of f32, >, !wave.tensor<[@M] of f32, > + return + } +}