Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions water/include/water/Dialect/Wave/IR/WaveAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -382,22 +382,34 @@ def WaveIterSymbolAttr : AttrDef<WaveDialect, "WaveIterSymbol"> {
def WaveIndexMappingAttr : AttrDef<WaveDialect, "WaveIndexMapping"> {
let mnemonic = "index_mapping";
let description = [{
An affine map with named symbols for Wave indexing expressions.

This attribute preserves meaningful symbol names (e.g., WG0, BLOCK_M, T0)
while storing affine maps internally for start, step, and stride. The
symbol_names array corresponds 1:1 to the symbols in the affine
expressions, where s0 maps to symbol_names[0], s1 to symbol_names[1], etc.

Custom syntax: [symbol_name attributes] -> (start_expr, step_expr, stride_expr)
Example: [wave.index_symbol<WG0>, wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + 42, 1, BLOCK_M)
An attribute capturing how symbols indexing a tensor are distributed
across concurrent threads, workgroups, devices, and loop tiles.

In this attribute, the start expression corresponds to the offset from
the start of the tensor dimension, the step corresponds to the number
of consecutive elements taken (i.e., corresponding to the `vector.step`
operation to index into them) and the stride corresponds to the number
of elements to step over in this dimension to get to the next block
of consecutive elements that are accessed by the current thread. The
start expression is an affine expression where positional affine symbols
are mapped to specific wave symbols. It will typically include thread and
workgroup index symbols (the absence thereof indicates that all threads or
all workgroups access the same element, i.e., a broadcast), as well as
loop iteration symbols for operations nested in loops. Start and stride
are constant values.

The symbols array corresponds 1:1 to the symbols in the affine map, where
s0 maps to symbols[0], s1 to symbols[1], etc.

Custom syntax: `[symbol_name attributes] -> (start_expr, step, stride)`
Example: `[wave.index_symbol<WG0>, wave.symbol<"BLOCK_M">] -> (WG0 * BLOCK_M + 42, 1, 4)`
}];

let parameters = (ins
ArrayRefParameter<"::mlir::Attribute">:$symbols,
"::mlir::AffineMap":$start,
"::mlir::AffineMap":$step,
"::mlir::AffineMap":$stride
"uint64_t":$step,
"uint64_t":$stride
);

let hasCustomAssemblyFormat = 1;
Expand Down
15 changes: 6 additions & 9 deletions water/include/water/c/Dialects.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,11 @@ MLIR_CAPI_EXPORTED bool
mlirAttributeIsAWaveIndexMappingAttr(MlirAttribute attr);

/// Creates a new WaveIndexMappingAttr with the given start, step and stride
/// maps that are interpreted as accepting the symbols provided in the
/// `symbolNames` list. The list must have as many entries as maps have symbols,
/// and all maps must have the same number of symbols and zero dimensions. The
/// list is expected to only contain WaveSymbolAttr instances.
/// values that are interpreted as constant offsets. The `symbolNames` list
/// is expected to only contain WaveSymbolAttr instances.
MLIR_CAPI_EXPORTED MlirAttribute mlirWaveIndexMappingAttrGet(
MlirContext mlirCtx, MlirAttribute *symbolNames, MlirAffineMap start,
MlirAffineMap step, MlirAffineMap stride);
MlirContext mlirCtx, MlirAttribute *symbolNames, intptr_t numSymbols,
MlirAffineMap start, uint64_t step, uint64_t stride);

/// Returns the typeID of a WaveIndexMappingAttr.
MLIR_CAPI_EXPORTED MlirTypeID mlirWaveIndexMappingAttrGetTypeID();
Expand All @@ -116,11 +114,10 @@ MLIR_CAPI_EXPORTED MlirAffineMap
mlirWaveIndexMappingAttrGetStart(MlirAttribute attr);

/// Get the step from a WaveIndexMappingAttr.
MLIR_CAPI_EXPORTED MlirAffineMap
mlirWaveIndexMappingAttrGetStep(MlirAttribute attr);
MLIR_CAPI_EXPORTED uint64_t mlirWaveIndexMappingAttrGetStep(MlirAttribute attr);

/// Get the stride from a WaveIndexMappingAttr.
MLIR_CAPI_EXPORTED MlirAffineMap
MLIR_CAPI_EXPORTED uint64_t
mlirWaveIndexMappingAttrGetStride(MlirAttribute attr);

/// Get the number of (input) symbols.
Expand Down
21 changes: 8 additions & 13 deletions water/lib/CAPI/Dialects.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,12 @@ bool mlirAttributeIsAWaveIndexMappingAttr(MlirAttribute attr) {

MlirAttribute mlirWaveIndexMappingAttrGet(MlirContext mlirCtx,
MlirAttribute *symbolNames,
MlirAffineMap start,
MlirAffineMap step,
MlirAffineMap stride) {
intptr_t numSymbols,
MlirAffineMap start, uint64_t step,
uint64_t stride) {
mlir::MLIRContext *ctx = unwrap(mlirCtx);

// Convert C array of MlirAttribute to vector of WaveSymbolAttr.
unsigned numSymbols = mlirAffineMapGetNumSymbols(start);
assert(mlirAffineMapGetNumSymbols(step) == numSymbols &&
"expected start and step to have the same number of dimensions");
assert(mlirAffineMapGetNumSymbols(stride) == numSymbols &&
"expected start and stride to have the same number of dimensions");
llvm::SmallVector<mlir::Attribute> symbolAttrs = llvm::map_to_vector(
llvm::make_range(symbolNames, symbolNames + numSymbols),
[](MlirAttribute attr) { return unwrap(attr); });
Expand All @@ -131,7 +126,7 @@ MlirAttribute mlirWaveIndexMappingAttrGet(MlirContext mlirCtx,
"WaveIndexSymbolAttr attributes");

return wrap(wave::WaveIndexMappingAttr::get(ctx, symbolAttrs, unwrap(start),
unwrap(step), unwrap(stride)));
step, stride));
}

