Skip to content

Commit ac6492f

Browse files
Merge OpenAI Triton commit a85fab0 (#3755)
This PR change the Triton base from e196446 to a85fab0 (Mar 24). Pass rate: 89.76%
2 parents 77b8626 + 8203b1d commit ac6492f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1802
-911
lines changed

README.md

+4
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ arbitrary LLVM version.
138138
during the build. By default, this is the user's home directory. It
139139
can be changed anytime.
140140

141+
- If you're running out of memory when building Triton, specify the `MAX_JOBS`
142+
environment variable (to the `pip install -e python` command) to limit the
143+
number of jobs.
144+
141145
- Pass `--no-build-isolation` to `pip install` to make nop builds faster.
142146
Without this, every invocation of `pip install` uses a different symlink to
143147
cmake, and this forces ninja to rebuild most of the `.a` files.

bin/RegisterTritonDialects.h

+1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
9797
mlir::registerTritonAMDGPUCanonicalizePointers();
9898
mlir::registerTritonAMDGPUConvertToBufferOps();
9999
mlir::registerTritonAMDGPUInThreadTranspose();
100+
mlir::registerTritonAMDGPUCoalesceAsyncCopy();
100101
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
101102
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
102103

cmake/llvm-hash.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2619c2ed584cdf3b38e6743ed3c785223f06e3f7
1+
0ea4fb92648b2aa7cbab486bb493e122b4dcc062

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

+1-1
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def NVMMASharedEncodingAttr :
430430
} else if (contigDimSizeInByte >= 32 && contigDimSizeInByte % 32 == 0) {
431431
swizzlingByteWidth = 32;
432432
} else {
433-
llvm_unreachable("unsupported shared memory layout for MMAv3");
433+
llvm_unreachable("unsupported NVMMA layout (MMAv3 or TMA)");
434434
}
435435
bool transposed = order[0] == 0;
436436
return $_get(context, swizzlingByteWidth, transposed, eleBitWidth, fp4Padded, CTALayout);

include/triton/Dialect/TritonGPU/Transforms/PipelineExpander.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace triton {
2525

2626
/// Options to dictate how loops should be pipelined.
2727
struct PipeliningOption {
28-
/// Lambda returning all the operation in the forOp, with their stage, in the
28+
/// Lambda returning all the operations in the forOp, with their stage, in the
2929
/// order picked for the pipelined loop.
3030
using GetScheduleFnType = std::function<void(
3131
scf::ForOp, std::vector<std::pair<Operation *, unsigned>> &)>;
@@ -54,7 +54,7 @@ struct PipeliningOption {
5454
/// Control whether the transformation checks that the number of iterations is
5555
/// greater or equal to the number of stages and skip the transformation if
5656
/// this is not the case. If the loop is dynamic and this is set to true the
57-
/// pipeliner will have to predicate operations in the the prologue/epilogue.
57+
/// pipeliner will have to predicate operations in the prologue/epilogue.
5858
bool supportDynamicLoops = false;
5959

6060
// Callback to predicate operations when the prologue or epilogue are not

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

+8-3
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,16 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
156156
"elem type .b4x16_p64 supports only 128B swizzling");
157157
}
158158
} else {
159-
op->emitError() << "Unhandled encoding type";
160-
return failure();
159+
auto swizzledEnc = dyn_cast<gpu::SwizzledSharedEncodingAttr>(
160+
op.getType().getBlockType().getEncoding());
161+
if (!swizzledEnc || swizzledEnc.getVec() != 1 ||
162+
swizzledEnc.getPerPhase() != 1 || swizzledEnc.getMaxPhase() != 1) {
163+
op->emitError() << "Unhandled encoding type";
164+
return failure();
165+
}
161166
}
162167

