diff --git a/water/include/water/Dialect/Wave/IR/WaveAttrs.td b/water/include/water/Dialect/Wave/IR/WaveAttrs.td index 3ac37b7b3..0b3f74c6d 100644 --- a/water/include/water/Dialect/Wave/IR/WaveAttrs.td +++ b/water/include/water/Dialect/Wave/IR/WaveAttrs.td @@ -476,4 +476,169 @@ def WaveReadWriteBoundsAttr : AttrDef { }]; } +def WaveMemoryAccessPatternAttr : AttrDef { + let mnemonic = "memory_access_pattern"; + let description = [{ + This attribute specifies how memory access should be handled during lowering, + particularly for operations that may require LDS (Local Data Store) promotion. + + LDS promotion transforms inefficient scalar memory accesses into efficient vectorized + accesses by using Local Data Store (LDS) as an intermediate buffer. The transformation + converts: + + **Original Pattern**: Register -> Global Memory (scalar, potentially uncoalesced) + + **LDS Promotion Pattern**: + 1. Register -> LDS (scalar stores, same as original pattern but to LDS) + 2. LDS -> Register (vectorized loads from LDS) + 3. Register -> Global Memory (vectorized stores to global memory) + + ## Index Transformation Logic + + ### Step 1: Register -> LDS Store + The original global memory indices are transformed to LDS indices by subtracting + the LDS block's base address in global memory coordinates: + ``` + lds_store_indices = original_global_indices - lds_block_global_base + ``` + + Example: + - Original: `global[WG0 * BLOCK_M + T0 * 4 + offset]` + - LDS base: `WG0 * BLOCK_M` + - LDS store: `lds[(WG0 * BLOCK_M + T0 * 4 + offset) - (WG0 * BLOCK_M)] = lds[T0 * 4 + offset]` + + ### Step 2: LDS -> Register Load (Vectorized) + Uses `lds_load_mapping` to define vectorized access patterns within the LDS: + ``` + lds_load_start = start_expr(thread_indices) + lds_load_vector_size = step_expr(thread_indices) + lds_load_stride = stride_expr(thread_indices) + ``` + + ### Step 3: Register -> Global Store (Vectorized) + Uses `global_store_mapping` to define vectorized stores back to global memory: + ``` + global_store_start = start_expr(thread_indices) + global_store_vector_size = step_expr(thread_indices) + global_store_stride = stride_expr(thread_indices) + ``` + + ## Complete Example + ```mlir + #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "mfma_result_0", + + // LDS block placement: where in global memory this LDS block maps to + lds_block_global_base = #wave.expr_list< + [#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M) + >, + + // LDS allocation size + lds_block_shape = #wave.expr_list< + [#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (BLOCK_M, BLOCK_N) + >, + + // Vectorized LDS -> Register: each thread loads 64 elements starting at T0*64 + lds_load_mapping = #wave.index_mapping< + [#wave.index_symbol] -> (T0 * 64, 64, 1) + >, + + // Vectorized Register -> Global: each thread stores 64 elements + global_store_mapping = #wave.index_mapping< + [#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] + -> (WG0 * BLOCK_M + T0 * 64, 64, 1) + > + > + ``` + + ## Parameters: + - **use_lds_promotion**: Whether to enable LDS promotion transformation + - **group_id**: String identifier for grouping related operations that share the same LDS allocation + - **lds_block_global_base**: Global memory base address that this LDS block represents (for index transformation) + - **lds_block_shape**: Shape/size of the LDS block allocation in elements + - **lds_load_mapping**: Index mapping for vectorized loads from LDS (start, vector_size, stride) + - **global_store_mapping**: Index mapping for vectorized stores to global memory (start, vector_size, stride) + + ## Verification Requirements: + - All index mappings must have the same rank as the original global memory tensor + - lds_block_global_base and lds_block_shape must have consistent ranks + - When LDS promotion is enabled, all LDS-related parameters must be specified + }]; + + let parameters = (ins + "bool":$use_lds_promotion, + StringRefParameter<"group identifier for shared LDS allocation">:$group_id, + + // LDS block placement information + OptionalParameter<"::wave::WaveExprListAttr">:$lds_block_global_base, + OptionalParameter<"::wave::WaveExprListAttr">:$lds_block_shape, + + // Vectorized access patterns - split into indices and vector sizes for simplicity + OptionalParameter<"::wave::WaveExprListAttr">:$lds_load_indices, + OptionalParameter<"::wave::WaveExprListAttr">:$lds_load_vector_sizes, + OptionalParameter<"::wave::WaveExprListAttr">:$global_store_indices + ); + + let assemblyFormat = [{ + `<` `use_lds_promotion` `=` $use_lds_promotion + `,` `group_id` `=` $group_id + (`,` `lds_block_global_base` `=` $lds_block_global_base^)? + (`,` `lds_block_shape` `=` $lds_block_shape^)? + (`,` `lds_load_indices` `=` $lds_load_indices^)? + (`,` `lds_load_vector_sizes` `=` $lds_load_vector_sizes^)? + (`,` `global_store_indices` `=` $global_store_indices^)? `>` + }]; + + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + /// Check if LDS promotion is enabled. + bool shouldUseLdsPromotion() const { return getUseLdsPromotion(); } + + /// Check if complete LDS promotion information is provided. + bool hasCompleteLdsPromotionInfo() const { + return shouldUseLdsPromotion() && + getLdsBlockGlobalBase() && + getLdsBlockShape() && + getLdsLoadIndices() && + getLdsLoadVectorSizes() && + getGlobalStoreIndices(); + } + + /// Check if any LDS promotion parameters are specified. + bool hasAnyLdsPromotionParams() const { + return getLdsBlockGlobalBase() || + getLdsBlockShape() || + getLdsLoadIndices() || + getLdsLoadVectorSizes() || + getGlobalStoreIndices(); + } + + /// Get the rank/dimensionality of the LDS block. + unsigned getLdsBlockRank() const { + if (!getLdsBlockShape()) return 0; + return getLdsBlockShape().getRank(); + } + + /// Get the rank of the global base address expression. + unsigned getGlobalBaseRank() const { + if (!getLdsBlockGlobalBase()) return 0; + return getLdsBlockGlobalBase().getRank(); + } + + /// Get the rank of the LDS load indices. + unsigned getLdsLoadIndicesRank() const { + if (!getLdsLoadIndices()) return 0; + return getLdsLoadIndices().getRank(); + } + + /// Get the rank of the global store indices. + unsigned getGlobalStoreIndicesRank() const { + if (!getGlobalStoreIndices()) return 0; + return getGlobalStoreIndices().getRank(); + } + }]; +} + #endif // WATER_DIALECT_WAVE_WAVEATTRS diff --git a/water/include/water/Dialect/Wave/IR/WaveOps.td b/water/include/water/Dialect/Wave/IR/WaveOps.td index 1ea298ac7..1711ca69d 100644 --- a/water/include/water/Dialect/Wave/IR/WaveOps.td +++ b/water/include/water/Dialect/Wave/IR/WaveOps.td @@ -306,7 +306,9 @@ def WriteOp : WaveOp<"write", [ Arg, "Number of elements processed by each thread">:$elements_per_thread, Arg, - "Bound expressions for each symbolic dimension">:$bounds + "Bound expressions for each symbolic dimension">:$bounds, + Arg, + "Memory access pattern controlling LDS promotion">:$memory_access_pattern ), commonArguments); let assemblyFormat = diff --git a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp index e7dd1b5d4..38182792b 100644 --- a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp +++ b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp @@ -456,6 +456,141 @@ DeviceConstraintAttr::verify(function_ref emitError, return success(); } +//===----------------------------------------------------------------------===// +// WaveMemoryAccessPatternAttr +//===----------------------------------------------------------------------===// + +LogicalResult WaveMemoryAccessPatternAttr::verify( + function_ref emitError, bool use_lds_promotion, + StringRef group_id, WaveExprListAttr lds_block_global_base, + WaveExprListAttr lds_block_shape, WaveExprListAttr lds_load_indices, + WaveExprListAttr lds_load_vector_sizes, + WaveExprListAttr global_store_indices) { + // Validate group_id is not empty + if (group_id.empty()) { + return emitError() << "group_id cannot be empty"; + } + + // When LDS promotion is disabled, no LDS-related parameters should be + // specified + if (!use_lds_promotion) { + if (lds_block_global_base || lds_block_shape || lds_load_indices || + lds_load_vector_sizes || global_store_indices) { + return emitError() << "LDS promotion parameters should not be specified " + "when use_lds_promotion=false"; + } + return success(); + } + + // When LDS promotion is enabled, validate completeness and consistency + bool hasLdsBase = static_cast(lds_block_global_base); + bool hasLdsShape = static_cast(lds_block_shape); + bool hasLdsLoadIndices = static_cast(lds_load_indices); + bool hasLdsLoadVectorSizes = static_cast(lds_load_vector_sizes); + bool hasGlobalStoreIndices = static_cast(global_store_indices); + + // Check for partial specification - either all or none should be provided + if (hasLdsBase || hasLdsShape || hasLdsLoadIndices || hasLdsLoadVectorSizes || + hasGlobalStoreIndices) { + if (!hasLdsBase || !hasLdsShape || !hasLdsLoadIndices || + !hasLdsLoadVectorSizes || !hasGlobalStoreIndices) { + return emitError() << "when LDS promotion is enabled, all LDS parameters " + "must be specified: " + "lds_block_global_base, lds_block_shape, " + "lds_load_indices, lds_load_vector_sizes, " + "global_store_indices"; + } + } + + // If all LDS parameters are provided, perform detailed validation + if (hasLdsBase && hasLdsShape && hasLdsLoadIndices && hasLdsLoadVectorSizes && + hasGlobalStoreIndices) { + + // Validate that lds_block_global_base and lds_block_shape have consistent + // ranks + unsigned ldsBaseRank = lds_block_global_base.getRank(); + unsigned ldsShapeRank = lds_block_shape.getRank(); + + if (ldsBaseRank != ldsShapeRank) { + return emitError() << "lds_block_global_base rank (" << ldsBaseRank + << ") must match lds_block_shape rank (" + << ldsShapeRank << ")"; + } + + // Validate that load indices and vector sizes have consistent ranks + unsigned ldsLoadIndicesRank = lds_load_indices.getRank(); + unsigned ldsLoadVectorSizesRank = lds_load_vector_sizes.getRank(); + unsigned globalStoreIndicesRank = global_store_indices.getRank(); + + if (ldsLoadIndicesRank != ldsLoadVectorSizesRank) { + return emitError() << "lds_load_indices rank (" << ldsLoadIndicesRank + << ") must match lds_load_vector_sizes rank (" + << ldsLoadVectorSizesRank << ")"; + } + + if (ldsBaseRank != ldsLoadIndicesRank) { + return emitError() << "LDS block rank (" << ldsBaseRank + << ") must match LDS load indices rank (" + << ldsLoadIndicesRank << ")"; + } + + if (ldsBaseRank != globalStoreIndicesRank) { + return emitError() << "LDS block rank (" << ldsBaseRank + << ") must match global store indices rank (" + << globalStoreIndicesRank << ")"; + } + + // Validate that all symbols are WaveSymbolAttr or WaveIndexSymbolAttr + if (!llvm::all_of(lds_block_global_base.getSymbols(), + llvm::IsaPred)) { + return emitError() << "lds_block_global_base must only contain " + "WaveSymbolAttr or WaveIndexSymbolAttr"; + } + + if (!llvm::all_of(lds_block_shape.getSymbols(), + llvm::IsaPred)) { + return emitError() << "lds_block_shape must only contain WaveSymbolAttr " + "or WaveIndexSymbolAttr"; + } + + if (!llvm::all_of(lds_load_indices.getSymbols(), + llvm::IsaPred)) { + return emitError() << "lds_load_indices must only contain WaveSymbolAttr " + "or WaveIndexSymbolAttr"; + } + + if (!llvm::all_of(lds_load_vector_sizes.getSymbols(), + llvm::IsaPred)) { + return emitError() << "lds_load_vector_sizes must only contain " + "WaveSymbolAttr or WaveIndexSymbolAttr"; + } + + if (!llvm::all_of(global_store_indices.getSymbols(), + llvm::IsaPred)) { + return emitError() << "global_store_indices must only contain " + "WaveSymbolAttr or WaveIndexSymbolAttr"; + } + + // Validate that mappings have at least one dimension + if (ldsBaseRank == 0) { + return emitError() << "LDS block must have at least one dimension"; + } + + // Note: We cannot validate that the ranks match the original global memory + // tensor rank here because this attribute verification doesn't have access + // to the WriteOp's memory operand. This validation should be performed in + // the WriteOp's verifier where both the attribute and the memory operand + // type are available. + // + // Additionally, data coverage verification (ensuring that the collective + // workgroup access pattern covers exactly the same elements before and + // after LDS promotion) should be performed in the WriteOp verifier where + // access to the original index mapping is available. + } + + return success(); +} + void wave::WaveDialect::registerAttributes() { addAttributes< #define GET_ATTRDEF_LIST diff --git a/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp b/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp index 96ce0082a..50734d6e7 100644 --- a/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp +++ b/water/lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp @@ -11,9 +11,12 @@ #include "water/Dialect/Wave/IR/WaveUtils.h" #include "water/Dialect/Wave/Transforms/LoweringPatterns.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Transforms/DialectConversion.h" @@ -354,6 +357,7 @@ static void buildVectorWrite(Location loc, PatternRewriter &rewriter, Value mem, } else { vector::StoreOp::create(rewriter, loc, vecValue, mem, indices); } + return; } // vector.transfer_write (masked or unmasked) @@ -470,20 +474,333 @@ class WriteOpLoweringPattern : public OpConversionPattern { LogicalResult matchAndRewrite(wave::WriteOp op, wave::WriteOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value vec = adaptor.getValueToStore(); - auto vecType = cast(vec.getType()); + Value value = adaptor.getValueToStore(); + // All wave.write operations require index attributes + if (!op.getIndexAttr()) { + return rewriter.notifyMatchFailure( + op, "wave.write operations require index attributes"); + } + + // Check if LDS promotion is requested and has complete information. + auto memAccessPattern = op.getMemoryAccessPatternAttr(); + if (memAccessPattern && memAccessPattern.hasCompleteLdsPromotionInfo()) { + // Full LDS promotion: register->LDS, barrier, LDS->register, + // register->global + return lowerLdsPromotionGroup(op, adaptor, rewriter); + } + + // Use existing vector write logic. + auto vecType = cast(value.getType()); FailureOr memInfo = createMemoryIndicesAndMask( rewriter, getTypeConverter(), op, op.getMemory().getType(), vecType); if (failed(memInfo)) return failure(); buildVectorWrite(op.getLoc(), rewriter, adaptor.getMemory(), - memInfo->startIndices, vec, memInfo->mask, + memInfo->startIndices, value, memInfo->mask, memInfo->vectorizedDim); + rewriter.eraseOp(op); return success(); } + +private: + /// Implement group-based LDS promotion lowering. + /// Collects all wave.write operations in the same LDS promotion group, + /// then lowers them together with a single barrier per group. + /// + /// DESIGN CHOICE: All operations in the same LDS promotion group must be in + /// the same scope. This is required because: + /// 1. All register->LDS stores must complete before the barrier + /// 2. All LDS->register loads must happen after the barrier + /// 3. With operations in different scopes, we cannot guarantee this ordering + LogicalResult + lowerLdsPromotionGroup(wave::WriteOp firstOp, + wave::WriteOp::Adaptor firstAdaptor, + ConversionPatternRewriter &rewriter) const { + wave::WaveMemoryAccessPatternAttr memAccessPattern = + firstOp.getMemoryAccessPatternAttr(); + StringRef groupId = memAccessPattern.getGroupId(); + Location loc = firstOp.getLoc(); + + // Step 1: Collect all wave.write operations in this LDS promotion group + // from the entire function. + SmallVector groupOps = + collectLdsPromotionGroup(firstOp, groupId); + + // Step 2: Verify that all operations in the group are in the same scope. + if (failed(verifyLdsPromotionGroupScope(groupOps, groupId))) { + return rewriter.notifyMatchFailure( + firstOp, "LDS promotion group operations must be in the same scope"); + } + + // Step 3: Create shared LDS allocation for this group. + Value ldsAlloc = createLdsAllocation(firstOp, rewriter, loc); + if (!ldsAlloc) { + return rewriter.notifyMatchFailure(firstOp, + "failed to create LDS allocation"); + } + + // Step 4: Lower all register->LDS stores first. + for (wave::WriteOp op : groupOps) { + if (failed(storeRegisterToLds(op, rewriter, op.getValueToStore(), + ldsAlloc, loc))) { + return rewriter.notifyMatchFailure(op, + "failed to store register to LDS"); + } + } + + // Step 5: Insert single barrier for the entire group. + amdgpu::LDSBarrierOp::create(rewriter, loc); + + // Step 6: Lower all LDS->global transfers. + for (wave::WriteOp op : groupOps) { + // Convert the memory operand manually for each operation in the group. + Value convertedMemory = getTypeConverter()->materializeTargetConversion( + rewriter, op.getLoc(), + getTypeConverter()->convertType(op.getMemory().getType()), + op.getMemory()); + if (!convertedMemory) { + return rewriter.notifyMatchFailure(op, + "failed to convert memory operand"); + } + + // Create a properly converted adaptor. + SmallVector convertedOperands = {op.getValueToStore(), + convertedMemory}; + wave::WriteOp::Adaptor adaptor(convertedOperands, + op->getAttrDictionary()); + if (failed(transferLdsToGlobal(op, adaptor, rewriter, ldsAlloc, loc))) { + return rewriter.notifyMatchFailure(op, + "failed to transfer LDS to global"); + } + } + + // Step 7: Erase all original wave.write operations. + for (wave::WriteOp op : groupOps) { + rewriter.eraseOp(op); + } + + return success(); + } + + /// Verify that all operations in the LDS promotion group are in the same + /// scope. Returns failure if operations span multiple scopes, which would + /// break LDS promotion. + LogicalResult verifyLdsPromotionGroupScope(ArrayRef groupOps, + StringRef groupId) const { + if (groupOps.empty()) + return success(); + + // Get the block of the first operation as the reference + Block *referenceBlock = groupOps[0]->getBlock(); + + // Verify all operations are in the same block + for (wave::WriteOp op : groupOps) { + if (op->getBlock() != referenceBlock) { + return op->emitError("LDS promotion group '") + << groupId + << "' contains operations in different scopes. All operations " + << "with the same group_id must be in the same block for " + << "correct LDS barrier semantics."; + } + } + + return success(); + } + + /// Collect all wave.write operations in the same LDS promotion group. + /// Search the entire function for all wave.write ops with matching group_id. + SmallVector collectLdsPromotionGroup(wave::WriteOp firstOp, + StringRef groupId) const { + SmallVector groupOps; + + // Find the function that contains this write op. + func::FuncOp funcOp = firstOp->getParentOfType(); + assert(funcOp && "wave.write operation must be within a function"); + + // Walk the entire function to find all operations in this group. + funcOp.walk([&](wave::WriteOp writeOp) { + auto memPattern = writeOp.getMemoryAccessPatternAttr(); + if (memPattern && memPattern.hasCompleteLdsPromotionInfo() && + memPattern.getGroupId() == groupId) { + groupOps.push_back(writeOp); + } + }); + + return groupOps; + } + + /// Create LDS allocation based on the LDS block shape from the attribute. + Value createLdsAllocation(wave::WriteOp op, + ConversionPatternRewriter &rewriter, + Location loc) const { + wave::WaveMemoryAccessPatternAttr memAccessPattern = + op.getMemoryAccessPatternAttr(); + + // Get element type from the input value. + auto vecType = cast(op.getValueToStore().getType()); + Type elementType = vecType.getElementType(); + + // Get LDS block shape from the attribute. + auto ldsBlockShape = memAccessPattern.getLdsBlockShape(); + wave::WaveHyperparameterAttr hyper = + static_cast(*getTypeConverter()) + .getHyperparameters(); + + // Resolve the LDS block shape to concrete dimensions. + auto maybeResolvedShape = ldsBlockShape.getResolvedShape(hyper); + if (!maybeResolvedShape.has_value()) { + return nullptr; + } + SmallVector ldsShape = std::move(*maybeResolvedShape); + + // Create LDS allocation with workgroup address space. + auto addressSpaceAttr = gpu::AddressSpaceAttr::get( + rewriter.getContext(), gpu::AddressSpace::Workgroup); + auto ldsMemRefType = MemRefType::get( + ldsShape, elementType, MemRefLayoutAttrInterface{}, addressSpaceAttr); + + return memref::AllocOp::create(rewriter, loc, ldsMemRefType); + } + + /// Transform global memory indices to LDS indices and store register value to + /// LDS. + LogicalResult storeRegisterToLds(wave::WriteOp op, + ConversionPatternRewriter &rewriter, + Value inputValue, Value ldsAlloc, + Location loc) const { + wave::WaveMemoryAccessPatternAttr memAccessPattern = + op.getMemoryAccessPatternAttr(); + + // Get the original global memory access info. + auto vecType = cast(inputValue.getType()); + FailureOr memInfo = createMemoryIndicesAndMask( + rewriter, getTypeConverter(), op, op.getMemory().getType(), vecType); + if (failed(memInfo)) + return failure(); + + // Transform global indices to LDS indices using: lds_indices = + // global_indices - lds_block_global_base. + auto ldsBlockGlobalBase = memAccessPattern.getLdsBlockGlobalBase(); + wave::WaveHyperparameterAttr hyper = + static_cast(*getTypeConverter()) + .getHyperparameters(); + + FailureOr> ldsBaseValues = + materializeAffine(loc, ldsBlockGlobalBase.getSymbols(), + ldsBlockGlobalBase.getMap(), rewriter, hyper); + if (failed(ldsBaseValues)) + return failure(); + SmallVector ldsBase = std::move(*ldsBaseValues); + + SmallVector ldsIndices; + ldsIndices.reserve(memInfo->startIndices.size()); + for (size_t i = 0; i < memInfo->startIndices.size(); ++i) { + Value ldsIndex = arith::SubIOp::create( + rewriter, loc, memInfo->startIndices[i], ldsBase[i]); + ldsIndices.push_back(ldsIndex); + } + + // Store to LDS using the same vectorization dimension as the original + // operation. + buildVectorWrite(loc, rewriter, ldsAlloc, ldsIndices, inputValue, + memInfo->mask, memInfo->vectorizedDim); + + return success(); + } + + /// Load vectorized data from LDS and store vectorized data to global memory. + LogicalResult transferLdsToGlobal(wave::WriteOp op, + wave::WriteOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter, + Value ldsAlloc, Location loc) const { + wave::WaveMemoryAccessPatternAttr memAccessPattern = + op.getMemoryAccessPatternAttr(); + + auto ldsLoadIndices = memAccessPattern.getLdsLoadIndices(); + auto ldsLoadVectorSizes = memAccessPattern.getLdsLoadVectorSizes(); + auto globalStoreIndices = memAccessPattern.getGlobalStoreIndices(); + wave::WaveHyperparameterAttr hyper = + static_cast(*getTypeConverter()) + .getHyperparameters(); + + // Get element type from LDS allocation. + auto ldsMemRefType = cast(ldsAlloc.getType()); + Type elementType = ldsMemRefType.getElementType(); + + // Step 1: Calculate LDS load indices using lds_load_indices. + FailureOr> ldsLoadStartValues = + materializeAffine(loc, ldsLoadIndices.getSymbols(), + ldsLoadIndices.getMap(), rewriter, hyper); + if (failed(ldsLoadStartValues)) + return failure(); + SmallVector ldsLoadStart = std::move(*ldsLoadStartValues); + + // Step 2: Resolve vector sizes from lds_load_vector_sizes. + auto maybeVectorSizes = ldsLoadVectorSizes.getResolvedShape(hyper); + if (!maybeVectorSizes.has_value()) { + return rewriter.notifyMatchFailure( + op, "failed to resolve LDS load vector sizes"); + } + SmallVector vectorSizes = std::move(*maybeVectorSizes); + + // Find dimension with largest vector size. + int64_t ldsVectorizedDim = 0; + int64_t maxVectorSize = 1; + for (auto [i, size] : llvm::enumerate(vectorSizes)) { + if (size >= maxVectorSize) { + maxVectorSize = size; + ldsVectorizedDim = i; + } + } + + // Step 3: Vectorized load from LDS. + auto ldsLoadVecType = VectorType::get({maxVectorSize}, elementType); + + Value loadedVec = + buildVectorRead(loc, rewriter, ldsAlloc, ldsLoadStart, ldsLoadVecType, + /*mask=*/nullptr, ldsVectorizedDim); + + // Step 4: Calculate global store indices using global_store_indices. + FailureOr> globalStoreStartValues = + materializeAffine(loc, globalStoreIndices.getSymbols(), + globalStoreIndices.getMap(), rewriter, hyper); + if (failed(globalStoreStartValues)) + return failure(); + SmallVector globalStoreStart = std::move(*globalStoreStartValues); + + // Step 5: Resolve global store vector sizes from lds_load_vector_sizes. + auto maybeGlobalVectorSizes = ldsLoadVectorSizes.getResolvedShape(hyper); + if (!maybeGlobalVectorSizes.has_value()) { + return rewriter.notifyMatchFailure( + op, "failed to resolve global store vector sizes"); + } + SmallVector globalVectorSizes = std::move(*maybeGlobalVectorSizes); + + // Find dimension with largest vector size. + int64_t globalVectorizedDim = 0; + int64_t globalMaxVectorSize = 1; + for (auto [i, size] : llvm::enumerate(globalVectorSizes)) { + if (size >= globalMaxVectorSize) { + globalMaxVectorSize = size; + globalVectorizedDim = i; + } + } + + assert(maxVectorSize == globalMaxVectorSize && + "vector sizes are guaranteed to match by attribute verification"); + + Value vecToStore = loadedVec; + + // Step 6: Vectorized store to global memory + buildVectorWrite(loc, rewriter, adaptor.getMemory(), globalStoreStart, + vecToStore, + /*mask=*/nullptr, globalVectorizedDim); + + return success(); + } }; } // namespace diff --git a/water/lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp b/water/lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp index d0ebcc462..d408cda06 100644 --- a/water/lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp +++ b/water/lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "water/Dialect/Wave/IR/WaveOps.h" @@ -70,8 +71,11 @@ struct LowerWaveToMLIRPass vector::VectorDialect // clang-format on >(); + // Allow unrealized conversion casts temporarily (should be eliminated by + // type converter) + target.addLegalOp(); target.addIllegalOp(); + wave::CastOp, wave::ReadOp, wave::WriteOp>(); ConversionConfig config; config.allowPatternRollback = false; diff --git a/water/test/Dialect/Wave/attr-memory-access-pattern-invalid.mlir b/water/test/Dialect/Wave/attr-memory-access-pattern-invalid.mlir new file mode 100644 index 000000000..29d6a35a4 --- /dev/null +++ b/water/test/Dialect/Wave/attr-memory-access-pattern-invalid.mlir @@ -0,0 +1,81 @@ +// RUN: water-opt %s -split-input-file -verify-diagnostics + +// Test: empty group_id should fail +func.func @memory_access_pattern_empty_group_id(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + wave.write %value, %mem { + // expected-error @+1 {{group_id cannot be empty}} + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = false, + group_id = "" + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} + +// ----- + +// Test: LDS parameters specified when use_lds_promotion=false should fail +func.func @memory_access_pattern_lds_params_when_disabled(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + wave.write %value, %mem { + // expected-error @+1 {{LDS promotion parameters should not be specified when use_lds_promotion=false}} + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = false, + group_id = "test", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol] -> (WG0)> + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} + +// ----- + +// Test: Partial LDS specification when use_lds_promotion=true should fail +func.func @memory_access_pattern_partial_lds_specification(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + wave.write %value, %mem { + // expected-error @+1 {{when LDS promotion is enabled, all LDS parameters must be specified: lds_block_global_base, lds_block_shape, lds_load_indices, lds_load_vector_sizes, global_store_indices}} + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "test", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol] -> (WG0)> + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} + +// ----- + +// Test: Mismatched ranks between lds_block_global_base and lds_block_shape should fail +func.func @memory_access_pattern_mismatched_base_shape_ranks(%value: !wave.tensor>, %mem: !wave.tensor<[@M, @N] of f32, >) { + wave.write %value, %mem { + // expected-error @+1 {{lds_block_global_base rank (1) must match lds_block_shape rank (2)}} + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "test", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol] -> (WG0)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (BLOCK_M, BLOCK_N)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol] -> (T0)> + > + } : !wave.tensor>, !wave.tensor<[@M, @N] of f32, > + return +} + +// ----- + +// Test: Mismatched ranks between lds_load_indices and lds_load_vector_sizes should fail +func.func @memory_access_pattern_mismatched_lds_load_ranks(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + wave.write %value, %mem { + // expected-error @+1 {{lds_load_indices rank (1) must match lds_load_vector_sizes rank (2)}} + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "test", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol] -> (WG0)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0)>, + lds_load_vector_sizes = #wave.expr_list<[#wave.symbol<"VEC_M">, #wave.symbol<"VEC_N">] -> (VEC_M, VEC_N)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol] -> (T0)> + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} diff --git a/water/test/Dialect/Wave/attr-memory-access-pattern.mlir b/water/test/Dialect/Wave/attr-memory-access-pattern.mlir new file mode 100644 index 000000000..0601e63cf --- /dev/null +++ b/water/test/Dialect/Wave/attr-memory-access-pattern.mlir @@ -0,0 +1,106 @@ +// RUN: water-opt %s -split-input-file | water-opt | FileCheck %s +// RUN: water-opt %s -split-input-file --mlir-print-op-generic | water-opt | FileCheck %s + +// Test basic memory access pattern parsing without LDS promotion +// CHECK-LABEL: @memory_access_pattern_basic_no_promotion +func.func @memory_access_pattern_basic_no_promotion(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + // CHECK: wave.write + // CHECK-SAME: memory_access_pattern = #wave.memory_access_pattern< + // CHECK-SAME: use_lds_promotion = false, + // CHECK-SAME: group_id = "basic_test" + // CHECK-SAME: > + wave.write %value, %mem { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = false, + group_id = "basic_test" + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} + +// ----- + +// Test memory access pattern with complete LDS promotion - 1D case with symbolic vector size +// CHECK-LABEL: @memory_access_pattern_1d_lds_promotion +func.func @memory_access_pattern_1d_lds_promotion(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + // CHECK: wave.write + // CHECK-SAME: memory_access_pattern = #wave.memory_access_pattern< + // CHECK-SAME: use_lds_promotion = true, + // CHECK-SAME: group_id = "lds_1d", + // CHECK-SAME: lds_block_global_base = <[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + // CHECK-SAME: lds_block_shape = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + // CHECK-SAME: lds_load_indices = <[#wave.index_symbol, #wave.symbol<"VEC_SIZE">] -> (T0 * VEC_SIZE)>, + // CHECK-SAME: lds_load_vector_sizes = <[#wave.symbol<"VEC_SIZE">] -> (VEC_SIZE)>, + // CHECK-SAME: global_store_indices = <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"VEC_SIZE">] -> (WG0 * BLOCK_M + T0 * VEC_SIZE)> + // CHECK-SAME: > + wave.write %value, %mem { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "lds_1d", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"VEC_SIZE">] -> (T0 * VEC_SIZE)>, + lds_load_vector_sizes = #wave.expr_list<[#wave.symbol<"VEC_SIZE">] -> (VEC_SIZE)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"VEC_SIZE">] -> (WG0 * BLOCK_M + T0 * VEC_SIZE)> + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} + +// ----- + +// Test memory access pattern with constant vector size - 1D case +// CHECK-LABEL: @memory_access_pattern_1d_constant_vector_size +func.func @memory_access_pattern_1d_constant_vector_size(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + // CHECK: wave.write + // CHECK-SAME: memory_access_pattern = #wave.memory_access_pattern< + // CHECK-SAME: use_lds_promotion = true, + // CHECK-SAME: group_id = "lds_1d_const", + // CHECK-SAME: lds_block_global_base = <[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + // CHECK-SAME: lds_block_shape = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + // CHECK-SAME: lds_load_indices = <[#wave.index_symbol] -> (T0 * 64)>, + // CHECK-SAME: lds_load_vector_sizes = <[] -> (64)>, + // CHECK-SAME: global_store_indices = <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64)> + // CHECK-SAME: > + wave.write %value, %mem { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "lds_1d_const", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64)> + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} + +// ----- + +// Test memory access pattern with complete LDS promotion - 2D case +// CHECK-LABEL: @memory_access_pattern_2d_lds_promotion +func.func @memory_access_pattern_2d_lds_promotion(%value: !wave.tensor>, %mem: !wave.tensor<[@M, @N] of f32, >) { + // CHECK: wave.write + // CHECK-SAME: memory_access_pattern = #wave.memory_access_pattern< + // CHECK-SAME: use_lds_promotion = true, + // CHECK-SAME: group_id = "lds_2d", + // CHECK-SAME: lds_block_global_base = <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (WG0 * BLOCK_M, WG1 * BLOCK_N)>, + // CHECK-SAME: lds_block_shape = <[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (BLOCK_M, BLOCK_N)>, + // CHECK-SAME: lds_load_indices = <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"VEC_SIZE_M">, #wave.symbol<"VEC_SIZE_N">] -> (T0 * VEC_SIZE_M, T1 * VEC_SIZE_N)>, + // CHECK-SAME: lds_load_vector_sizes = <[#wave.symbol<"VEC_SIZE_M">, #wave.symbol<"VEC_SIZE_N">] -> (VEC_SIZE_M, VEC_SIZE_N)>, + // CHECK-SAME: global_store_indices = <[#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">, #wave.symbol<"VEC_SIZE_M">, #wave.symbol<"VEC_SIZE_N">] -> (WG0 * BLOCK_M + T0 * VEC_SIZE_M, WG1 * BLOCK_N + T1 * VEC_SIZE_N)> + // CHECK-SAME: > + wave.write %value, %mem { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "lds_2d", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (WG0 * BLOCK_M, WG1 * BLOCK_N)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (BLOCK_M, BLOCK_N)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"VEC_SIZE_M">, #wave.symbol<"VEC_SIZE_N">] -> (T0 * VEC_SIZE_M, T1 * VEC_SIZE_N)>, + lds_load_vector_sizes = #wave.expr_list<[#wave.symbol<"VEC_SIZE_M">, #wave.symbol<"VEC_SIZE_N">] -> (VEC_SIZE_M, VEC_SIZE_N)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">, #wave.symbol<"VEC_SIZE_M">, #wave.symbol<"VEC_SIZE_N">] -> (WG0 * BLOCK_M + T0 * VEC_SIZE_M, WG1 * BLOCK_N + T1 * VEC_SIZE_N)> + > + } : !wave.tensor>, !wave.tensor<[@M, @N] of f32, > + return +} diff --git a/water/test/Dialect/Wave/lower-write-group-invalid.mlir b/water/test/Dialect/Wave/lower-write-group-invalid.mlir new file mode 100644 index 000000000..19c7a113a --- /dev/null +++ b/water/test/Dialect/Wave/lower-write-group-invalid.mlir @@ -0,0 +1,48 @@ +// RUN: water-opt %s --split-input-file --lower-wave-to-mlir --verify-diagnostics + +// Test: Operations with the same group_id in different scopes should fail lowering +module attributes {wave.normal_form = #wave.normal_form} { +// expected-error @+1 {{failed to convert starting at this operation}} +func.func @test_different_scope_same_group(%cond: i1, %mem1: !wave.tensor<[@M] of f32, >, %mem2: !wave.tensor<[@M] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 1024, M = 1024}> +} { + %cst = arith.constant 0.0 : f32 + %waveValue1 = wave.register %cst : vector<4xf32> + %waveValue2 = wave.register %cst : vector<4xf32> + + scf.if %cond { + // This operation is inside the scf.if scope + // expected-error @+1 {{failed to legalize operation 'wave.write' that was explicitly marked illegal}} + wave.write %waveValue1, %mem1 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "scope_violation_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 4)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (4)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 4)> + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + } + + // expected-error @+1 {{LDS promotion group 'scope_violation_group' contains operations in different scopes. All operations with the same group_id must be in the same block for correct LDS barrier semantics.}} + wave.write %waveValue2, %mem2 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4 + 256, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "scope_violation_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 4)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (4)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 4 + 256)> + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + + return +} +} diff --git a/water/test/Dialect/Wave/lower-write-group.mlir b/water/test/Dialect/Wave/lower-write-group.mlir new file mode 100644 index 000000000..def33525c --- /dev/null +++ b/water/test/Dialect/Wave/lower-write-group.mlir @@ -0,0 +1,727 @@ +// RUN: water-opt %s --split-input-file --lower-wave-to-mlir | FileCheck %s + +module attributes {wave.normal_form = #wave.normal_form} { + // CHECK-LABEL: @test_basic + func.func @test_basic(%mem1: !wave.tensor<[@M] of f32, >, %mem2: !wave.tensor<[@M] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 1024, M = 1024}> + } { + %cst = arith.constant 0.0 : f32 + %waveValue1 = wave.register %cst : vector<4xf32> + %waveValue2 = wave.register %cst : vector<4xf32> + + // Phase 1: Single LDS allocation for the shared group + // CHECK: %[[SHARED_ALLOC:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> + + // Phase 2: Register→LDS stores with correct index calculations + // Verify affine maps for index calculations: original global index (WG0*1024+T0*4) and LDS base (WG0*1024) + // CHECK: %[[BLOCK_ID:.*]] = gpu.block_id x + // CHECK: %[[THREAD_ID:.*]] = gpu.thread_id x + // CHECK: %[[GLOBAL_IDX:.*]] = affine.apply #{{.*}}()[%[[BLOCK_ID]], %[[THREAD_ID]]] + // CHECK: %[[LDS_BASE:.*]] = affine.apply #{{.*}}()[%{{.*}}] + // CHECK: %[[LDS_STORE_IDX:.*]] = arith.subi %[[GLOBAL_IDX]], %[[LDS_BASE]] : index + // CHECK: vector.store %{{.*}}, %[[SHARED_ALLOC]][%[[LDS_STORE_IDX]]] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + // Second operation with offset (+256) + // CHECK: vector.store %{{.*}}, %[[SHARED_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Phase 3: Single shared barrier + // CHECK: amdgpu.lds_barrier + // Ensure only one barrier for the group + // CHECK-NOT: amdgpu.lds_barrier + + // Phase 4: LDS→Register loads using lds_load_indices (T0*64) + // CHECK: %[[LDS_LOAD_IDX:.*]] = affine.apply #{{.*}}()[%{{.*}}] + // CHECK: %[[LOADED_VEC1:.*]] = vector.load %[[SHARED_ALLOC]][%[[LDS_LOAD_IDX]]] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + // Phase 5: Register→Global stores using global_store_indices (WG0*1024+T0*64) + // CHECK: %[[GLOBAL_STORE_IDX1:.*]] = affine.apply #{{.*}}()[%{{.*}}, %{{.*}}] + // CHECK: vector.store %[[LOADED_VEC1]], %{{.*}}[%[[GLOBAL_STORE_IDX1]]] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + + // Second operation: LDS load and global store with offset (+256) + // CHECK: %[[LOADED_VEC2:.*]] = vector.load %[[SHARED_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + // CHECK: %[[GLOBAL_STORE_IDX2:.*]] = affine.apply #{{.*}}()[%{{.*}}, %{{.*}}] + // CHECK: vector.store %[[LOADED_VEC2]], %{{.*}}[%[[GLOBAL_STORE_IDX2]]] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + + // Verify no wave.write operations remain after lowering + // CHECK-NOT: wave.write + + wave.write %waveValue1, %mem1 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "shared_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64)> + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + + wave.write %waveValue2, %mem2 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4 + 256, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "shared_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64 + 256)> + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + + return + } +} + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { + // CHECK-LABEL: @test_different_groups + func.func @test_different_groups(%mem1: !wave.tensor<[@M] of f32, >, %mem2: !wave.tensor<[@M] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 1024, M = 1024}> + } { + %cst = arith.constant 0.0 : f32 + %waveValue1 = wave.register %cst : vector<4xf32> + %waveValue2 = wave.register %cst : vector<4xf32> + + // First group (group_a): allocation, register→LDS, barrier, LDS→register, register→global + // CHECK: %[[ALLOC_A:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> + // CHECK: vector.store %{{.*}}, %[[ALLOC_A]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + // CHECK: amdgpu.lds_barrier + // CHECK: %{{.*}} = vector.load %[[ALLOC_A]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + + // Second group (group_b): separate allocation, register→LDS, barrier, LDS→register, register→global + // CHECK: %[[ALLOC_B:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> + // CHECK: vector.store %{{.*}}, %[[ALLOC_B]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + // CHECK: amdgpu.lds_barrier + // CHECK: %{{.*}} = vector.load %[[ALLOC_B]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + + // Verify no wave.write operations remain + // CHECK-NOT: wave.write + + wave.write %waveValue1, %mem1 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "group_a", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64)> + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + + wave.write %waveValue2, %mem2 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "group_b", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64)> + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + + return + } +} + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { + // CHECK-LABEL: @test_mixed_lds_and_regular_writes + func.func @test_mixed_lds_and_regular_writes(%mem1: !wave.tensor<[@M] of f32, >, %mem2: !wave.tensor<[@M] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 1024, M = 1024}> + } { + %cst = arith.constant 0.0 : f32 + %waveValue1 = wave.register %cst : vector<4xf32> + %waveValue2 = wave.register %cst : vector<4xf32> + + // First operation uses LDS promotion: allocation, register→LDS, barrier, LDS→register, register→global + // CHECK: %[[MIXED_ALLOC:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> + // CHECK: vector.store %{{.*}}, %[[MIXED_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + // CHECK: amdgpu.lds_barrier + // CHECK: %{{.*}} = vector.load %[[MIXED_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + + // Second operation uses regular write - direct register→global store (no LDS) + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Verify no wave.write operations remain + // CHECK-NOT: wave.write + + wave.write %waveValue1, %mem1 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "lds_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64)> + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + + wave.write %waveValue2, %mem2 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = false, + group_id = "regular_group" + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + + return + } +} + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { + // CHECK-LABEL: @test_nested_lds_promotion_group + func.func @test_nested_lds_promotion_group(%cond: i1, %mem1: !wave.tensor<[@M] of f32, >, %mem2: !wave.tensor<[@M] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 1024, M = 1024}> + } { + %cst = arith.constant 0.0 : f32 + %waveValue1 = wave.register %cst : vector<4xf32> + %waveValue2 = wave.register %cst : vector<4xf32> + + scf.if %cond { + // Nested LDS promotion group should generate: allocation, register→LDS stores, barrier, LDS→register loads, register→global stores + // CHECK: %[[NESTED_ALLOC:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> + // CHECK: vector.store %{{.*}}, %[[NESTED_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + // CHECK: vector.store %{{.*}}, %[[NESTED_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + // CHECK: amdgpu.lds_barrier + // CHECK: %{{.*}} = vector.load %[[NESTED_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + // CHECK: %{{.*}} = vector.load %[[NESTED_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + + wave.write %waveValue1, %mem1 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "nested_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64)> + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + + wave.write %waveValue2, %mem2 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4 + 256, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "nested_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64 + 256)> + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + } + + return + } +} + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { + // CHECK-LABEL: @test_lds_promotion_group_same_scope + func.func @test_lds_promotion_group_same_scope(%cond: i1, %mem1: !wave.tensor<[@M] of f32, >, %mem2: !wave.tensor<[@M] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 1024, M = 1024}> + } { + %cst = arith.constant 0.0 : f32 + %waveValue1 = wave.register %cst : vector<4xf32> + %waveValue2 = wave.register %cst : vector<4xf32> + + scf.if %cond { + // Both operations are in the same nested scope (scf.if block) - this should work correctly + // Single LDS allocation for the shared group + // CHECK: %[[SAME_SCOPE_ALLOC:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> + + // Register→LDS stores + // CHECK: vector.store %{{.*}}, %[[SAME_SCOPE_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + // CHECK: vector.store %{{.*}}, %[[SAME_SCOPE_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + + // Single barrier for the group + // CHECK: amdgpu.lds_barrier + + // LDS→Register loads and Register→Global stores + // CHECK: %{{.*}} = vector.load %[[SAME_SCOPE_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + // CHECK: %{{.*}} = vector.load %[[SAME_SCOPE_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<64xf32> + + wave.write %waveValue1, %mem1 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "same_scope_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64)> + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + + wave.write %waveValue2, %mem2 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4 + 256, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "same_scope_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64 + 256)> + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + } + + return + } +} + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { + // CHECK-LABEL: @test_multidim_lds_promotion + func.func @test_multidim_lds_promotion(%mem: !wave.tensor<[@M, @N] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 64, BLOCK_N = 64, M = 1024, N = 1024}> + } { + %cst = arith.constant 0.0 : f32 + %value = wave.register %cst : vector<8xf32> + + // Tests 2D LDS allocation and lowering + // Phase 1: 2D LDS allocation (BLOCK_M × BLOCK_N = 64×64 = 4096 elements) + // CHECK: %[[ALLOC_2D:.*]] = memref.alloc() : memref<64x64xf32, #gpu.address_space> + + // Phase 2: Register→LDS store with 2D indexing + // Original 2D access: (WG0*BLOCK_M + T0*8, WG1*BLOCK_N + T1*1) + // LDS store: subtract 2D base (WG0*BLOCK_M, WG1*BLOCK_N) to get local LDS coordinates + // CHECK-DAG: vector.transfer_write %{{.*}}, %[[ALLOC_2D]][%{{.*}}, %{{.*}}] {{.*}} : vector<8xf32>, memref<64x64xf32, #gpu.address_space> + + // Phase 3: Barrier synchronization + // CHECK: amdgpu.lds_barrier + + // Phase 4: LDS→Register load with vectorized 2D pattern (T0*8, T1*4) + // CHECK-DAG: %[[REG:.*]] = vector.transfer_read %[[ALLOC_2D]][%{{.*}}, %{{.*}}], %{{.*}} {{.*}} : memref<64x64xf32, #gpu.address_space>, vector<8xf32> + + // Phase 5: Register→Global store with 2D coordinates (WG0*BLOCK_M + T0*8, WG1*BLOCK_N + T1*4) + // CHECK-DAG: vector.transfer_write %[[REG]], %{{.*}}[%{{.*}}, %{{.*}}] {{.*}} : vector<8xf32>, memref<1024x1024xf32, #gpu.address_space> + + // Verify no wave.write operations remain + // CHECK-NOT: wave.write + + wave.write %value, %mem index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 8, 8, 1), + N : [#wave.symbol<"BLOCK_N">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_N * WG1 + T1, 1, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "multidim_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (WG0 * BLOCK_M, WG1 * BLOCK_N)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (BLOCK_M, BLOCK_N)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol] -> (T0 * 8, T1 * 4)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (8, 4)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (WG0 * BLOCK_M + T0 * 8, WG1 * BLOCK_N + T1 * 4)> + > + } : vector<8xf32>, !wave.tensor<[@M, @N] of f32, > + + return + } +} + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { + // CHECK-LABEL: @test_variable_vector_sizes + func.func @test_variable_vector_sizes(%mem: !wave.tensor<[@M] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 1024, VEC_SIZE = 32, M = 4096}> + } { + %cst = arith.constant 0.0 : f32 + %value = wave.register %cst : vector<32xf32> + + // Variable vector size test: LDS allocation with symbolic VEC_SIZE (32) + // CHECK: %[[VAR_ALLOC:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> + // CHECK: vector.store %{{.*}}, %[[VAR_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<32xf32> + // CHECK: amdgpu.lds_barrier + // CHECK: %{{.*}} = vector.load %[[VAR_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<32xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<4096xf32, #gpu.address_space>, vector<32xf32> + // CHECK-NOT: wave.write + + wave.write %value, %mem index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol, #wave.symbol<"VEC_SIZE">] -> (BLOCK_M * WG0 + T0 * VEC_SIZE, VEC_SIZE, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "variable_vec_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"VEC_SIZE">] -> (T0 * VEC_SIZE)>, + lds_load_vector_sizes = #wave.expr_list<[#wave.symbol<"VEC_SIZE">] -> (VEC_SIZE)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"VEC_SIZE">] -> (WG0 * BLOCK_M + T0 * VEC_SIZE)> + > + } : vector<32xf32>, !wave.tensor<[@M] of f32, > + + return + } +} + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { + // CHECK-LABEL: @test_complex_expressions + func.func @test_complex_expressions(%mem: !wave.tensor<[@M] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 1024, OFFSET = 128, STRIDE = 16, M = 8192}> + } { + %cst = arith.constant 0.0 : f32 + %value = wave.register %cst : vector<16xf32> + + // Complex expression test: LDS allocation with OFFSET and STRIDE symbolics + // CHECK: %[[COMPLEX_ALLOC:.*]] = memref.alloc() : memref<1152xf32, #gpu.address_space> + // CHECK: vector.store %{{.*}}, %[[COMPLEX_ALLOC]][%{{.*}}] : memref<1152xf32, #gpu.address_space>, vector<16xf32> + // CHECK: amdgpu.lds_barrier + // CHECK: %{{.*}} = vector.load %[[COMPLEX_ALLOC]][%{{.*}}] : memref<1152xf32, #gpu.address_space>, vector<16xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<8192xf32, #gpu.address_space>, vector<16xf32> + // CHECK-NOT: wave.write + + wave.write %value, %mem index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol, #wave.symbol<"OFFSET">, #wave.symbol<"STRIDE">] -> (BLOCK_M * WG0 + OFFSET + T0 * STRIDE + 8, STRIDE, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "complex_expr_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"OFFSET">] -> (WG0 * BLOCK_M + OFFSET)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">, #wave.symbol<"OFFSET">] -> (BLOCK_M + OFFSET)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"STRIDE">] -> (T0 * STRIDE + 8)>, + lds_load_vector_sizes = #wave.expr_list<[#wave.symbol<"STRIDE">] -> (STRIDE)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"OFFSET">, #wave.symbol<"STRIDE">] -> (WG0 * BLOCK_M + OFFSET + T0 * STRIDE + 8)> + > + } : vector<16xf32>, !wave.tensor<[@M] of f32, > + + return + } +} + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { + // CHECK-LABEL: @test_no_lds_promo + func.func @test_no_lds_promo(%mem1: !wave.tensor<[@M] of f32, >, %mem2: !wave.tensor<[@M] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 512, M = 2048}> + } { + %cst = arith.constant 0.0 : f32 + %value1 = wave.register %cst : vector<4xf32> + %value2 = wave.register %cst : vector<4xf32> + + // First operation uses regular write (no LDS promotion) + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<4xf32> + + // Second operation has no memory access pattern - should use default vector write + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<4xf32> + + // Verify no wave.write operations remain + // CHECK-NOT: wave.write + wave.write %value1, %mem1 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = false, + group_id = "regular_group" + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + + // Second operation has no memory access pattern at all - should use default vector write + // CHECK-NOT: memory_access_pattern + wave.write %value2, %mem2 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4, 4, 1) + }] : vector<4xf32>, !wave.tensor<[@M] of f32, > + + return + } + } + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { + // Verify the correct affine maps are generated with concrete stride values + // CHECK: #[[MAP_GLOBAL_256:.*]] = affine_map<()[s0, s1] -> (s0 * 256 + s1 * 8)> + // CHECK: #[[MAP_LDS_BASE_256:.*]] = affine_map<()[s0] -> (s0 * 256)> + // CHECK: #[[MAP_LDS_LOAD_16:.*]] = affine_map<()[s0] -> (s0 * 16)> + // CHECK: #[[MAP_GLOBAL_STORE_16:.*]] = affine_map<()[s0, s1] -> (s0 * 256 + s1 * 16)> + + func.func @verify_index_math(%mem: !wave.tensor<[@M] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 256, M = 2048}> + } { + %cst = arith.constant 0.0 : f32 + %value = wave.register %cst : vector<8xf32> + + // LDS block size should be BLOCK_M = 256 + // CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<256xf32, #gpu.address_space> + + // Original index: BLOCK_M * WG0 + T0 * 8 = 256 * WG0 + T0 * 8 + // CHECK: %[[GLOBAL_IDX:.*]] = affine.apply #[[MAP_GLOBAL_256]]()[%{{.*}}, %{{.*}}] + + // LDS base: WG0 * BLOCK_M = WG0 * 256 + // CHECK: %[[LDS_BASE:.*]] = affine.apply #[[MAP_LDS_BASE_256]]()[%{{.*}}] + + // LDS store index: (256*WG0 + T0*8) - (256*WG0) = T0*8 (local offset) + // CHECK: %[[LDS_STORE_IDX:.*]] = arith.subi %[[GLOBAL_IDX]], %[[LDS_BASE]] : index + // CHECK: vector.store %{{.*}}, %[[ALLOC]][%[[LDS_STORE_IDX]]] : memref<256xf32, #gpu.address_space>, vector<8xf32> + + // CHECK: amdgpu.lds_barrier + + // LDS load index: T0 * 16 (different stride for vectorized access) + // CHECK: %[[LDS_LOAD_IDX:.*]] = affine.apply #[[MAP_LDS_LOAD_16]]()[%{{.*}}] + // CHECK: %[[LOADED:.*]] = vector.load %[[ALLOC]][%[[LDS_LOAD_IDX]]] : memref<256xf32, #gpu.address_space>, vector<16xf32> + + // Verify constants are properly materialized during materialization + // CHECK: %{{.*}} = arith.constant 256 : index + + // Global store index: WG0 * 256 + T0 * 16 + // CHECK: %[[GLOBAL_STORE_IDX:.*]] = affine.apply #[[MAP_GLOBAL_STORE_16]]()[%{{.*}}, %{{.*}}] + // CHECK: vector.store %[[LOADED]], %{{.*}}[%[[GLOBAL_STORE_IDX]]] : memref<2048xf32, #gpu.address_space>, vector<16xf32> + + wave.write %value, %mem index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 8, 8, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "math_verification_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 16)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (16)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 16)> + > + } : vector<8xf32>, !wave.tensor<[@M] of f32, > + + return + } + } + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { + // CHECK-LABEL: @test_mixed_vector_sizes_same_group + func.func @test_mixed_vector_sizes_same_group(%mem1: !wave.tensor<[@M] of f32, >, %mem2: !wave.tensor<[@M] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 1024, M = 1024}> + } { + %cst = arith.constant 0.0 : f32 + %value1 = wave.register %cst : vector<4xf32> // 4-element vector + %value2 = wave.register %cst : vector<8xf32> // 8-element vector (different size) + + // Single LDS allocation for the mixed-size group + // CHECK: %[[MIXED_SIZE_ALLOC:.*]] = memref.alloc() : memref<1024xf32, #gpu.address_space> + + // First operation with 4-element vector + // CHECK: vector.store %{{.*}}, %[[MIXED_SIZE_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<4xf32> + wave.write %value1, %mem1 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 4, 4, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "mixed_vector_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 32)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (32)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 32)> + > + } : vector<4xf32>, !wave.tensor<[@M] of f32, > + + // Second operation with 8-element vector (different input vector size, but same LDS vectorization) + // CHECK: vector.store %{{.*}}, %[[MIXED_SIZE_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<8xf32> + wave.write %value2, %mem2 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 8 + 256, 8, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "mixed_vector_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 32)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (32)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 32 + 256)> + > + } : vector<8xf32>, !wave.tensor<[@M] of f32, > + + // Single barrier for the entire mixed-size group + // CHECK: amdgpu.lds_barrier + + // LDS→register loads and register→global stores (same vectorization for both operations) + // CHECK: %{{.*}} = vector.load %[[MIXED_SIZE_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<32xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<32xf32> + // CHECK: %{{.*}} = vector.load %[[MIXED_SIZE_ALLOC]][%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<32xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<1024xf32, #gpu.address_space>, vector<32xf32> + + // Verify no wave.write operations remain + // CHECK-NOT: wave.write + + return + } + } + +// ----- + +module attributes {wave.normal_form = #wave.normal_form} { + // CHECK-LABEL: @test_large_group + func.func @test_large_group(%mem1: !wave.tensor<[@M] of f32, >, %mem2: !wave.tensor<[@M] of f32, >, %mem3: !wave.tensor<[@M] of f32, >, %mem4: !wave.tensor<[@M] of f32, >, %mem5: !wave.tensor<[@M] of f32, >, %mem6: !wave.tensor<[@M] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 2048, M = 2048}> + } { + %cst = arith.constant 0.0 : f32 + %value1 = wave.register %cst : vector<8xf32> + %value2 = wave.register %cst : vector<8xf32> + %value3 = wave.register %cst : vector<8xf32> + %value4 = wave.register %cst : vector<8xf32> + %value5 = wave.register %cst : vector<8xf32> + %value6 = wave.register %cst : vector<8xf32> + + // Single LDS allocation for the large stress test group (6 operations) + // CHECK: %[[STRESS_ALLOC:.*]] = memref.alloc() : memref<2048xf32, #gpu.address_space> + + // All 6 register→LDS stores should happen first + // CHECK: vector.store %{{.*}}, %[[STRESS_ALLOC]][%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<8xf32> + // CHECK: vector.store %{{.*}}, %[[STRESS_ALLOC]][%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<8xf32> + // CHECK: vector.store %{{.*}}, %[[STRESS_ALLOC]][%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<8xf32> + // CHECK: vector.store %{{.*}}, %[[STRESS_ALLOC]][%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<8xf32> + // CHECK: vector.store %{{.*}}, %[[STRESS_ALLOC]][%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<8xf32> + // CHECK: vector.store %{{.*}}, %[[STRESS_ALLOC]][%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<8xf32> + + // Single barrier for the entire large group + // CHECK: amdgpu.lds_barrier + // Ensure only one barrier despite large number of operations + // CHECK-NOT: amdgpu.lds_barrier + + // All 6 LDS→register loads and register→global stores + // CHECK: %{{.*}} = vector.load %[[STRESS_ALLOC]][%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<64xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<64xf32> + // CHECK: %{{.*}} = vector.load %[[STRESS_ALLOC]][%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<64xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<64xf32> + // CHECK: %{{.*}} = vector.load %[[STRESS_ALLOC]][%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<64xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<64xf32> + // CHECK: %{{.*}} = vector.load %[[STRESS_ALLOC]][%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<64xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<64xf32> + // CHECK: %{{.*}} = vector.load %[[STRESS_ALLOC]][%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<64xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<64xf32> + // CHECK: %{{.*}} = vector.load %[[STRESS_ALLOC]][%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<64xf32> + // CHECK: vector.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<2048xf32, #gpu.address_space>, vector<64xf32> + + // Operation 1 + wave.write %value1, %mem1 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 8, 8, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "stress_test_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64)> + > + } : vector<8xf32>, !wave.tensor<[@M] of f32, > + + // Operation 2 + wave.write %value2, %mem2 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 8 + 256, 8, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "stress_test_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64 + 256)> + > + } : vector<8xf32>, !wave.tensor<[@M] of f32, > + + // Operation 3 + wave.write %value3, %mem3 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 8 + 512, 8, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "stress_test_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64 + 512)> + > + } : vector<8xf32>, !wave.tensor<[@M] of f32, > + + // Operation 4 + wave.write %value4, %mem4 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 8 + 768, 8, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "stress_test_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64 + 768)> + > + } : vector<8xf32>, !wave.tensor<[@M] of f32, > + + // Operation 5 + wave.write %value5, %mem5 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 8 + 1024, 8, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "stress_test_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64 + 1024)> + > + } : vector<8xf32>, !wave.tensor<[@M] of f32, > + + // Operation 6 + wave.write %value6, %mem6 index [{ + M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + T0 * 8 + 1280, 8, 1) + }] { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "stress_test_group", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64 + 1280)> + > + } : vector<8xf32>, !wave.tensor<[@M] of f32, > + + // Verify no wave.write operations remain + // CHECK-NOT: wave.write + + return + } +}