Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LHS Registers Part 2 - Pipelining #19

Open
wants to merge 18 commits into
base: llvm-head
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
908 changes: 908 additions & 0 deletions BUILD

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion cmake/llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
df0864e761107b07e38f5503e0cbee0cebb4c5e8
29b92d07746fac26cd64c914bc9c5c3833974f6d
25 changes: 11 additions & 14 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ compared to 1*64 when the hasLeadingOffset is false.
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
// ---- begin Ampere & Hopper ----
if (mmaEnc.isAmpere() || mmaEnc.isHopper()) {
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
perPhase = std::max<int>(perPhase, 1);
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
Expand Down Expand Up @@ -397,13 +397,6 @@ compared to 1*64 when the hasLeadingOffset is false.
llvm_unreachable("invalid operand index");
}

// ---- begin version 3 ----
if (mmaEnc.isHopper()) {
llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr"
" is Hopper has not been implemented yet");
return $_get(context, 1, 1, 1, order, CTALayout, true);
}

// ---- not implemented ----
llvm_unreachable("unsupported swizzling for provided MMA version");
}]>,
Expand Down Expand Up @@ -1222,7 +1215,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
SmallVector<int> getMMAv1Rep(int opIdx) const;
SmallVector<int> getMMAv1ShapePerWarp(int opIdx) const;
int getMMAv1Vec(int opIdx) const;
SmallVector<int64_t> getMMAv2Rep(ArrayRef<int64_t> shape,
SmallVector<int64_t> getMMAv2OrV3Rep(ArrayRef<int64_t> shape,
int bitwidth, int opIdx) const;

bool supportReduction() const {
Expand Down Expand Up @@ -1317,6 +1310,10 @@ The parent field is the layout of d.
kWidth defines number of consecutive elements stored by one thread along k dimension.
Some layouts do not use this parameter, either because they have a fixed number of
elements along the K dim, or they use all elements of the tensor along the K dim.

We require kWidth to be provided for Hopper because the dtype at loading might be
different from the dtype at WGMMA, due to casting. The kWidth is determined by the
dtype at WGMMA.
}];

let parameters = (
Expand All @@ -1327,16 +1324,16 @@ elements along the K dim, or they use all elements of the tensor along the K dim
);

let builders = [
// Specially for MMAV1(Volta)
// For MMAV2 and V3
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"Type":$eltTy), [{
NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent);
if (!parentAttr || !parentAttr.isAmpere())
if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper()))
return $_get(context, opIdx, parent, 0);
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
unsigned MMAv2kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, MMAv2kWidth);
unsigned kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, kWidth);
}]>
];

Expand Down
17 changes: 13 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,15 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
// If the parent of the dot operand is in block encoding, we don't need to
// reorder elements
auto parentEncoding = dyn_cast<NvidiaMmaEncodingAttr>(ouEncoding.getParent());
if (!parentEncoding)
if (!parentEncoding || parentEncoding.isHopper())
return values;
size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth();
size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth();
auto ouEltTy = ouTensorTy.getElementType();
if (inBitWidth == ouBitWidth)
return values;
if (inBitWidth == 16 && ouBitWidth == 32) {
if ((inBitWidth == 16 && ouBitWidth == 32) ||
(inBitWidth == 32 && ouBitWidth == 16)) {
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 8) {
ret.push_back(values[i]);
Expand Down Expand Up @@ -86,8 +87,12 @@ SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
if (!tensorTy)
return inValues;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
if (!encoding)
return inValues;
auto parentEnc = dyn_cast<NvidiaMmaEncodingAttr>(encoding.getParent());
if (!parentEnc || parentEnc.isHopper())
return inValues;

SmallVector<Value> outValues;
for (auto v : inValues) {
// cast i32 to appropriate eltType vector and extract elements
Expand All @@ -108,8 +113,12 @@ SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
if (!tensorTy)
return inValues;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
if (!encoding)
return inValues;
auto parentEnc = dyn_cast<NvidiaMmaEncodingAttr>(encoding.getParent());
if (!parentEnc || parentEnc.isHopper())
return inValues;

SmallVector<Value> outValues;
auto eltType = typeConverter->convertType(tensorTy.getElementType());
int vecWidth = 32 / eltType.getIntOrFloatBitWidth();
Expand Down
32 changes: 24 additions & 8 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1022,13 +1022,17 @@ LogicalResult DotOperandEncodingAttr::verify(
return emitError() << "triton_gpu.dot_op parent paramenter cannot be null";
}
if (auto parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (kWidth != 0 && !parentAttr.isAmpere())
if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "triton_gpu.dot_op kWidth parameter can only be "
"non-zero for Ampere MMA parent";
if (kWidth == 0 && parentAttr.isAmpere())
"non-zero for Ampere or Hopper MMA parent";
if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError()
<< "triton_gpu.dot_op kWidth parameter is mandatory for "
"Ampere MMA parent";
"Ampere or Hopper MMA parent";
if (opIdx != 0 && parentAttr.isHopper())
return emitError()
<< "triton_gpu.dot_op opIdx parameter must be 0 for "
"Hopper MMA parent";
return success();
}

Expand Down Expand Up @@ -1957,17 +1961,17 @@ SmallVector<int> NvidiaMmaEncodingAttr::getMMAv1ShapePerWarp(int opIdx) const {
int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const {
return 2 * getMMAv1Rep(opIdx)[opIdx];
}
SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2OrV3Rep(ArrayRef<int64_t> shape,
int bitwidth,
int opIdx) const {
assert(isAmpere() || isHopper());
auto rank = shape.size();
auto warpsPerCTA = getWarpsPerCTA();
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
int numRepBatch =
rank == 3
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))
: 1;
assert(isAmpere());

if (opIdx == 0)
return {numRepBatch,
Expand All @@ -1982,18 +1986,25 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
warpsPerCTA[rank - 1]))};
}
}

unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands(
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
auto shapePerCTA = getShapePerCTA(*this, shape);
int warpsPerCTAM = getWarpsPerCTA()[0];
int warpsPerCTAN = getWarpsPerCTA()[1];
// H100
if (isHopper()) {
return getTotalElemsPerThread(shape, eltTy);
assert(opIdx == 0);
auto instrMNK = getInstrShape();
int repM = ceil<unsigned>(shapePerCTA[0], instrMNK[0] * warpsPerCTAM);
int repK = ceil<unsigned>(shapePerCTA[1], instrMNK[2]);
// For each WGMMA instr, a 2x2 matrix fragment is loaded. Each thread holds
// kWidth elements for each quadrant. WGMMA is repeated repM * repK times.
return 4 * kWidth * repM * repK;
}
// A100
if (isAmpere()) {
auto rep = getMMAv2Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth(), opIdx);
auto rep = getMMAv2OrV3Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth(), opIdx);
if (opIdx == 0)
return 4 * rep[0] * rep[1] * rep[2];
if (opIdx == 1)
Expand Down Expand Up @@ -2720,6 +2731,11 @@ struct CanonicalizeConvertFromAlloc
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding, so we want to keep this layout conversion.
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
convert.getSrc().getType().getEncoding()))
return failure();
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
op, op->getResult(0).getType(), convert.getSrc());
return mlir::success();
Expand Down
24 changes: 24 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,21 @@ static Value getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter,
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
newLayout, SharedMemorySpace);
rewriter.setInsertionPointAfterValue(arg);

// LocalAllocOp lowering doesn't support going from DotOperandEncoding
// to SharedEncoding.
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
argType.getEncoding())) {
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
// then pass it to the LocalAllocOp.
auto newArgType = RankedTensorType::get(
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
auto dotOperandToBlockedCvt =
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
dotOperandToBlockedCvt);
}

return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
}

Expand All @@ -162,6 +177,15 @@ class BlockedToMMA : public mlir::OpRewritePattern<DotOp> {
mutable llvm::DenseMap<Operation *, unsigned> dotOpInstNs;

static bool bwdFilter(Operation *op) {
// Dot operand layout assignment to Predicates are not currently supported
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
// condition limits visibility of the original bit-width so that predicate
// are not considered, hence, kwidth can never be = 32.
if (isa<arith::UIToFPOp>(op)) {
Type srcType = getElementTypeOrSelf(op->getOperand(0));
if (srcType.isInteger(1))
return false;
}
return op->getNumOperands() == 1 &&
(isa<FpToFpOp, BitcastOp, ConvertLayoutOp>(op) ||
isPureUnaryInlineAsm(op) ||
Expand Down
Loading