163-
int32_t swizzle_mode;
168+
int32_t swizzle_mode = 0;
164169
if (swizzleBytes == 128) {
165170
swizzle_mode = 3;
166171
} else if (swizzleBytes == 64) {

include/triton/Tools/Sys/GetEnv.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3333
"TRITON_ENABLE_LLVM_DEBUG",
3434
"TRITON_HIP_GLOBAL_PREFETCH",
3535
"TRITON_HIP_LOCAL_PREFETCH",
36+
"TRITON_HIP_USE_ASYNC_COPY",
3637
"TRITON_HIP_USE_BLOCK_PINGPONG",
3738
"TRITON_HIP_USE_IN_THREAD_TRANSPOSE",
3839
"TRITON_LLVM_DEBUG_ONLY",

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

+45-7
Original file line numberDiff line numberDiff line change
@@ -392,8 +392,10 @@ struct MemDescSubviewOpConversion
392392
Location loc = op->getLoc();
393393
auto b = TritonLLVMOpBuilder(loc, rewriter);
394394
auto srcTy = op.getSrc().getType();
395+
auto destTy = op.getResult().getType();
395396
auto llvmElemTy = getTypeConverter()->convertType(srcTy.getElementType());
396397
auto layoutOrder = getOrder(srcTy);
398+
auto enc = srcTy.getEncoding();
397399

398400
// newBase = base + offset
399401
auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
@@ -408,13 +410,49 @@ struct MemDescSubviewOpConversion
408410
for (int i = rankReduced; i < opOffsetVals.size(); i++) {
409411
offsetVals.push_back(b.add(opOffsetVals[i], smemObj.getOffsets()[i]));
410412
}
411-
// Compute the offset based on the original strides of the shared memory
412-
// object
413-
auto offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
414-
auto elemPtrTy = smemObj.getBase().getType();
415-
smemObj = SharedMemoryObject(
416-
b.gep(elemPtrTy, llvmElemTy, smemObj.getBase(), offset), llvmElemTy,
417-
offsetVals);
413+
Value offset = b.undef(i32_ty);
414+
auto allocShape = srcTy.getAllocShape();
415+
bool isSimpleSubview =
416+
allocShape.take_back(destRank) == destTy.getShape() ||
417+
!isa<NVMMASharedEncodingAttr>(enc);
418+
if (!isSimpleSubview) {
419+
auto nvmmaEnc = cast<NVMMASharedEncodingAttr>(enc);
420+
assert(destRank >= 2 &&
421+
"Shape size should be >= 2 when using NVMMAShared encoding");
422+
auto swizzleStride = b.i32_val((nvmmaEnc.getSwizzlingByteWidth() * 8) /
423+
llvmElemTy.getIntOrFloatBitWidth());
424+
offset = b.i32_val(0);
425+
for (auto i = 0; i < opOffsetVals.size() - 2; ++i) {
426+
offset = b.add(offset, b.mul(opOffsetVals[i], opSmemStrides[i]));
427+
}
428+
// newOffset = offset - (stridedOff * swizzledStride + contigOff /
429+
// swizzledStride * tileSize + contigOff % swizzledStride)
430+
// + stridedInc * swizzledStride + contigInc / swizzledStride *
431+
// tileSize + contigInc % swizzledStride
432+
auto stridedDim = destRank - 1 - layoutOrder[0];
433+
auto contigDim = destRank - 1 - layoutOrder[1];
434+
auto stridedOff = smemObj.getOffsets()[stridedDim];
435+
auto contigOff = smemObj.getOffsets()[contigDim];
436+
auto stridedInc = offsetVals[stridedDim];
437+
auto contigInc = offsetVals[contigDim];
438+
int allocStridedDim = allocShape.size() - 1 - layoutOrder[0];
439+
auto tileSize =
440+
b.mul(b.i32_val(allocShape[allocStridedDim]), swizzleStride);
441+
offset = b.sub(offset, b.mul(stridedOff, swizzleStride));
442+
offset = b.sub(offset, b.mul(b.udiv(contigOff, swizzleStride), tileSize));
443+
offset = b.sub(offset, b.urem(contigOff, swizzleStride));
444+
offset = b.add(offset, b.mul(stridedInc, swizzleStride));
445+
offset = b.add(offset, b.mul(b.udiv(contigInc, swizzleStride), tileSize));
446+
offset = b.add(offset, b.urem(contigInc, swizzleStride));
447+
} else {
448+
// Compute the offset based on the original strides of the shared memory
449+
// object
450+
offset = dot(rewriter, loc, opOffsetVals, opSmemStrides);
451+
}
452+
auto base = smemObj.getBase();
453+
auto elemPtrTy = base.getType();
454+
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),
455+
llvmElemTy, offsetVals);
418456
auto retVal = getStructFromSharedMemoryObject(loc, smemObj, rewriter);
419457
rewriter.replaceOp(op, retVal);
420458
return success();