MlirTypeID mlirWaveIndexMappingAttrGetTypeID() {
Expand All @@ -142,12 +137,12 @@ MlirAffineMap mlirWaveIndexMappingAttrGetStart(MlirAttribute attr) {
return wrap(llvm::cast<wave::WaveIndexMappingAttr>(unwrap(attr)).getStart());
}

MlirAffineMap mlirWaveIndexMappingAttrGetStep(MlirAttribute attr) {
return wrap(llvm::cast<wave::WaveIndexMappingAttr>(unwrap(attr)).getStep());
uint64_t mlirWaveIndexMappingAttrGetStep(MlirAttribute attr) {
return llvm::cast<wave::WaveIndexMappingAttr>(unwrap(attr)).getStep();
}

MlirAffineMap mlirWaveIndexMappingAttrGetStride(MlirAttribute attr) {
return wrap(llvm::cast<wave::WaveIndexMappingAttr>(unwrap(attr)).getStride());
uint64_t mlirWaveIndexMappingAttrGetStride(MlirAttribute attr) {
return llvm::cast<wave::WaveIndexMappingAttr>(unwrap(attr)).getStride();
}

intptr_t mlirWaveIndexMappingAttrGetNumSymbols(MlirAttribute attr) {
Expand Down
40 changes: 17 additions & 23 deletions water/lib/Dialect/Wave/IR/WaveAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,29 +147,25 @@ Attribute WaveIndexMappingAttr::parse(AsmParser &parser, Type type) {
}

AffineExpr startExpr;
AffineExpr stepExpr;
AffineExpr strideExpr;
if (failed(parseExprWithNames(symbolNames, startExpr, parser)) ||
parser.parseComma() ||
failed(parseExprWithNames(symbolNames, stepExpr, parser)) ||
parser.parseComma() ||
failed(parseExprWithNames(symbolNames, strideExpr, parser)) ||
parser.parseRParen()) {
parser.emitError(
parser.getCurrentLocation(),
"expected three affine expressions for '(start, step, stride)'");
uint64_t step, stride;
llvm::SMLoc currentLocation = parser.getCurrentLocation();
if (failed(parseExprWithNames(symbolNames, startExpr, parser))) {
parser.emitError(currentLocation, "expected affine expression");
return {};
}
if (failed(parser.parseComma()) || failed(parser.parseInteger(step)) ||
failed(parser.parseComma()) || failed(parser.parseInteger(stride)) ||
failed(parser.parseRParen())) {
return {};
}

// Build maps
auto startMap = AffineMap::get(
/*numDims=*/0, /*numSymbols=*/symbolNames.size(), startExpr, context);
auto stepMap = AffineMap::get(
/*numDims=*/0, /*numSymbols=*/symbolNames.size(), stepExpr, context);
auto strideMap = AffineMap::get(
/*numDims=*/0, /*numSymbols=*/symbolNames.size(), strideExpr, context);
auto startMap = startExpr
? AffineMap::get(
/*numDims=*/0, /*numSymbols=*/symbolNames.size(),
startExpr, context)
: AffineMap();

return get(context, symbolNameAttrs, startMap, stepMap, strideMap);
return get(context, symbolNameAttrs, startMap, step, stride);
}

void WaveIndexMappingAttr::print(AsmPrinter &printer) const {
Expand All @@ -196,16 +192,14 @@ void WaveIndexMappingAttr::print(AsmPrinter &printer) const {
}
// All three maps share the same symbol set and order.
std::string startStr = stringifyWithNames(getStart(), names);
std::string stepStr = stringifyWithNames(getStep(), names);
std::string strideStr = stringifyWithNames(getStride(), names);

printer << "(" << startStr << ", " << stepStr << ", " << strideStr << ")";
printer << "(" << startStr << ", " << getStep() << ", " << getStride() << ")";
}

LogicalResult
WaveIndexMappingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<Attribute> symbols, AffineMap start,
AffineMap step, AffineMap stride) {
uint64_t step, uint64_t stride) {
if (!llvm::all_of(symbols, llvm::IsaPred<WaveSymbolAttr, WaveIndexSymbolAttr,
WaveIterSymbolAttr>)) {
return emitError() << "expected all symbols to be a WaveSymbolAttr, "
Expand Down
62 changes: 13 additions & 49 deletions water/lib/Dialect/Wave/IR/WaveInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,35 +50,20 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) {
<< val;

auto mapping = cast<wave::WaveIndexMappingAttr>(val);
auto checkNoDims = [&](AffineMap map, StringRef which) -> LogicalResult {
if (map.getNumDims() != 0)
return op->emitError(
"wave indexing " + which +
" map should have no dimensions, only symbols, got ")
<< map.getNumDims() << " dimensions for symbol "
<< named.getName();
return success();
};

AffineMap startMap = mapping.getStart();
AffineMap stepMap = mapping.getStep();
AffineMap strideMap = mapping.getStride();
if (failed(checkNoDims(startMap, "start")) ||
failed(checkNoDims(stepMap, "step")) ||
failed(checkNoDims(strideMap, "stride")))
return failure();
if (mapping.getStart().getNumDims() != 0) {
return op->emitError("wave indexing start map should have no "
"dimensions, only symbols, got ")
<< mapping.getStart().getNumDims() << " dimensions for symbol "
<< named.getName();
}

unsigned declared = mapping.getSymbols().size();
if (startMap.getNumSymbols() != declared ||
stepMap.getNumSymbols() != declared ||
strideMap.getNumSymbols() != declared) {
if (mapping.getStart().getNumSymbols() != declared) {
return op->emitError(
"inconsistent symbol count between symbol_names and "
"affine maps for index symbol '")
<< named.getName() << "' (expected " << declared
<< ", got start=" << startMap.getNumSymbols()
<< ", step=" << stepMap.getNumSymbols()
<< ", stride=" << strideMap.getNumSymbols() << ")";
<< named.getName() << "' (expected " << declared << ", got "
<< mapping.getStart().getNumSymbols() << ")";
}

for (auto symbol : mapping.getSymbols()) {
Expand Down Expand Up @@ -691,13 +676,9 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join(
auto isThreadDependent = [&](wave::WaveIndexMappingAttr val) -> bool {
llvm::SmallVector<unsigned> threadLikeSymbolPositions;
getThreadLikeSymbolPositions(val.getSymbols(), threadLikeSymbolPositions);
return llvm::any_of(
llvm::ArrayRef{val.getStart(), val.getStep(), val.getStride()},
[&](mlir::AffineMap map) {
return llvm::any_of(threadLikeSymbolPositions, [&](unsigned pos) {
return map.isFunctionOfSymbol(pos);
});
});
return llvm::any_of(threadLikeSymbolPositions, [&](unsigned pos) {
return val.getStart().isFunctionOfSymbol(pos);
});
};

// If both are thread-dependent or thread-independent, the only acceptable
Expand Down Expand Up @@ -728,18 +709,6 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join(
alignMapSymbols(threadIndependentMapping.getStart(),
threadIndependentSymbols, allSymbols);

mlir::AffineMap threadDependentStep = alignMapSymbols(
threadDependentMapping.getStep(), threadDependentSymbols, allSymbols);
mlir::AffineMap threadIndependentStep =
alignMapSymbols(threadIndependentMapping.getStep(),
threadIndependentSymbols, allSymbols);

mlir::AffineMap threadDependentStride = alignMapSymbols(
threadDependentMapping.getStride(), threadDependentSymbols, allSymbols);
mlir::AffineMap threadIndependentStride =
alignMapSymbols(threadIndependentMapping.getStride(),
threadIndependentSymbols, allSymbols);

// Subtract the thread-independent from thread-dependent for each.
auto subtractMaps = [&](mlir::AffineMap a,
mlir::AffineMap b) -> mlir::AffineMap {
Expand All @@ -754,10 +723,6 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join(
};
mlir::AffineMap newStart =
subtractMaps(threadDependentStart, threadIndependentStart);
mlir::AffineMap newStep =
subtractMaps(threadDependentStep, threadIndependentStep);
mlir::AffineMap newStride =
subtractMaps(threadDependentStride, threadIndependentStride);

llvm::SmallVector<unsigned> threadLikeSymbolPositions;
getThreadLikeSymbolPositions(allSymbols, threadLikeSymbolPositions);
Expand All @@ -775,8 +740,7 @@ wave::IndexExprsLatticeStorage wave::IndexExprsLatticeStorage::join(
return !walkResult.wasInterrupted();
};

if (!isOnlyThreadDependent(newStart) || !isOnlyThreadDependent(newStep) ||
!isOnlyThreadDependent(newStride))
if (!isOnlyThreadDependent(newStart))
return top();

result[namedAttr.getName()] = threadDependentMapping;
Expand Down
50 changes: 24 additions & 26 deletions water/lib/Dialect/Wave/IR/WaveOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -597,8 +597,8 @@ struct MmaSingleIndexExprBuilder {
// Set the parameter of the index expression for the currently selected
// dimension.
MmaSingleIndexExprBuilder &offset(mlir::AffineExpr expr);
MmaSingleIndexExprBuilder &size(int64_t value);
MmaSingleIndexExprBuilder &stride(int64_t value);
MmaSingleIndexExprBuilder &size(uint64_t value);
MmaSingleIndexExprBuilder &stride(uint64_t value);

// Select the dimension.
MmaSingleIndexExprBuilder &m();
Expand All @@ -609,7 +609,9 @@ struct MmaSingleIndexExprBuilder {
void populate(llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes) const;

MmaIndexingExprBuilder &parent;
mlir::AffineExpr offsetExpr, sizeExpr, strideExpr;
mlir::AffineExpr offsetExpr;
std::optional<uint64_t> sizeValue;
std::optional<uint64_t> strideValue;
bool enabled;
};

Expand Down Expand Up @@ -657,16 +659,18 @@ struct MmaIndexingExprBuilder {
void populate(llvm::SmallVectorImpl<mlir::NamedAttribute> &attributes) const {
mlir::MLIRContext *ctx = getAnySymbolContext(mSymbol, nSymbol, kSymbol);

auto buildMap = [&](mlir::AffineExpr expr) {
assert(expr &&
"expected offset/size/stride to be set up for all symbols");
return mlir::AffineMap::get(/*dimCount=*/0,
/*symbolCount=*/symbols.size(), expr, ctx);
};
auto buildOne = [&](const MmaSingleIndexExprBuilder &builder) {
assert(builder.offsetExpr &&
"expected offset to be set up for all symbols");
assert(builder.sizeValue && "expected size to be set up for all symbols");
assert(builder.strideValue &&
"expected stride to be set up for all symbols");
return wave::WaveIndexMappingAttr::get(
ctx, symbols, buildMap(builder.offsetExpr),
buildMap(builder.sizeExpr), buildMap(builder.strideExpr));
ctx, symbols,
mlir::AffineMap::get(/*dimCount=*/0,
/*symbolCount=*/symbols.size(),
builder.offsetExpr, ctx),
*builder.sizeValue, *builder.strideValue);
};

if (mSymbol)
Expand All @@ -691,19 +695,19 @@ MmaSingleIndexExprBuilder::offset(mlir::AffineExpr expr) {
return *this;
}

MmaSingleIndexExprBuilder &MmaSingleIndexExprBuilder::size(int64_t value) {
MmaSingleIndexExprBuilder &MmaSingleIndexExprBuilder::size(uint64_t value) {
if (!enabled)
return *this;
assert(!sizeExpr && "expected size to be set only once");
sizeExpr = mlir::getAffineConstantExpr(value, offsetExpr.getContext());
assert(!sizeValue && "expected size to be set only once");
sizeValue = value;
return *this;
}

MmaSingleIndexExprBuilder &MmaSingleIndexExprBuilder::stride(int64_t value) {
MmaSingleIndexExprBuilder &MmaSingleIndexExprBuilder::stride(uint64_t value) {
if (!enabled)
return *this;
assert(!strideExpr && "expected stride to be set only once");
strideExpr = mlir::getAffineConstantExpr(value, offsetExpr.getContext());
assert(!strideValue && "expected stride to be set only once");
strideValue = value;
return *this;
}

Expand Down Expand Up @@ -922,9 +926,7 @@ applyConstraint(ConstraintAttrT constraint,
/*dimCount=*/0, symbols.size(),
symbolExpr * constraint.getTileSize().getMap().getResult(0));
if (baseMapping == nullptr)
return wave::WaveIndexMappingAttr::get(
context, symbols, map, mlir::AffineMap::getConstantMap(1, context),
mlir::AffineMap::getConstantMap(1, context));
return wave::WaveIndexMappingAttr::get(context, symbols, map, 1, 1);

llvm::SmallVector<mlir::Attribute> allSymbols;
wave::aggregateAllSymbols(
Expand All @@ -934,15 +936,11 @@ applyConstraint(ConstraintAttrT constraint,

mlir::AffineMap baseStart = alignMapSymbols(
baseMapping.getStart(), baseMapping.getSymbols(), allSymbols);
mlir::AffineMap baseStep = alignMapSymbols(
baseMapping.getStep(), baseMapping.getSymbols(), allSymbols);
mlir::AffineMap baseStride = alignMapSymbols(
baseMapping.getStride(), baseMapping.getSymbols(), allSymbols);
map = alignMapSymbols(map, symbols, allSymbols);
map = mlir::AffineMap::get(/*dimCount=*/0, allSymbols.size(),
baseStart.getResult(0) + map.getResult(0));
return wave::WaveIndexMappingAttr::get(context, allSymbols, map, baseStep,
baseStride);
return wave::WaveIndexMappingAttr::get(
context, allSymbols, map, baseMapping.getStep(), baseMapping.getStride());
}

// Create an index mapping induced by the given constraint. Combine it with the
Expand Down
Loading
Loading