diff --git a/water/include/water/Dialect/Wave/IR/WaveAttrs.td b/water/include/water/Dialect/Wave/IR/WaveAttrs.td index 3b915f7dd..071f0b8b8 100644 --- a/water/include/water/Dialect/Wave/IR/WaveAttrs.td +++ b/water/include/water/Dialect/Wave/IR/WaveAttrs.td @@ -382,22 +382,34 @@ def WaveIterSymbolAttr : AttrDef { def WaveIndexMappingAttr : AttrDef { let mnemonic = "index_mapping"; let description = [{ - An affine map with named symbols for Wave indexing expressions. - - This attribute preserves meaningful symbol names (e.g., WG0, BLOCK_M, T0) - while storing affine maps internally for start, step, and stride. The - symbol_names array corresponds 1:1 to the symbols in the affine - expressions, where s0 maps to symbol_names[0], s1 to symbol_names[1], etc. - - Custom syntax: [symbol_name attributes] -> (start_expr, step_expr, stride_expr) - Example: [wave.index_symbol, wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + 42, 1, BLOCK_M) + An attribute capturing how symbols indexing a tensor are distributed + across concurrent threads, workgroups, devices, and loop tiles. + + In this attribute, the start expression corresponds to the offset from + the start of the tensor dimension, the step corresponds to the number + of consecutive elements taken (i.e., corresponding to the `vector.step` + operation to index into them) and the stride corresponds to the number + of elements to step over in this dimension to get to the next block + of consecutive elements that are accessed by the current thread. The + start expression is an affine expression where positional affine symbols + are mapped to specific wave symbols. It will typically include thread and + workgroup index symbols (the absence thereof indicates that all threads or + all workgroups access the same element, i.e., a broadcast), as well as + loop iteration symbols for operations nested in loops. Start and stride + are constant values. + + The symbols array corresponds 1:1 to the symbols in the affine map, where + s0 maps to symbols[0], s1 to symbols[1], etc. + + Custom syntax: `[symbol_name attributes] -> (start_expr, step, stride)` + Example: `[wave.index_symbol, wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + 42, 1, 4)` }]; let parameters = (ins ArrayRefParameter<"::mlir::Attribute">:$symbols, "::mlir::AffineMap":$start, - "::mlir::AffineMap":$step, - "::mlir::AffineMap":$stride + "uint64_t":$step, + "uint64_t":$stride ); let hasCustomAssemblyFormat = 1; diff --git a/water/include/water/c/Dialects.h b/water/include/water/c/Dialects.h index f6a53d755..e1992d39f 100644 --- a/water/include/water/c/Dialects.h +++ b/water/include/water/c/Dialects.h @@ -100,13 +100,11 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAWaveIndexMappingAttr(MlirAttribute attr); /// Creates a new WaveIndexMappingAttr with the given start, step and stride -/// maps that are interpreted as accepting the symbols provided in the -/// `symbolNames` list. The list must have as many entries as maps have symbols, -/// and all maps must have the same number of symbols and zero dimensions. The -/// list is expected to only contain WaveSymbolAttr instances. +/// values that are interpreted as constant offsets. The `symbolNames` list +/// is expected to only contain WaveSymbolAttr instances. MLIR_CAPI_EXPORTED MlirAttribute mlirWaveIndexMappingAttrGet( - MlirContext mlirCtx, MlirAttribute *symbolNames, MlirAffineMap start, - MlirAffineMap step, MlirAffineMap stride); + MlirContext mlirCtx, MlirAttribute *symbolNames, intptr_t numSymbols, + MlirAffineMap start, uint64_t step, uint64_t stride); /// Returns the typeID of a WaveIndexMappingAttr. MLIR_CAPI_EXPORTED MlirTypeID mlirWaveIndexMappingAttrGetTypeID(); @@ -116,11 +114,10 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirWaveIndexMappingAttrGetStart(MlirAttribute attr); /// Get the step from a WaveIndexMappingAttr. -MLIR_CAPI_EXPORTED MlirAffineMap -mlirWaveIndexMappingAttrGetStep(MlirAttribute attr); +MLIR_CAPI_EXPORTED uint64_t mlirWaveIndexMappingAttrGetStep(MlirAttribute attr); /// Get the stride from a WaveIndexMappingAttr. -MLIR_CAPI_EXPORTED MlirAffineMap +MLIR_CAPI_EXPORTED uint64_t mlirWaveIndexMappingAttrGetStride(MlirAttribute attr); /// Get the number of (input) symbols. diff --git a/water/lib/CAPI/Dialects.cpp b/water/lib/CAPI/Dialects.cpp index 458db6fef..9eb45fbc5 100644 --- a/water/lib/CAPI/Dialects.cpp +++ b/water/lib/CAPI/Dialects.cpp @@ -108,17 +108,12 @@ bool mlirAttributeIsAWaveIndexMappingAttr(MlirAttribute attr) { MlirAttribute mlirWaveIndexMappingAttrGet(MlirContext mlirCtx, MlirAttribute *symbolNames, - MlirAffineMap start, - MlirAffineMap step, - MlirAffineMap stride) { + intptr_t numSymbols, + MlirAffineMap start, uint64_t step, + uint64_t stride) { mlir::MLIRContext *ctx = unwrap(mlirCtx); // Convert C array of MlirAttribute to vector of WaveSymbolAttr. - unsigned numSymbols = mlirAffineMapGetNumSymbols(start); - assert(mlirAffineMapGetNumSymbols(step) == numSymbols && - "expected start and step to have the same number of dimensions"); - assert(mlirAffineMapGetNumSymbols(stride) == numSymbols && - "expected start and stride to have the same number of dimensions"); llvm::SmallVector symbolAttrs = llvm::map_to_vector( llvm::make_range(symbolNames, symbolNames + numSymbols), [](MlirAttribute attr) { return unwrap(attr); }); @@ -131,7 +126,7 @@ MlirAttribute mlirWaveIndexMappingAttrGet(MlirContext mlirCtx, "WaveIndexSymbolAttr attributes"); return wrap(wave::WaveIndexMappingAttr::get(ctx, symbolAttrs, unwrap(start), - unwrap(step), unwrap(stride))); + step, stride)); } MlirTypeID mlirWaveIndexMappingAttrGetTypeID() { @@ -142,12 +137,12 @@ MlirAffineMap mlirWaveIndexMappingAttrGetStart(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getStart()); } -MlirAffineMap mlirWaveIndexMappingAttrGetStep(MlirAttribute attr) { - return wrap(llvm::cast(unwrap(attr)).getStep()); +uint64_t mlirWaveIndexMappingAttrGetStep(MlirAttribute attr) { + return llvm::cast(unwrap(attr)).getStep(); } -MlirAffineMap mlirWaveIndexMappingAttrGetStride(MlirAttribute attr) { - return wrap(llvm::cast(unwrap(attr)).getStride()); +uint64_t mlirWaveIndexMappingAttrGetStride(MlirAttribute attr) { + return llvm::cast(unwrap(attr)).getStride(); } intptr_t mlirWaveIndexMappingAttrGetNumSymbols(MlirAttribute attr) { diff --git a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp index 6301167ad..23cc6da00 100644 --- a/water/lib/Dialect/Wave/IR/WaveAttrs.cpp +++ b/water/lib/Dialect/Wave/IR/WaveAttrs.cpp @@ -147,29 +147,25 @@ Attribute WaveIndexMappingAttr::parse(AsmParser &parser, Type type) { } AffineExpr startExpr; - AffineExpr stepExpr; - AffineExpr strideExpr; - if (failed(parseExprWithNames(symbolNames, startExpr, parser)) || - parser.parseComma() || - failed(parseExprWithNames(symbolNames, stepExpr, parser)) || - parser.parseComma() || - failed(parseExprWithNames(symbolNames, strideExpr, parser)) || - parser.parseRParen()) { - parser.emitError( - parser.getCurrentLocation(), - "expected three affine expressions for '(start, step, stride)'"); + uint64_t step, stride; + llvm::SMLoc currentLocation = parser.getCurrentLocation(); + if (failed(parseExprWithNames(symbolNames, startExpr, parser))) { + parser.emitError(currentLocation, "expected affine expression"); + return {}; + } + if (failed(parser.parseComma()) || failed(parser.parseInteger(step)) || + failed(parser.parseComma()) || failed(parser.parseInteger(stride)) || + failed(parser.parseRParen())) { return {}; } - // Build maps - auto startMap = AffineMap::get( - /*numDims=*/0, /*numSymbols=*/symbolNames.size(), startExpr, context); - auto stepMap = AffineMap::get( - /*numDims=*/0, /*numSymbols=*/symbolNames.size(), stepExpr, context); - auto strideMap = AffineMap::get( - /*numDims=*/0, /*numSymbols=*/symbolNames.size(), strideExpr, context); + auto startMap = startExpr + ? AffineMap::get( + /*numDims=*/0, /*numSymbols=*/symbolNames.size(), + startExpr, context) + : AffineMap(); - return get(context, symbolNameAttrs, startMap, stepMap, strideMap); + return get(context, symbolNameAttrs, startMap, step, stride); } void WaveIndexMappingAttr::print(AsmPrinter &printer) const { @@ -196,16 +192,14 @@ void WaveIndexMappingAttr::print(AsmPrinter &printer) const { } // All three maps share the same symbol set and order. std::string startStr = stringifyWithNames(getStart(), names); - std::string stepStr = stringifyWithNames(getStep(), names); - std::string strideStr = stringifyWithNames(getStride(), names); - printer << "(" << startStr << ", " << stepStr << ", " << strideStr << ")"; + printer << "(" << startStr << ", " << getStep() << ", " << getStride() << ")"; } LogicalResult WaveIndexMappingAttr::verify(function_ref emitError, ArrayRef symbols, AffineMap start, - AffineMap step, AffineMap stride) { + uint64_t step, uint64_t stride) { if (!llvm::all_of(symbols, llvm::IsaPred)) { return emitError() << "expected all symbols to be a WaveSymbolAttr, " diff --git a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp index d201002f4..457f70068 100644 --- a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp +++ b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp @@ -50,35 +50,20 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) { << val; auto mapping = cast(val); - auto checkNoDims = [&](AffineMap map, StringRef which) -> LogicalResult { - if (map.getNumDims() != 0) - return op->emitError( - "wave indexing " + which + - " map should have no dimensions, only symbols, got ") - << map.getNumDims() << " dimensions for symbol " - << named.getName(); - return success(); - }; - - AffineMap startMap = mapping.getStart(); - AffineMap stepMap = mapping.getStep(); - AffineMap strideMap = mapping.getStride(); - if (failed(checkNoDims(startMap, "start")) || - failed(checkNoDims(stepMap, "step")) || - failed(checkNoDims(strideMap, "stride"))) - return failure(); + if (mapping.getStart().getNumDims() != 0) { + return op->emitError("wave indexing start map should have no " + "dimensions, only symbols, got ") + << mapping.getStart().getNumDims() << " dimensions for symbol " + << named.getName(); + } unsigned declared = mapping.getSymbols().size(); - if (startMap.getNumSymbols() != declared || - stepMap.getNumSymbols() != declared || - strideMap.getNumSymbols() != declared) { + if (mapping.getStart().getNumSymbols() != declared) { return op->emitError( "inconsistent symbol count between symbol_names and " "affine maps for index symbol '") - << named.getName() << "' (expected " << declared - << ", got start=" << startMap.getNumSymbols() - << ", step=" << stepMap.getNumSymbols() - << ", stride=" << strideMap.getNumSymbols() << ")"; + << named.getName() << "' (expected " << declared << ", got " + << mapping.getStart().getNumSymbols() << ")"; } for (auto symbol : mapping.getSymbols()) { @@ -691,13 +676,9 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join( auto isThreadDependent = [&](wave::WaveIndexMappingAttr val) -> bool { llvm::SmallVector threadLikeSymbolPositions; getThreadLikeSymbolPositions(val.getSymbols(), threadLikeSymbolPositions); - return llvm::any_of( - llvm::ArrayRef{val.getStart(), val.getStep(), val.getStride()}, - [&](mlir::AffineMap map) { - return llvm::any_of(threadLikeSymbolPositions, [&](unsigned pos) { - return map.isFunctionOfSymbol(pos); - }); - }); + return llvm::any_of(threadLikeSymbolPositions, [&](unsigned pos) { + return val.getStart().isFunctionOfSymbol(pos); + }); }; // If both are thread-dependent or thread-independent, the only acceptable @@ -728,18 +709,6 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join( alignMapSymbols(threadIndependentMapping.getStart(), threadIndependentSymbols, allSymbols); - mlir::AffineMap threadDependentStep = alignMapSymbols( - threadDependentMapping.getStep(), threadDependentSymbols, allSymbols); - mlir::AffineMap threadIndependentStep = - alignMapSymbols(threadIndependentMapping.getStep(), - threadIndependentSymbols, allSymbols); - - mlir::AffineMap threadDependentStride = alignMapSymbols( - threadDependentMapping.getStride(), threadDependentSymbols, allSymbols); - mlir::AffineMap threadIndependentStride = - alignMapSymbols(threadIndependentMapping.getStride(), - threadIndependentSymbols, allSymbols); - // Subtract the thread-independent from thread-dependent for each. auto subtractMaps = [&](mlir::AffineMap a, mlir::AffineMap b) -> mlir::AffineMap { @@ -754,10 +723,6 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join( }; mlir::AffineMap newStart = subtractMaps(threadDependentStart, threadIndependentStart); - mlir::AffineMap newStep = - subtractMaps(threadDependentStep, threadIndependentStep); - mlir::AffineMap newStride = - subtractMaps(threadDependentStride, threadIndependentStride); llvm::SmallVector threadLikeSymbolPositions; getThreadLikeSymbolPositions(allSymbols, threadLikeSymbolPositions); @@ -775,8 +740,7 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join( return !walkResult.wasInterrupted(); }; - if (!isOnlyThreadDependent(newStart) || !isOnlyThreadDependent(newStep) || - !isOnlyThreadDependent(newStride)) + if (!isOnlyThreadDependent(newStart)) return top(); result[namedAttr.getName()] = threadDependentMapping; diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index 88d3fdf89..23d0175c0 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -597,8 +597,8 @@ struct MmaSingleIndexExprBuilder { // Set the parameter of the index expression for the currently selected // dimension. MmaSingleIndexExprBuilder &offset(mlir::AffineExpr expr); - MmaSingleIndexExprBuilder &size(int64_t value); - MmaSingleIndexExprBuilder &stride(int64_t value); + MmaSingleIndexExprBuilder &size(uint64_t value); + MmaSingleIndexExprBuilder &stride(uint64_t value); // Select the dimension. MmaSingleIndexExprBuilder &m(); @@ -609,7 +609,9 @@ struct MmaSingleIndexExprBuilder { void populate(llvm::SmallVectorImpl &attributes) const; MmaIndexingExprBuilder &parent; - mlir::AffineExpr offsetExpr, sizeExpr, strideExpr; + mlir::AffineExpr offsetExpr; + std::optional sizeValue; + std::optional strideValue; bool enabled; }; @@ -657,16 +659,18 @@ struct MmaIndexingExprBuilder { void populate(llvm::SmallVectorImpl &attributes) const { mlir::MLIRContext *ctx = getAnySymbolContext(mSymbol, nSymbol, kSymbol); - auto buildMap = [&](mlir::AffineExpr expr) { - assert(expr && - "expected offset/size/stride to be set up for all symbols"); - return mlir::AffineMap::get(/*dimCount=*/0, - /*symbolCount=*/symbols.size(), expr, ctx); - }; auto buildOne = [&](const MmaSingleIndexExprBuilder &builder) { + assert(builder.offsetExpr && + "expected offset to be set up for all symbols"); + assert(builder.sizeValue && "expected size to be set up for all symbols"); + assert(builder.strideValue && + "expected stride to be set up for all symbols"); return wave::WaveIndexMappingAttr::get( - ctx, symbols, buildMap(builder.offsetExpr), - buildMap(builder.sizeExpr), buildMap(builder.strideExpr)); + ctx, symbols, + mlir::AffineMap::get(/*dimCount=*/0, + /*symbolCount=*/symbols.size(), + builder.offsetExpr, ctx), + *builder.sizeValue, *builder.strideValue); }; if (mSymbol) @@ -691,19 +695,19 @@ MmaSingleIndexExprBuilder::offset(mlir::AffineExpr expr) { return *this; } -MmaSingleIndexExprBuilder &MmaSingleIndexExprBuilder::size(int64_t value) { +MmaSingleIndexExprBuilder &MmaSingleIndexExprBuilder::size(uint64_t value) { if (!enabled) return *this; - assert(!sizeExpr && "expected size to be set only once"); - sizeExpr = mlir::getAffineConstantExpr(value, offsetExpr.getContext()); + assert(!sizeValue && "expected size to be set only once"); + sizeValue = value; return *this; } -MmaSingleIndexExprBuilder &MmaSingleIndexExprBuilder::stride(int64_t value) { +MmaSingleIndexExprBuilder &MmaSingleIndexExprBuilder::stride(uint64_t value) { if (!enabled) return *this; - assert(!strideExpr && "expected stride to be set only once"); - strideExpr = mlir::getAffineConstantExpr(value, offsetExpr.getContext()); + assert(!strideValue && "expected stride to be set only once"); + strideValue = value; return *this; } @@ -922,9 +926,7 @@ applyConstraint(ConstraintAttrT constraint, /*dimCount=*/0, symbols.size(), symbolExpr * constraint.getTileSize().getMap().getResult(0)); if (baseMapping == nullptr) - return wave::WaveIndexMappingAttr::get( - context, symbols, map, mlir::AffineMap::getConstantMap(1, context), - mlir::AffineMap::getConstantMap(1, context)); + return wave::WaveIndexMappingAttr::get(context, symbols, map, 1, 1); llvm::SmallVector allSymbols; wave::aggregateAllSymbols( @@ -934,15 +936,11 @@ applyConstraint(ConstraintAttrT constraint, mlir::AffineMap baseStart = alignMapSymbols( baseMapping.getStart(), baseMapping.getSymbols(), allSymbols); - mlir::AffineMap baseStep = alignMapSymbols( - baseMapping.getStep(), baseMapping.getSymbols(), allSymbols); - mlir::AffineMap baseStride = alignMapSymbols( - baseMapping.getStride(), baseMapping.getSymbols(), allSymbols); map = alignMapSymbols(map, symbols, allSymbols); map = mlir::AffineMap::get(/*dimCount=*/0, allSymbols.size(), baseStart.getResult(0) + map.getResult(0)); - return wave::WaveIndexMappingAttr::get(context, allSymbols, map, baseStep, - baseStride); + return wave::WaveIndexMappingAttr::get( + context, allSymbols, map, baseMapping.getStep(), baseMapping.getStride()); } // Create an index mapping induced by the given constraint. Combine it with the diff --git a/water/lib/Dialect/Wave/IR/WaveUtils.cpp b/water/lib/Dialect/Wave/IR/WaveUtils.cpp index f45babff3..d7102a2a8 100644 --- a/water/lib/Dialect/Wave/IR/WaveUtils.cpp +++ b/water/lib/Dialect/Wave/IR/WaveUtils.cpp @@ -28,13 +28,7 @@ wave::getUncollapsedVectorShape(llvm::ArrayRef shape, Attribute entry = indexDict.get(symbol.getName()); assert(entry && "expected dictionary to contain indices for the shape"); auto mapAttr = cast(entry); - std::optional> folded = - wave::evaluateMapWithHyperparams(mapAttr.getStep(), - mapAttr.getSymbols(), hyper); - if (!folded) - return ShapedType::kDynamic; - assert(folded->size() == 1 && "expected single-result map"); - return (*folded)[0]; + return static_cast(mapAttr.getStep()); }); } diff --git a/water/python/WaterExtensionNanobind.cpp b/water/python/WaterExtensionNanobind.cpp index 814103117..a21c7054a 100644 --- a/water/python/WaterExtensionNanobind.cpp +++ b/water/python/WaterExtensionNanobind.cpp @@ -128,27 +128,20 @@ NB_MODULE(_waterDialects, m) { .def_classmethod( "get", [](const nb::object &cls, std::vector &symbols, - MlirAffineMap start, MlirAffineMap step, MlirAffineMap stride, + MlirAffineMap start, uint64_t step, uint64_t stride, // MlirContext should always come last to allow for being // automatically deduced from context. MlirContext context) { intptr_t numSymbols = symbols.size(); - intptr_t numResults = mlirAffineMapGetNumResults(start); - for (MlirAffineMap map : {start, step, stride}) { - if (numSymbols != mlirAffineMapGetNumSymbols(map)) { - throw nb::value_error("Expected symbols, start, step and " - "stride to be co-indexed."); - } - if (mlirAffineMapGetNumDims(map) != 0) { - throw nb::value_error("Maps should not involve dimensions."); - } - if (numResults != mlirAffineMapGetNumResults(map)) { - throw nb::value_error( - "Maps should have the same number of results."); - } + if (numSymbols != mlirAffineMapGetNumSymbols(start)) { + throw nb::value_error( + "Expected symbols and start to be co-indexed."); + } + if (mlirAffineMapGetNumDims(start) != 0) { + throw nb::value_error("Start map should not involve dimensions."); } - return cls(mlirWaveIndexMappingAttrGet(context, symbols.data(), - start, step, stride)); + return cls(mlirWaveIndexMappingAttrGet( + context, symbols.data(), numSymbols, start, step, stride)); }, nb::arg("cls"), nb::arg("symbols"), nb::arg("start"), nb::arg("step"), nb::arg("stride"), nb::arg("context") = nb::none(), diff --git a/water/test/Dialect/Wave/lower-wave-to-mlir.mlir b/water/test/Dialect/Wave/lower-wave-to-mlir.mlir index 78fb4a6e3..6c0c7d73e 100644 --- a/water/test/Dialect/Wave/lower-wave-to-mlir.mlir +++ b/water/test/Dialect/Wave/lower-wave-to-mlir.mlir @@ -370,7 +370,7 @@ func.func @lower_read(%mem: !wave.tensor<[@M, @N] of f16, >) attributes // CHECK: %[[TIDX_Y:.*]] = gpu.thread_id y // CHECK: %[[BIDX_Y:.*]] = gpu.block_id y // CHECK: %[[COL:.*]] = affine.apply affine_map<()[s0, s1] -> (s1 * 64 + s0 * 8)>()[%[[TIDX_Y]], %[[BIDX_Y]]] - N : [#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + (BLOCK_N floordiv 8) * T1, BLOCK_N ceildiv 8, 1)}] + N : [#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + (BLOCK_N floordiv 8) * T1, 8, 1)}] : (!wave.tensor<[@M, @N] of f16, >) -> vector<8xf16> // CHECK: %[[VEC:.+]] = vector.load {{.*}}[%[[ROW]], %[[COL]]] : memref<{{.*}}xf16{{.*}}>, vector<8xf16> @@ -478,7 +478,7 @@ func.func @lower_write(%mem: !wave.tensor<[@M, @N] of f16, >) attributes // CHECK: %[[TIDX_Y:.*]] = gpu.thread_id y // CHECK: %[[BIDX_Y:.*]] = gpu.block_id y // CHECK: %[[COL:.*]] = affine.apply affine_map<()[s0, s1] -> (s1 * 64 + s0 * 8)>()[%[[TIDX_Y]], %[[BIDX_Y]]] - N : [#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + (BLOCK_N floordiv 8) * T1, BLOCK_N ceildiv 8, 1)}] + N : [#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + (BLOCK_N floordiv 8) * T1, 8, 1)}] : vector<8xf16>, !wave.tensor<[@M, @N] of f16, > // CHECK: vector.store {{.*}}[%[[ROW]], %[[COL]]] : memref<{{.*}}xf16{{.*}}>, vector<8xf16> // CHECK-NOT: vector.transfer_write diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index 2952c484a..2d96ff136 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -107,8 +107,7 @@ func.func @index_attr_wrong_attr_type(%arg0: f32) { // must provide the full triple (start, step, stride) func.func @index_attr_missing_step_stride(%arg0: f32) { - // expected-error @+2 {{expected ','}} - // expected-error @+1 {{custom op 'wave.register' expected three affine expressions for '(start, step, stride)'}} + // expected-error @below {{expected ','}} wave.register %arg0 index [{X : [#wave.index_symbol] -> (WG0)}] : !wave.tensor<[@M] of f32, > return } @@ -117,8 +116,7 @@ func.func @index_attr_missing_step_stride(%arg0: f32) { // must provide the full triple (start, step, stride) func.func @index_attr_missing_stride(%arg0: f32) { - // expected-error @+2 {{expected ','}} - // expected-error @+1 {{custom op 'wave.register' expected three affine expressions for '(start, step, stride)'}} + // expected-error @below {{expected ','}} wave.register %arg0 index [{X : [#wave.index_symbol] -> (WG0, 1)}] : !wave.tensor<[@M] of f32, > return } @@ -256,7 +254,7 @@ module attributes {wave.normal_form = #wave.normal_form} { // expected-note @below {{BLOCK_M, BLOCK_N, M}} %0 = wave.read %mem index [{ M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + (BLOCK_M floordiv 2) * (T0 floordiv 64) + T0 mod 64, 1, 64), - N : [#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + (BLOCK_N floordiv 2) * T1, BLOCK_N ceildiv 2, 1)}] + N : [#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + (BLOCK_N floordiv 2) * T1, 1, 1)}] : (!wave.tensor<[@M] of f16, >) -> !wave.tensor<[@M] of f16, > return } @@ -271,7 +269,7 @@ module attributes {wave.normal_form = #wave.normal_form} { // expected-note @below {{available symbols: M, N}} %0 = wave.read %mem index [{ M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + (BLOCK_M floordiv 2) * (T0 floordiv 64) + T0 mod 64, 1, 64), - N : [#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + (BLOCK_N floordiv 2) * T1, BLOCK_N ceildiv 2, 1)}] + N : [#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + (BLOCK_N floordiv 2) * T1, 1, 1)}] : (!wave.tensor<[@M] of f16, >) -> !wave.tensor<[@M] of f16, > return } @@ -357,8 +355,8 @@ func.func @read_index_multi_step_eval(%mem: !wave.tensor<[@M, @N] of f32>) attri // expected-error @below {{'index' has more than one entry with non-unit step}} // expected-note @below {{second non-unit step dimension: 1}} wave.read %mem index [{ - M : [#wave.index_symbol, #wave.symbol<"X">] -> (T0, 2 * X, 1), - N : [#wave.index_symbol, #wave.symbol<"X">, #wave.symbol<"Y">] -> (T1, X + Y, 1) + M : [#wave.index_symbol, #wave.symbol<"X">] -> (T0, 2, 1), + N : [#wave.index_symbol, #wave.symbol<"X">, #wave.symbol<"Y">] -> (T1, 2, 1) }] : (!wave.tensor<[@M, @N] of f32>) -> vector<4xf32> return } diff --git a/water/test/Dialect/Wave/ops.mlir b/water/test/Dialect/Wave/ops.mlir index 4022b6b82..15ddae701 100644 --- a/water/test/Dialect/Wave/ops.mlir +++ b/water/test/Dialect/Wave/ops.mlir @@ -108,7 +108,7 @@ func.func @register_with_symbols_complex_index() { // CHECK: wave.register %register = wave.register %0 index [{ - B : [#wave.index_symbol, #wave.symbol<"BLOCK_B">] -> (WG2 * (BLOCK_B+BLOCK_B), BLOCK_B * (WG2+WG2), WG2 * BLOCK_B), + B : [#wave.index_symbol, #wave.symbol<"BLOCK_B">] -> (WG2 * (BLOCK_B+BLOCK_B), 1, 1), M : [#wave.index_symbol, #wave.symbol<"BLOCK_M">, #wave.index_symbol] -> (WG0 * BLOCK_M + BLOCK_M * ((T0 floordiv 64) floordiv 2) + T0 mod 32, 1, 1), N : [#wave.index_symbol, #wave.symbol<"BLOCK_N">, #wave.index_symbol, #wave.index_symbol, #wave.index_symbol] -> (T1 * (BLOCK_N floordiv 2) + BLOCK_N * WG1 + GPR_NUM mod 4 + ((GPR_NUM floordiv 4) mod 4) * 8 + ((T0 mod 64) floordiv 32) * 4, 1, 1) }] @@ -157,7 +157,7 @@ attributes {wave.hyperparameters = #wave.hyperparameters<{BLOCK_M = 32, BLOCK_N // CHECK: #wave.index_symbol %0 = wave.read %mem index [{ M : [#wave.symbol<"BLOCK_M">, #wave.index_symbol, #wave.index_symbol] -> (BLOCK_M * WG0 + (BLOCK_M floordiv 2) * (T0 floordiv 64) + T0 mod 64, 1, 64), - N : [#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + (BLOCK_N floordiv 2) * T1, BLOCK_N ceildiv 2, 1)}] + N : [#wave.index_symbol, #wave.index_symbol, #wave.symbol<"BLOCK_N">] -> (WG1 * BLOCK_N + (BLOCK_N floordiv 2) * T1, 1, 1)}] : (!wave.tensor<[@M] of f16, >) -> !wave.tensor<[@M] of f16, > return } diff --git a/water/test/Dialect/Wave/python_bindings.py b/water/test/Dialect/Wave/python_bindings.py index 9b7376c70..9edd7f166 100644 --- a/water/test/Dialect/Wave/python_bindings.py +++ b/water/test/Dialect/Wave/python_bindings.py @@ -44,7 +44,7 @@ else: assert False, "Expected to fail with TypeError." - # CHECK: #wave, #wave.symbol<"BLOCK_M">, #wave.index_symbol] -> (WG0 * 3, WG0 + BLOCK_M, T0 mod WG0)> + # CHECK: #wave, #wave.symbol<"BLOCK_M">, #wave.index_symbol] -> (WG0 * 3, 1, 4)> symbols = [ wave.WaveIndexSymbolAttr.get(wave.WaveIndexSymbol.WORKGROUP_0), wave.WaveSymbolAttr.get("BLOCK_M"), @@ -54,20 +54,18 @@ s1 = ir.AffineSymbolExpr.get(1) s2 = ir.AffineSymbolExpr.get(2) start_map = ir.AffineMap.get(0, 3, [s0 * 3]) - step_map = ir.AffineMap.get(0, 3, [s0 + s1]) - stride_map = ir.AffineMap.get(0, 3, [s2 % s0]) - index_mapping_attr = wave.WaveIndexMappingAttr.get( - symbols, start_map, step_map, stride_map - ) + step = 1 + stride = 4 + index_mapping_attr = wave.WaveIndexMappingAttr.get(symbols, start_map, step, stride) print(index_mapping_attr) # CHECK: ()[s0, s1, s2] -> (s0 * 3) print(index_mapping_attr.start) - # CHECK: ()[s0, s1, s2] -> (s0 + s1) + # CHECK: 1 print(index_mapping_attr.step) - # CHECK: ()[s0, s1, s2] -> (s2 mod s0) + # CHECK: 4 print(index_mapping_attr.stride) # CHECK: 3 @@ -84,7 +82,7 @@ print(retrieved_symbols[2]) try: - wave.WaveIndexMappingAttr.get([], start_map, step_map, stride_map) + wave.WaveIndexMappingAttr.get([], start_map, step, stride) except ValueError as e: assert "co-indexed" in str(e) else: @@ -92,23 +90,15 @@ try: dimension_map = ir.AffineMap.get(1, 0, []) - wave.WaveIndexMappingAttr.get([], dimension_map, dimension_map, dimension_map) + wave.WaveIndexMappingAttr.get([], dimension_map, step, stride) except ValueError as e: assert "not involve dimensions" in str(e) else: assert False, "Expected to fail with ValueError." - try: - no_result_map = ir.AffineMap.get(0, 3, []) - wave.WaveIndexMappingAttr.get(symbols, start_map, no_result_map, stride_map) - except ValueError as e: - assert "same number of results" in str(e) - else: - assert False, "Expected to fail with ValueError." - try: wave.WaveIndexMappingAttr.get( - ["string", "instead", "of", "attrs"], start_map, step_map, stride_map + ["string", "instead", "of", "attrs"], start_map, step, stride ) except TypeError as e: assert "ir.Attribute" in str(e)