lib/Dialect/TritonGPU/IR/Dialect.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -1203,7 +1203,7 @@ LinearEncodingAttr::orderPerDim(StringAttr dimName,
12031203
// [Note. Divergence of methods wrt. legacy layouts]
12041204
// For smaller shapes where the CTATile is larger than the output
12051205
// tensor, some methods return different values than the legacy layouts. I think
1206-
// this is benign tho. An example: what is the the vector of `warpsPerCTA` if
1206+
// this is benign tho. An example: what is the vector of `warpsPerCTA` if
12071207
// all the warps hold the same data? I think it should be [1, 1], even if we
12081208
// have 4 warps. But perhaps for this we have to add some masking in some
12091209
// places... We'll see

lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp

+69-26
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,16 @@
11
#include "mlir/IR/TypeUtilities.h"
22
#include "mlir/Pass/PassManager.h"
3-
#include "mlir/Transforms/Passes.h"
4-
#include "triton/Analysis/AxisInfo.h"
53
#include "triton/Dialect/Triton/IR/Dialect.h"
64
#include "triton/Dialect/Triton/IR/Types.h"
75
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
86
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
97
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
10-
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
118
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
129
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
1310
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
1411
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
15-
#include "triton/Tools/Sys/GetEnv.hpp"
1612
#include "llvm/ADT/PriorityWorklist.h"
17-
#include "llvm/ADT/Sequence.h"
18-
#include "llvm/Support/Casting.h"
19-
#include "llvm/Support/VersionTuple.h"
13+
#include <algorithm>
2014
#include <memory>
2115
#include <unordered_set>
2216

@@ -35,6 +29,7 @@ struct UseInfo {
3529
TypedValue<tt::TensorDescType> descriptor;
3630
Operation *use;
3731
Attribute desiredSharedEncoding;
32+
SmallVector<int64_t> shape;
3833
ttg::CTALayoutAttr ctaLayout;
3934
};
4035

@@ -72,6 +67,14 @@ ttg::CTALayoutAttr getCtaLayoutFromEncoding(Attribute encoding) {
7267
layout.getCTASplitNum(), layout.getCTAOrder());
7368
}
7469

