1
1
#include " mlir/IR/TypeUtilities.h"
2
2
#include " mlir/Pass/PassManager.h"
3
- #include " mlir/Transforms/Passes.h"
4
- #include " triton/Analysis/AxisInfo.h"
5
3
#include " triton/Dialect/Triton/IR/Dialect.h"
6
4
#include " triton/Dialect/Triton/IR/Types.h"
7
5
#include " triton/Dialect/TritonGPU/IR/Attributes.h"
8
6
#include " triton/Dialect/TritonGPU/IR/Dialect.h"
9
7
#include " triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
10
- #include " triton/Dialect/TritonGPU/Transforms/Passes.h"
11
8
#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
12
9
#include " triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
13
10
#include " triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
14
11
#include " triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
15
- #include " triton/Tools/Sys/GetEnv.hpp"
16
12
#include " llvm/ADT/PriorityWorklist.h"
17
- #include " llvm/ADT/Sequence.h"
18
- #include " llvm/Support/Casting.h"
19
- #include " llvm/Support/VersionTuple.h"
13
+ #include < algorithm>
20
14
#include < memory>
21
15
#include < unordered_set>
22
16
@@ -35,6 +29,7 @@ struct UseInfo {
35
29
TypedValue<tt::TensorDescType> descriptor;
36
30
Operation *use;
37
31
Attribute desiredSharedEncoding;
32
+ SmallVector<int64_t > shape;
38
33
ttg::CTALayoutAttr ctaLayout;
39
34
};
40
35
@@ -72,6 +67,14 @@ ttg::CTALayoutAttr getCtaLayoutFromEncoding(Attribute encoding) {
72
67
layout.getCTASplitNum (), layout.getCTAOrder ());
73
68
}
74
69
70
+ SmallVector<int64_t > expandToRank (ArrayRef<int64_t > shape, int rank) {
71
+ SmallVector<int64_t > result (rank, 1 );
72
+ assert (shape.size () <= rank);
73
+ auto rankDiff = rank - shape.size ();
74
+ std::copy (shape.begin (), shape.end (), result.begin () + rankDiff);
75
+ return result;
76
+ }
77
+
75
78
std::optional<UseInfo> getUseInfo (Operation *op) {
76
79
UseInfo info;
77
80
info.use = op;
@@ -81,6 +84,9 @@ std::optional<UseInfo> getUseInfo(Operation *op) {
81
84
auto encoding = info.desiredSharedEncoding ? info.desiredSharedEncoding
82
85
: load.getType ().getEncoding ();
83
86
info.ctaLayout = ttg::getCTALayout (encoding);
87
+ auto shape = load.getResult ().getType ().getShape ();
88
+ auto rank = load.getDesc ().getType ().getBlockType ().getRank ();
89
+ info.shape = expandToRank (shape, rank);
84
90
return info;
85
91
}
86
92
if (auto gather = dyn_cast<tt::DescriptorGatherOp>(op)) {
@@ -89,18 +95,27 @@ std::optional<UseInfo> getUseInfo(Operation *op) {
89
95
auto encoding = info.desiredSharedEncoding ? info.desiredSharedEncoding
90
96
: gather.getType ().getEncoding ();
91
97
info.ctaLayout = ttg::getCTALayout (encoding);
98
+ auto shape = gather.getResult ().getType ().getShape ();
99
+ auto rank = gather.getDesc ().getType ().getBlockType ().getRank ();
100
+ info.shape = expandToRank (shape, rank);
92
101
return info;
93
102
}
94
103
if (auto store = dyn_cast<tt::DescriptorStoreOp>(op)) {
95
104
info.descriptor = store.getDesc ();
96
105
auto encoding = store.getSrc ().getType ().getEncoding ();
97
106
info.ctaLayout = ttg::getCTALayout (encoding);
107
+ auto shape = store.getSrc ().getType ().getShape ();
108
+ auto rank = store.getDesc ().getType ().getBlockType ().getRank ();
109
+ info.shape = expandToRank (shape, rank);
98
110
return info;
99
111
}
100
112
if (auto scatter = dyn_cast<tt::DescriptorScatterOp>(op)) {
101
113
info.descriptor = scatter.getDesc ();
102
114
auto encoding = scatter.getSrc ().getType ().getEncoding ();
103
115
info.ctaLayout = ttg::getCTALayout (encoding);
116
+ auto shape = scatter.getSrc ().getType ().getShape ();
117
+ auto rank = scatter.getDesc ().getType ().getBlockType ().getRank ();
118
+ info.shape = expandToRank (shape, rank);
104
119
return info;
105
120
}
106
121
return std::nullopt;
@@ -109,12 +124,15 @@ std::optional<UseInfo> getUseInfo(Operation *op) {
109
124
struct EncodingInfo {
110
125
Attribute desiredEncoding;
111
126
ttg::CTALayoutAttr ctaLayout;
127
+ // Shape may be different from the descriptor block shape for gather/scatter
128
+ // use case
129
+ SmallVector<int64_t > shape;
112
130
bool forcedToDefault = false ;
113
131
114
132
bool operator ==(const EncodingInfo &other) const {
115
133
return desiredEncoding == other.desiredEncoding &&
116
134
ctaLayout == other.ctaLayout &&
117
- forcedToDefault == other.forcedToDefault ;
135
+ forcedToDefault == other.forcedToDefault && shape == other. shape ;
118
136
}
119
137
};
120
138
@@ -123,7 +141,8 @@ struct EncodingInfo {
123
141
template <> struct std ::hash<EncodingInfo> {
124
142
size_t operator ()(const EncodingInfo &einfo) const {
125
143
return llvm::hash_combine (einfo.desiredEncoding , einfo.ctaLayout ,
126
- einfo.forcedToDefault );
144
+ einfo.forcedToDefault ,
145
+ ArrayRef<int64_t >(einfo.shape ));
127
146
}
128
147
};
129
148
@@ -172,6 +191,21 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
172
191
// Always propagate forcedToDefault
173
192
result.forcedToDefault = lhs.forcedToDefault || rhs.forcedToDefault ;
174
193
194
+ if (result.forcedToDefault )
195
+ return result;
196
+
197
+ if (lhs.shape .empty () || lhs.shape == rhs.shape )
198
+ result.shape = rhs.shape ;
199
+ else if (rhs.shape .empty ())
200
+ result.shape = lhs.shape ;
201
+ else {
202
+ assert (lhs.shape .size () == rhs.shape .size ());
203
+ auto rank = lhs.shape .size ();
204
+ result.shape .reserve (rank);
205
+ for (int i = 0 ; i < rank; ++i)
206
+ result.shape .push_back (std::min (lhs.shape [i], rhs.shape [i]));
207
+ }
208
+
175
209
SetVector<ttg::CTALayoutAttr> ctaLayouts;
176
210
if (lhs.ctaLayout )
177
211
ctaLayouts.insert (lhs.ctaLayout );
@@ -190,9 +224,6 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
190
224
break ;
191
225
}
192
226
193
- if (result.forcedToDefault )
194
- return result;
195
-
196
227
SetVector<Attribute> desiredEncodings;
197
228
if (lhs.desiredEncoding )
198
229
desiredEncodings.insert (lhs.desiredEncoding );
@@ -213,23 +244,32 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
213
244
}
214
245
215
246
Attribute getFallbackSharedEncoding (RankedTensorType tensorType,
216
- ttg::CTALayoutAttr ctaLayout) {
247
+ ttg::CTALayoutAttr ctaLayout,
248
+ ArrayRef<int64_t > usageShape) {
217
249
auto ctx = tensorType.getContext ();
218
250
SmallVector<unsigned > order;
219
251
for (int i = tensorType.getRank () - 1 ; i >= 0 ; --i)
220
252
order.push_back (i);
221
253
254
+ ArrayRef<int64_t > shape =
255
+ usageShape.empty () ? tensorType.getShape () : usageShape;
222
256
if (!ctaLayout)
223
257
ctaLayout = ttg::CTALayoutAttr::getDefault (ctx, tensorType.getRank ());
224
258
else if (ctaLayout.getRank () != tensorType.getRank ())
225
- ctaLayout = ttng::updateCTALayoutForShape (ctaLayout, tensorType.getShape ());
259
+ ctaLayout = ttng::updateCTALayoutForShape (ctaLayout, shape);
260
+
261
+ auto elemTy = tensorType.getElementType ();
262
+ auto shapePerCTA = ttg::getShapePerCTA (ctaLayout.getCTASplitNum (), shape);
263
+ unsigned eleBitWidth = tensorType.getElementType ().getIntOrFloatBitWidth ();
226
264
227
- if (tensorType.getRank () == 1 ) {
265
+ auto contigDimSizeInBytes = shapePerCTA.back () * eleBitWidth / 8 ;
266
+ auto rank = tensorType.getRank ();
267
+ if (rank == 1 || contigDimSizeInBytes < 32 || shape[rank - 2 ] < 8 ) {
228
268
return ttg::SwizzledSharedEncodingAttr::get (ctx, 1 , 1 , 1 , order, ctaLayout);
229
269
}
230
- return ttg::NVMMASharedEncodingAttr::get (
231
- ctx, tensorType. getShape (), order, ctaLayout, tensorType.getElementType (),
232
- /* fp4Padded*/ false );
270
+ return ttg::NVMMASharedEncodingAttr::get (ctx, shape, order, ctaLayout,
271
+ tensorType.getElementType (),
272
+ /* fp4Padded*/ false );
233
273
}
234
274
235
275
tt::TensorDescType getTensorDescTypeWithEncoding (Operation *op,
@@ -274,17 +314,19 @@ void assignMemoryLayouts(tt::FuncOp &func) {
274
314
// fallback to default encoding
275
315
for (auto blockArg : func.getBlocks ().front ().getArguments ())
276
316
if (auto desc = dyn_cast<TypedValue<tt::TensorDescType>>(blockArg))
277
- updateEncoding ({desc}, EncodingInfo{{}, {}, /* forcedToDefault=*/ true });
317
+ updateEncoding ({desc},
318
+ EncodingInfo{{}, {}, {}, /* forcedToDefault=*/ true });
278
319
279
320
func.walk ([&](Operation *op) {
280
321
if (auto info = getUseInfo (op)) {
281
- updateEncoding (info->descriptor , EncodingInfo{info->desiredSharedEncoding ,
282
- info->ctaLayout });
322
+ updateEncoding (info->descriptor ,
323
+ EncodingInfo{info->desiredSharedEncoding , info->ctaLayout ,
324
+ info->shape });
283
325
} else {
284
326
bool forcedToDefault =
285
327
isa<tt::CallOp, tt::ReturnOp, tt::ReinterpretTensorDescOp>(op);
286
328
auto einfo =
287
- internEncoding (encodings, EncodingInfo{{}, {}, forcedToDefault});
329
+ internEncoding (encodings, EncodingInfo{{}, {}, {}, forcedToDefault});
288
330
289
331
auto setEncoding = [&](Value v) {
290
332
auto typedVal = cast<TypedValue<tt::TensorDescType>>(v);
@@ -344,9 +386,10 @@ void assignMemoryLayouts(tt::FuncOp &func) {
344
386
if (einfo->desiredEncoding ) {
345
387
newEncoding = einfo->desiredEncoding ;
346
388
} else if (einfo->forcedToDefault ) {
347
- newEncoding = getFallbackSharedEncoding (existingTy, {});
389
+ newEncoding = getFallbackSharedEncoding (existingTy, {}, {} );
348
390
} else {
349
- newEncoding = getFallbackSharedEncoding (existingTy, einfo->ctaLayout );
391
+ newEncoding =
392
+ getFallbackSharedEncoding (existingTy, einfo->ctaLayout , einfo->shape );
350
393
}
351
394
desc.setType (getTensorDescTypeWithEncoding (desc.getDefiningOp (), existingTy,
352
395
newEncoding));
@@ -356,14 +399,14 @@ void assignMemoryLayouts(tt::FuncOp &func) {
356
399
SmallVector<Type> resultTys (func.getResultTypes ());
357
400
for (auto [i, argTy] : llvm::enumerate (argTys)) {
358
401
if (auto descTy = dyn_cast<tt::TensorDescType>(argTy)) {
359
- auto encoding = getFallbackSharedEncoding (descTy.getBlockType (), {});
402
+ auto encoding = getFallbackSharedEncoding (descTy.getBlockType (), {}, {} );
360
403
argTys[i] = getTensorDescTypeWithEncoding (nullptr , descTy.getBlockType (),
361
404
encoding);
362
405
}
363
406
}
364
407
for (auto [i, resultTy] : llvm::enumerate (resultTys)) {
365
408
if (auto descTy = dyn_cast<tt::TensorDescType>(resultTy)) {
366
- auto encoding = getFallbackSharedEncoding (descTy.getBlockType (), {});
409
+ auto encoding = getFallbackSharedEncoding (descTy.getBlockType (), {}, {} );
367
410
resultTys[i] = getTensorDescTypeWithEncoding (
368
411
nullptr , descTy.getBlockType (), encoding);
369
412
}
0 commit comments