diff --git a/water/include/water/Dialect/Wave/IR/WaveAttrs.td b/water/include/water/Dialect/Wave/IR/WaveAttrs.td index 3ac37b7b3..0b3f74c6d 100644 --- a/water/include/water/Dialect/Wave/IR/WaveAttrs.td +++ b/water/include/water/Dialect/Wave/IR/WaveAttrs.td @@ -476,4 +476,169 @@ def WaveReadWriteBoundsAttr : AttrDef { }]; } +def WaveMemoryAccessPatternAttr : AttrDef { + let mnemonic = "memory_access_pattern"; + let description = [{ + This attribute specifies how memory access should be handled during lowering, + particularly for operations that may require LDS (Local Data Store) promotion. + + LDS promotion transforms inefficient scalar memory accesses into efficient vectorized + accesses by using Local Data Store (LDS) as an intermediate buffer. The transformation + converts: + + **Original Pattern**: Register -> Global Memory (scalar, potentially uncoalesced) + + **LDS Promotion Pattern**: + 1. Register -> LDS (scalar stores, same as original pattern but to LDS) + 2. LDS -> Register (vectorized loads from LDS) + 3. Register -> Global Memory (vectorized stores to global memory) + + ## Index Transformation Logic + + ### Step 1: Register -> LDS Store + The original global memory indices are transformed to LDS indices by subtracting + the LDS block's base address in global memory coordinates: + ``` + lds_store_indices = original_global_indices - lds_block_global_base + ``` + + Example: + - Original: `global[WG0 * BLOCK_M + T0 * 4 + offset]` + - LDS base: `WG0 * BLOCK_M` + - LDS store: `lds[(WG0 * BLOCK_M + T0 * 4 + offset) - (WG0 * BLOCK_M)] = lds[T0 * 4 + offset]` + + ### Step 2: LDS -> Register Load (Vectorized) + Uses `lds_load_mapping` to define vectorized access patterns within the LDS: + ``` + lds_load_start = start_expr(thread_indices) + lds_load_vector_size = step_expr(thread_indices) + lds_load_stride = stride_expr(thread_indices) + ``` + + ### Step 3: Register -> Global Store (Vectorized) + Uses `global_store_mapping` to define vectorized stores back to global memory: + ``` + global_store_start = start_expr(thread_indices) + global_store_vector_size = step_expr(thread_indices) + global_store_stride = stride_expr(thread_indices) + ``` + + ## Complete Example + ```mlir + #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "mfma_result_0", + + // LDS block placement: where in global memory this LDS block maps to + lds_block_global_base = #wave.expr_list< + [#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M) + >, + + // LDS allocation size + lds_block_shape = #wave.expr_list< + [#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (BLOCK_M, BLOCK_N) + >, + + // Vectorized LDS -> Register: each thread loads 64 elements starting at T0*64 + lds_load_mapping = #wave.index_mapping< + [#wave.index_symbol] -> (T0 * 64, 64, 1) + >, + + // Vectorized Register -> Global: each thread stores 64 elements + global_store_mapping = #wave.index_mapping< + [#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] + -> (WG0 * BLOCK_M + T0 * 64, 64, 1) + > + > + ``` + + ## Parameters: + - **use_lds_promotion**: Whether to enable LDS promotion transformation + - **group_id**: String identifier for grouping related operations that share the same LDS allocation + - **lds_block_global_base**: Global memory base address that this LDS block represents (for index transformation) + - **lds_block_shape**: Shape/size of the LDS block allocation in elements + - **lds_load_mapping**: Index mapping for vectorized loads from LDS (start, vector_size, stride) + - **global_store_mapping**: Index mapping for vectorized stores to global memory (start, vector_size, stride) + + ## Verification Requirements: + - All index mappings must have the same rank as the original global memory tensor + - lds_block_global_base and lds_block_shape must have consistent ranks + - When LDS promotion is enabled, all LDS-related parameters must be specified + }]; + + let parameters = (ins + "bool":$use_lds_promotion, + StringRefParameter<"group identifier for shared LDS allocation">:$group_id, + + // LDS block placement information + OptionalParameter<"::wave::WaveExprListAttr">:$lds_block_global_base, + OptionalParameter<"::wave::WaveExprListAttr">:$lds_block_shape, + + // Vectorized access patterns - split into indices and vector sizes for simplicity + OptionalParameter<"::wave::WaveExprListAttr">:$lds_load_indices, + OptionalParameter<"::wave::WaveExprListAttr">:$lds_load_vector_sizes, + OptionalParameter<"::wave::WaveExprListAttr">:$global_store_indices + ); + + let assemblyFormat = [{ + `<` `use_lds_promotion` `=` $use_lds_promotion + `,` `group_id` `=` $group_id + (`,` `lds_block_global_base` `=` $lds_block_global_base^)? + (`,` `lds_block_shape` `=` $lds_block_shape^)? + (`,` `lds_load_indices` `=` $lds_load_indices^)? + (`,` `lds_load_vector_sizes` `=` $lds_load_vector_sizes^)? + (`,` `global_store_indices` `=` $global_store_indices^)? `>` + }]; + + let genVerifyDecl = 1; + + let extraClassDeclaration = [{ + /// Check if LDS promotion is enabled. + bool shouldUseLdsPromotion() const { return getUseLdsPromotion(); } + + /// Check if complete LDS promotion information is provided. + bool hasCompleteLdsPromotionInfo() const { + return shouldUseLdsPromotion() && + getLdsBlockGlobalBase() && + getLdsBlockShape() && + getLdsLoadIndices() && + getLdsLoadVectorSizes() && + getGlobalStoreIndices(); + } + + /// Check if any LDS promotion parameters are specified. + bool hasAnyLdsPromotionParams() const { + return getLdsBlockGlobalBase() || + getLdsBlockShape() || + getLdsLoadIndices() || + getLdsLoadVectorSizes() || + getGlobalStoreIndices(); + } + + /// Get the rank/dimensionality of the LDS block. + unsigned getLdsBlockRank() const { + if (!getLdsBlockShape()) return 0; + return getLdsBlockShape().getRank(); + } + + /// Get the rank of the global base address expression. + unsigned getGlobalBaseRank() const { + if (!getLdsBlockGlobalBase()) return 0; + return getLdsBlockGlobalBase().getRank(); + } + + /// Get the rank of the LDS load indices. + unsigned getLdsLoadIndicesRank() const { + if (!getLdsLoadIndices()) return 0; + return getLdsLoadIndices().getRank(); + } + + /// Get the rank of the global store indices. + unsigned getGlobalStoreIndicesRank() const { + if (!getGlobalStoreIndices()) return 0; + return getGlobalStoreIndices().getRank(); + } + }]; +} + #endif // WATER_DIALECT_WAVE_WAVEATTRS diff --git a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp index e7dd1b5d4..a2284364b 100644 --- a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp +++ b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp @@ -456,6 +456,127 @@ DeviceConstraintAttr::verify(function_ref emitError, return success(); } +//===----------------------------------------------------------------------===// +// WaveMemoryAccessPatternAttr +//===----------------------------------------------------------------------===// + +LogicalResult +WaveMemoryAccessPatternAttr::verify(function_ref emitError, + bool use_lds_promotion, StringRef group_id, + WaveExprListAttr lds_block_global_base, + WaveExprListAttr lds_block_shape, + WaveExprListAttr lds_load_indices, + WaveExprListAttr lds_load_vector_sizes, + WaveExprListAttr global_store_indices) { + // Validate group_id is not empty + if (group_id.empty()) { + return emitError() << "group_id cannot be empty"; + } + + // When LDS promotion is disabled, no LDS-related parameters should be specified + if (!use_lds_promotion) { + if (lds_block_global_base || lds_block_shape || lds_load_indices || + lds_load_vector_sizes || global_store_indices) { + return emitError() << "LDS promotion parameters should not be specified when use_lds_promotion=false"; + } + return success(); + } + + // When LDS promotion is enabled, validate completeness and consistency + bool hasLdsBase = static_cast(lds_block_global_base); + bool hasLdsShape = static_cast(lds_block_shape); + bool hasLdsLoadIndices = static_cast(lds_load_indices); + bool hasLdsLoadVectorSizes = static_cast(lds_load_vector_sizes); + bool hasGlobalStoreIndices = static_cast(global_store_indices); + + // Check for partial specification - either all or none should be provided + if (hasLdsBase || hasLdsShape || hasLdsLoadIndices || hasLdsLoadVectorSizes || + hasGlobalStoreIndices) { + if (!hasLdsBase || !hasLdsShape || !hasLdsLoadIndices || !hasLdsLoadVectorSizes || + !hasGlobalStoreIndices) { + return emitError() << "when LDS promotion is enabled, all LDS parameters must be specified: " + "lds_block_global_base, lds_block_shape, lds_load_indices, lds_load_vector_sizes, " + "global_store_indices"; + } + } + + // If all LDS parameters are provided, perform detailed validation + if (hasLdsBase && hasLdsShape && hasLdsLoadIndices && hasLdsLoadVectorSizes && + hasGlobalStoreIndices) { + + // Validate that lds_block_global_base and lds_block_shape have consistent ranks + unsigned ldsBaseRank = lds_block_global_base.getRank(); + unsigned ldsShapeRank = lds_block_shape.getRank(); + + if (ldsBaseRank != ldsShapeRank) { + return emitError() << "lds_block_global_base rank (" << ldsBaseRank + << ") must match lds_block_shape rank (" << ldsShapeRank << ")"; + } + + // Validate that load indices and vector sizes have consistent ranks + unsigned ldsLoadIndicesRank = lds_load_indices.getRank(); + unsigned ldsLoadVectorSizesRank = lds_load_vector_sizes.getRank(); + unsigned globalStoreIndicesRank = global_store_indices.getRank(); + + if (ldsLoadIndicesRank != ldsLoadVectorSizesRank) { + return emitError() << "lds_load_indices rank (" << ldsLoadIndicesRank + << ") must match lds_load_vector_sizes rank (" << ldsLoadVectorSizesRank << ")"; + } + + if (ldsBaseRank != ldsLoadIndicesRank) { + return emitError() << "LDS block rank (" << ldsBaseRank + << ") must match LDS load indices rank (" << ldsLoadIndicesRank << ")"; + } + + if (ldsBaseRank != globalStoreIndicesRank) { + return emitError() << "LDS block rank (" << ldsBaseRank + << ") must match global store indices rank (" << globalStoreIndicesRank << ")"; + } + + // Validate that all symbols are WaveSymbolAttr or WaveIndexSymbolAttr + if (!llvm::all_of(lds_block_global_base.getSymbols(), + llvm::IsaPred)) { + return emitError() << "lds_block_global_base must only contain WaveSymbolAttr or WaveIndexSymbolAttr"; + } + + if (!llvm::all_of(lds_block_shape.getSymbols(), + llvm::IsaPred)) { + return emitError() << "lds_block_shape must only contain WaveSymbolAttr or WaveIndexSymbolAttr"; + } + + if (!llvm::all_of(lds_load_indices.getSymbols(), + llvm::IsaPred)) { + return emitError() << "lds_load_indices must only contain WaveSymbolAttr or WaveIndexSymbolAttr"; + } + + if (!llvm::all_of(lds_load_vector_sizes.getSymbols(), + llvm::IsaPred)) { + return emitError() << "lds_load_vector_sizes must only contain WaveSymbolAttr or WaveIndexSymbolAttr"; + } + + if (!llvm::all_of(global_store_indices.getSymbols(), + llvm::IsaPred)) { + return emitError() << "global_store_indices must only contain WaveSymbolAttr or WaveIndexSymbolAttr"; + } + + // Validate that mappings have at least one dimension + if (ldsBaseRank == 0) { + return emitError() << "LDS block must have at least one dimension"; + } + + // Note: We cannot validate that the ranks match the original global memory tensor rank here + // because this attribute verification doesn't have access to the WriteOp's memory operand. + // This validation should be performed in the WriteOp's verifier where both the attribute + // and the memory operand type are available. + // + // Additionally, data coverage verification (ensuring that the collective workgroup access + // pattern covers exactly the same elements before and after LDS promotion) should be + // performed in the WriteOp verifier where access to the original index mapping is available. + } + + return success(); +} + void wave::WaveDialect::registerAttributes() { addAttributes< #define GET_ATTRDEF_LIST diff --git a/water/test/Dialect/Wave/attr-memory-access-pattern-invalid.mlir b/water/test/Dialect/Wave/attr-memory-access-pattern-invalid.mlir new file mode 100644 index 000000000..06c264bdc --- /dev/null +++ b/water/test/Dialect/Wave/attr-memory-access-pattern-invalid.mlir @@ -0,0 +1,84 @@ +// RUN: water-opt %s -split-input-file -verify-diagnostics + +// Test: empty group_id should fail +func.func @memory_access_pattern_empty_group_id(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + wave.write %value, %mem { + // expected-error @+1 {{group_id cannot be empty}} + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = false, + group_id = "" + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} + +// ----- + +// Test: LDS parameters specified when use_lds_promotion=false should fail +func.func @memory_access_pattern_lds_params_when_disabled(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + wave.write %value, %mem { + // expected-error @+1 {{LDS promotion parameters should not be specified when use_lds_promotion=false}} + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = false, + group_id = "test", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol] -> (WG0)> + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} + +// ----- + +// Test: Partial LDS specification when use_lds_promotion=true should fail +func.func @memory_access_pattern_partial_lds_specification(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + wave.write %value, %mem { + // expected-error @+1 {{when LDS promotion is enabled, all LDS parameters must be specified: lds_block_global_base, lds_block_shape, lds_load_indices, lds_load_vector_sizes, global_store_indices}} + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "test", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol] -> (WG0)> + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} + +// ----- + +// Test: Mismatched ranks between lds_block_global_base and lds_block_shape should fail +func.func @memory_access_pattern_mismatched_base_shape_ranks(%value: !wave.tensor>, %mem: !wave.tensor<[@M, @N] of f32, >) { + wave.write %value, %mem { + // expected-error @+1 {{lds_block_global_base rank (1) must match lds_block_shape rank (2)}} + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "test", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol] -> (WG0)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (BLOCK_M, BLOCK_N)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol] -> (T0)> + > + } : !wave.tensor>, !wave.tensor<[@M, @N] of f32, > + return +} + +// ----- + +// Test: Mismatched ranks between lds_load_indices and lds_load_vector_sizes should fail +func.func @memory_access_pattern_mismatched_lds_load_ranks(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + wave.write %value, %mem { + // expected-error @+1 {{lds_load_indices rank (1) must match lds_load_vector_sizes rank (2)}} + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "test", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol] -> (WG0)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0)>, + lds_load_vector_sizes = #wave.expr_list<[#wave.symbol<"VEC_M">, #wave.symbol<"VEC_N">] -> (VEC_M, VEC_N)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol] -> (T0)> + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} + + + diff --git a/water/test/Dialect/Wave/attr-memory-access-pattern.mlir b/water/test/Dialect/Wave/attr-memory-access-pattern.mlir new file mode 100644 index 000000000..0601e63cf --- /dev/null +++ b/water/test/Dialect/Wave/attr-memory-access-pattern.mlir @@ -0,0 +1,106 @@ +// RUN: water-opt %s -split-input-file | water-opt | FileCheck %s +// RUN: water-opt %s -split-input-file --mlir-print-op-generic | water-opt | FileCheck %s + +// Test basic memory access pattern parsing without LDS promotion +// CHECK-LABEL: @memory_access_pattern_basic_no_promotion +func.func @memory_access_pattern_basic_no_promotion(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + // CHECK: wave.write + // CHECK-SAME: memory_access_pattern = #wave.memory_access_pattern< + // CHECK-SAME: use_lds_promotion = false, + // CHECK-SAME: group_id = "basic_test" + // CHECK-SAME: > + wave.write %value, %mem { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = false, + group_id = "basic_test" + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} + +// ----- + +// Test memory access pattern with complete LDS promotion - 1D case with symbolic vector size +// CHECK-LABEL: @memory_access_pattern_1d_lds_promotion +func.func @memory_access_pattern_1d_lds_promotion(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + // CHECK: wave.write + // CHECK-SAME: memory_access_pattern = #wave.memory_access_pattern< + // CHECK-SAME: use_lds_promotion = true, + // CHECK-SAME: group_id = "lds_1d", + // CHECK-SAME: lds_block_global_base = <[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + // CHECK-SAME: lds_block_shape = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + // CHECK-SAME: lds_load_indices = <[#wave.index_symbol, #wave.symbol<"VEC_SIZE">] -> (T0 * VEC_SIZE)>, + // CHECK-SAME: lds_load_vector_sizes = <[#wave.symbol<"VEC_SIZE">] -> (VEC_SIZE)>, + // CHECK-SAME: global_store_indices = <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"VEC_SIZE">] -> (WG0 * BLOCK_M + T0 * VEC_SIZE)> + // CHECK-SAME: > + wave.write %value, %mem { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "lds_1d", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"VEC_SIZE">] -> (T0 * VEC_SIZE)>, + lds_load_vector_sizes = #wave.expr_list<[#wave.symbol<"VEC_SIZE">] -> (VEC_SIZE)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"VEC_SIZE">] -> (WG0 * BLOCK_M + T0 * VEC_SIZE)> + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} + +// ----- + +// Test memory access pattern with constant vector size - 1D case +// CHECK-LABEL: @memory_access_pattern_1d_constant_vector_size +func.func @memory_access_pattern_1d_constant_vector_size(%value: !wave.tensor>, %mem: !wave.tensor<[@M] of f32, >) { + // CHECK: wave.write + // CHECK-SAME: memory_access_pattern = #wave.memory_access_pattern< + // CHECK-SAME: use_lds_promotion = true, + // CHECK-SAME: group_id = "lds_1d_const", + // CHECK-SAME: lds_block_global_base = <[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + // CHECK-SAME: lds_block_shape = <[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + // CHECK-SAME: lds_load_indices = <[#wave.index_symbol] -> (T0 * 64)>, + // CHECK-SAME: lds_load_vector_sizes = <[] -> (64)>, + // CHECK-SAME: global_store_indices = <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64)> + // CHECK-SAME: > + wave.write %value, %mem { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "lds_1d_const", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">] -> (BLOCK_M)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol] -> (T0 * 64)>, + lds_load_vector_sizes = #wave.expr_list<[] -> (64)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + T0 * 64)> + > + } : !wave.tensor>, !wave.tensor<[@M] of f32, > + return +} + +// ----- + +// Test memory access pattern with complete LDS promotion - 2D case +// CHECK-LABEL: @memory_access_pattern_2d_lds_promotion +func.func @memory_access_pattern_2d_lds_promotion(%value: !wave.tensor>, %mem: !wave.tensor<[@M, @N] of f32, >) { + // CHECK: wave.write + // CHECK-SAME: memory_access_pattern = #wave.memory_access_pattern< + // CHECK-SAME: use_lds_promotion = true, + // CHECK-SAME: group_id = "lds_2d", + // CHECK-SAME: lds_block_global_base = <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (WG0 * BLOCK_M, WG1 * BLOCK_N)>, + // CHECK-SAME: lds_block_shape = <[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (BLOCK_M, BLOCK_N)>, + // CHECK-SAME: lds_load_indices = <[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"VEC_SIZE_M">, #wave.symbol<"VEC_SIZE_N">] -> (T0 * VEC_SIZE_M, T1 * VEC_SIZE_N)>, + // CHECK-SAME: lds_load_vector_sizes = <[#wave.symbol<"VEC_SIZE_M">, #wave.symbol<"VEC_SIZE_N">] -> (VEC_SIZE_M, VEC_SIZE_N)>, + // CHECK-SAME: global_store_indices = <[#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">, #wave.symbol<"VEC_SIZE_M">, #wave.symbol<"VEC_SIZE_N">] -> (WG0 * BLOCK_M + T0 * VEC_SIZE_M, WG1 * BLOCK_N + T1 * VEC_SIZE_N)> + // CHECK-SAME: > + wave.write %value, %mem { + memory_access_pattern = #wave.memory_access_pattern< + use_lds_promotion = true, + group_id = "lds_2d", + lds_block_global_base = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (WG0 * BLOCK_M, WG1 * BLOCK_N)>, + lds_block_shape = #wave.expr_list<[#wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">] -> (BLOCK_M, BLOCK_N)>, + lds_load_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.symbol<"VEC_SIZE_M">, #wave.symbol<"VEC_SIZE_N">] -> (T0 * VEC_SIZE_M, T1 * VEC_SIZE_N)>, + lds_load_vector_sizes = #wave.expr_list<[#wave.symbol<"VEC_SIZE_M">, #wave.symbol<"VEC_SIZE_N">] -> (VEC_SIZE_M, VEC_SIZE_N)>, + global_store_indices = #wave.expr_list<[#wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.symbol<"BLOCK_N">, #wave.symbol<"VEC_SIZE_M">, #wave.symbol<"VEC_SIZE_N">] -> (WG0 * BLOCK_M + T0 * VEC_SIZE_M, WG1 * BLOCK_N + T1 * VEC_SIZE_N)> + > + } : !wave.tensor>, !wave.tensor<[@M, @N] of f32, > + return +}