70+
SmallVector<int64_t> expandToRank(ArrayRef<int64_t> shape, int rank) {
71+
SmallVector<int64_t> result(rank, 1);
72+
assert(shape.size() <= rank);
73+
auto rankDiff = rank - shape.size();
74+
std::copy(shape.begin(), shape.end(), result.begin() + rankDiff);
75+
return result;
76+
}
77+
7578
std::optional<UseInfo> getUseInfo(Operation *op) {
7679
UseInfo info;
7780
info.use = op;
@@ -81,6 +84,9 @@ std::optional<UseInfo> getUseInfo(Operation *op) {
8184
auto encoding = info.desiredSharedEncoding ? info.desiredSharedEncoding
8285
: load.getType().getEncoding();
8386
info.ctaLayout = ttg::getCTALayout(encoding);
87+
auto shape = load.getResult().getType().getShape();
88+
auto rank = load.getDesc().getType().getBlockType().getRank();
89+
info.shape = expandToRank(shape, rank);
8490
return info;
8591
}
8692
if (auto gather = dyn_cast<tt::DescriptorGatherOp>(op)) {
@@ -89,18 +95,27 @@ std::optional<UseInfo> getUseInfo(Operation *op) {
8995
auto encoding = info.desiredSharedEncoding ? info.desiredSharedEncoding
9096
: gather.getType().getEncoding();
9197
info.ctaLayout = ttg::getCTALayout(encoding);
98+
auto shape = gather.getResult().getType().getShape();
99+
auto rank = gather.getDesc().getType().getBlockType().getRank();
100+
info.shape = expandToRank(shape, rank);
92101
return info;
93102
}
94103
if (auto store = dyn_cast<tt::DescriptorStoreOp>(op)) {
95104
info.descriptor = store.getDesc();
96105
auto encoding = store.getSrc().getType().getEncoding();
97106
info.ctaLayout = ttg::getCTALayout(encoding);
107+
auto shape = store.getSrc().getType().getShape();
108+
auto rank = store.getDesc().getType().getBlockType().getRank();
109+
info.shape = expandToRank(shape, rank);
98110
return info;
99111
}
100112
if (auto scatter = dyn_cast<tt::DescriptorScatterOp>(op)) {
101113
info.descriptor = scatter.getDesc();
102114
auto encoding = scatter.getSrc().getType().getEncoding();
103115
info.ctaLayout = ttg::getCTALayout(encoding);
116+
auto shape = scatter.getSrc().getType().getShape();
117+
auto rank = scatter.getDesc().getType().getBlockType().getRank();
118+
info.shape = expandToRank(shape, rank);
104119
return info;
105120
}
106121
return std::nullopt;
@@ -109,12 +124,15 @@ std::optional<UseInfo> getUseInfo(Operation *op) {
109124
struct EncodingInfo {
110125
Attribute desiredEncoding;
111126
ttg::CTALayoutAttr ctaLayout;
127+
// Shape may be different from the descriptor block shape for gather/scatter
128+
// use case
129+
SmallVector<int64_t> shape;
112130
bool forcedToDefault = false;
113131

114132
bool operator==(const EncodingInfo &other) const {
115133
return desiredEncoding == other.desiredEncoding &&
116134
ctaLayout == other.ctaLayout &&
117-
forcedToDefault == other.forcedToDefault;
135+
forcedToDefault == other.forcedToDefault && shape == other.shape;
118136
}
119137
};
120138

@@ -123,7 +141,8 @@ struct EncodingInfo {
123141
template <> struct std::hash<EncodingInfo> {
124142
size_t operator()(const EncodingInfo &einfo) const {
125143
return llvm::hash_combine(einfo.desiredEncoding, einfo.ctaLayout,
126-
einfo.forcedToDefault);
144+
einfo.forcedToDefault,
145+
ArrayRef<int64_t>(einfo.shape));
127146
}
128147
};
129148

@@ -172,6 +191,21 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
172191
// Always propagate forcedToDefault
173192
result.forcedToDefault = lhs.forcedToDefault || rhs.forcedToDefault;
174193

194+
if (result.forcedToDefault)
195+
return result;
196+
197+
if (lhs.shape.empty() || lhs.shape == rhs.shape)
198+
result.shape = rhs.shape;
199+
else if (rhs.shape.empty())
200+
result.shape = lhs.shape;
201+
else {
202+
assert(lhs.shape.size() == rhs.shape.size());
203+
auto rank = lhs.shape.size();
204+
result.shape.reserve(rank);
205+
for (int i = 0; i < rank; ++i)
206+
result.shape.push_back(std::min(lhs.shape[i], rhs.shape[i]));
207+
}
208+
175209
SetVector<ttg::CTALayoutAttr> ctaLayouts;
176210
if (lhs.ctaLayout)
177211
ctaLayouts.insert(lhs.ctaLayout);
@@ -190,9 +224,6 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
190224
break;
191225
}
192226

193-
if (result.forcedToDefault)
194-
return result;
195-
196227
SetVector<Attribute> desiredEncodings;
197228
if (lhs.desiredEncoding)
198229
desiredEncodings.insert(lhs.desiredEncoding);
@@ -213,23 +244,32 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
213244
}
214245

215246
Attribute getFallbackSharedEncoding(RankedTensorType tensorType,
216-
ttg::CTALayoutAttr ctaLayout) {
247+
ttg::CTALayoutAttr ctaLayout,
248+
ArrayRef<int64_t> usageShape) {
217249
auto ctx = tensorType.getContext();
218250
SmallVector<unsigned> order;
219251
for (int i = tensorType.getRank() - 1; i >= 0; --i)
220252
order.push_back(i);
221253

254+
ArrayRef<int64_t> shape =
255+
usageShape.empty() ? tensorType.getShape() : usageShape;
222256
if (!ctaLayout)
223257
ctaLayout = ttg::CTALayoutAttr::getDefault(ctx, tensorType.getRank());
224258
else if (ctaLayout.getRank() != tensorType.getRank())
225-
ctaLayout = ttng::updateCTALayoutForShape(ctaLayout, tensorType.getShape());
259+
ctaLayout = ttng::updateCTALayoutForShape(ctaLayout, shape);
260+
261+
auto elemTy = tensorType.getElementType();
262+
auto shapePerCTA = ttg::getShapePerCTA(ctaLayout.getCTASplitNum(), shape);
263+
unsigned eleBitWidth = tensorType.getElementType().getIntOrFloatBitWidth();
226264

227-
if (tensorType.getRank() == 1) {
265+
auto contigDimSizeInBytes = shapePerCTA.back() * eleBitWidth / 8;
266+
auto rank = tensorType.getRank();
267+
if (rank == 1 || contigDimSizeInBytes < 32 || shape[rank - 2] < 8) {
228268
return ttg::SwizzledSharedEncodingAttr::get(ctx, 1, 1, 1, order, ctaLayout);
229269
}
230-
return ttg::NVMMASharedEncodingAttr::get(
231-
ctx, tensorType.getShape(), order, ctaLayout, tensorType.getElementType(),
232-
/*fp4Padded*/ false);
270+
return ttg::NVMMASharedEncodingAttr::get(ctx, shape, order, ctaLayout,
271+
tensorType.getElementType(),
272+
/*fp4Padded*/ false);
233273
}
234274

235275
tt::TensorDescType getTensorDescTypeWithEncoding(Operation *op,
@@ -274,17 +314,19 @@ void assignMemoryLayouts(tt::FuncOp &func) {
274314
// fallback to default encoding
275315
for (auto blockArg : func.getBlocks().front().getArguments())
276316
if (auto desc = dyn_cast<TypedValue<tt::TensorDescType>>(blockArg))
277-
updateEncoding({desc}, EncodingInfo{{}, {}, /*forcedToDefault=*/true});
317+
updateEncoding({desc},
318+
EncodingInfo{{}, {}, {}, /*forcedToDefault=*/true});
278319

279320
func.walk([&](Operation *op) {
280321
if (auto info = getUseInfo(op)) {
281-
updateEncoding(info->descriptor, EncodingInfo{info->desiredSharedEncoding,
282-
info->ctaLayout});
322+
updateEncoding(info->descriptor,
323+
EncodingInfo{info->desiredSharedEncoding, info->ctaLayout,
324+
info->shape});
283325
} else {
284326
bool forcedToDefault =
285327
isa<tt::CallOp, tt::ReturnOp, tt::ReinterpretTensorDescOp>(op);
286328
auto einfo =
287-
internEncoding(encodings, EncodingInfo{{}, {}, forcedToDefault});
329+
internEncoding(encodings, EncodingInfo{{}, {}, {}, forcedToDefault});
288330

289331
auto setEncoding = [&](Value v) {
290332
auto typedVal = cast<TypedValue<tt::TensorDescType>>(v);
@@ -344,9 +386,10 @@ void assignMemoryLayouts(tt::FuncOp &func) {
344386
if (einfo->desiredEncoding) {
345387
newEncoding = einfo->desiredEncoding;
346388
} else if (einfo->forcedToDefault) {
347-
newEncoding = getFallbackSharedEncoding(existingTy, {});
389+
newEncoding = getFallbackSharedEncoding(existingTy, {}, {});
348390
} else {
349-
newEncoding = getFallbackSharedEncoding(existingTy, einfo->ctaLayout);
391+
newEncoding =
392+
getFallbackSharedEncoding(existingTy, einfo->ctaLayout, einfo->shape);
350393
}
351394
desc.setType(getTensorDescTypeWithEncoding(desc.getDefiningOp(), existingTy,
352395
newEncoding));
@@ -356,14 +399,14 @@ void assignMemoryLayouts(tt::FuncOp &func) {
356399
SmallVector<Type> resultTys(func.getResultTypes());
357400
for (auto [i, argTy] : llvm::enumerate(argTys)) {
358401
if (auto descTy = dyn_cast<tt::TensorDescType>(argTy)) {
359-
auto encoding = getFallbackSharedEncoding(descTy.getBlockType(), {});
402+
auto encoding = getFallbackSharedEncoding(descTy.getBlockType(), {}, {});
360403
argTys[i] = getTensorDescTypeWithEncoding(nullptr, descTy.getBlockType(),
361404
encoding);
362405
}
363406
}
364407
for (auto [i, resultTy] : llvm::enumerate(resultTys)) {
365408
if (auto descTy = dyn_cast<tt::TensorDescType>(resultTy)) {
366-
auto encoding = getFallbackSharedEncoding(descTy.getBlockType(), {});
409+
auto encoding = getFallbackSharedEncoding(descTy.getBlockType(), {}, {});
367410
resultTys[i] = getTensorDescTypeWithEncoding(
368411
nullptr, descTy.getBlockType(), encoding);
369412
}

0 commit comments

Comments
 (0)