Skip to content

Commit eb8c81a

Browse files
authored
Reimplement the blocking pass with backward dataflow analysis framework. (#848)
1 parent 7348112 commit eb8c81a

26 files changed

+2567
-44
lines changed

include/imex/Dialect/XeTile/IR/XeTileOps.td

+6-3
Original file line numberDiff line numberDiff line change
@@ -447,6 +447,9 @@ def XeTile_TileMMAOp : XeTile_Op<"tile_mma", []> {
447447
mlir::Type getElementType() {
448448
return getA().getType().getElementType();
449449
}
450+
mlir::VectorType getOutputType() {
451+
return getOutput().getType();
452+
}
450453
}];
451454

452455
let hasVerifier = 1;
@@ -581,7 +584,7 @@ def XeTile_TransposeOp: XeTile_Op<"transpose", []> {
581584
let hasVerifier = 1;
582585
}
583586

584-
def XeTile_ReduceOp: XeTile_Op<"reduce", []> {
587+
def XeTile_ReductionOp: XeTile_Op<"reduction", []> {
585588
let summary = "performs a reduction operation over a 2D vector.";
586589
let description = [{
587590
It has the same semantics as the `vector.multi_reduction`,
@@ -591,10 +594,10 @@ def XeTile_ReduceOp: XeTile_Op<"reduce", []> {
591594

592595
let arguments = (ins Vector_CombiningKindAttr: $kind,
593596
XeTile_2DOr4DVector: $source,
594-
DenseI64ArrayAttr: $reduction_dim);
597+
DenseI64ArrayAttr: $reduction_dims);
595598
let results = (outs XeTile_2DOr4DVector: $result);
596599
let assemblyFormat = [{
597-
$kind `,` $source $reduction_dim attr-dict `:` type($source) `->` type($result)
600+
$kind `,` $source $reduction_dims attr-dict `:` type($source) `->` type($result)
598601
}];
599602

600603
let hasVerifier = 1;

include/imex/Dialect/XeTile/Transforms/Passes.h

+2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ std::unique_ptr<mlir::Pass> createXeTileInitDuplicatePass();
4040

4141
std::unique_ptr<mlir::Pass>
4242
createXeTileBlockingPass(const std::string &device = "pvc");
43+
std::unique_ptr<mlir::Pass>
44+
createNewXeTileBlockingPass(const std::string &device = "pvc");
4345
std::unique_ptr<mlir::Pass> createXeTileBlockAligningPass();
4446
std::unique_ptr<mlir::Pass> createXeTileWgToSgPass();
4547
std::unique_ptr<mlir::Pass> createXeTileOptimizeTransposePass();

include/imex/Dialect/XeTile/Transforms/Passes.td

+26
Original file line numberDiff line numberDiff line change
@@ -130,5 +130,31 @@ def XeTileCanonicalization : Pass<"xetile-canonicalization", "::mlir::gpu::GPUMo
130130
];
131131
}
132132

133+
def NewXeTileBlocking : Pass<"new-xetile-blocking", "::mlir::gpu::GPUModuleOp">{
134+
let summary = "transform XeTile large tiles(input) into arrays of smaller "
135+
"blocks with appropriate size, such that the operator on each "
136+
"of the blocks can be mapped into one hardware instruction.";
137+
138+
let description = [{
139+
This transform pass preprocesses the xetile program by decomposing large XeTile tiles
140+
into smaller ones that can be handled by a hardware instruction. It is going to replace
141+
the xetile-blocking pass.
142+
}];
143+
144+
let constructor = "imex::createNewXeTileBlockingPass()";
145+
let dependentDialects = ["imex::xetile::XeTileDialect",
146+
"mlir::arith::ArithDialect",
147+
"mlir::math::MathDialect",
148+
"mlir::gpu::GPUDialect",
149+
"mlir::memref::MemRefDialect",
150+
"mlir::vector::VectorDialect"];
151+
152+
let options = [
153+
Option<"device", "device", "std::string",
154+
/*default=*/"\"pvc\"",
155+
"gpu platform architecture where these ops are running">
156+
];
157+
}
158+
133159

134160
#endif // _XeTile_PASSES_TD_INCLUDED_

lib/Conversion/XeTileToXeGPU/XeTileOpConversion.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -736,15 +736,16 @@ extern llvm::SmallVector<mlir::Value> lowerInnerReductionWithVectorReduction(
736736
mlir::vector::CombiningKind kind, mlir::Location loc, mlir::Type elemTy,
737737
XeOneToNPatternRewriter &rewriter);
738738

739-
struct SgTileReduceOpPattern : public XeOneToNConversion<xetile::ReduceOp> {
740-
using XeOneToNConversion<xetile::ReduceOp>::XeOneToNConversion;
739+
struct SgTileReductionOpPattern
740+
: public XeOneToNConversion<xetile::ReductionOp> {
741+
using XeOneToNConversion<xetile::ReductionOp>::XeOneToNConversion;
741742

742743
mlir::LogicalResult
743-
matchAndRewrite(xetile::ReduceOp op, OpAdaptor adaptor,
744+
matchAndRewrite(xetile::ReductionOp op, OpAdaptor adaptor,
744745
XeOneToNPatternRewriter &rewriter) const override {
745746
auto srcTy = op.getSource().getType();
746747
auto elemTy = srcTy.getElementType();
747-
auto dims = op.getReductionDim();
748+
auto dims = op.getReductionDims();
748749
// its input should be a 4D vector, and has 2 reduction dims,
749750
// otherwise run the blocking pass first.
750751
if (dims.size() != 2 || srcTy.getRank() != 4)
@@ -1092,8 +1093,8 @@ void populateXeTileOpConversionPatterns(imex::XeOneToNTypeConverter &converter,
10921093
SgTileMMAOpPattern, SgUpdateTileOffsetOpPattern,
10931094
SgTransposeOpPattern<mlir::vector::TransposeOp>,
10941095
SgTransposeOpPattern<xetile::TransposeOp>, SgBroadcastOpPattern,
1095-
SgTileReduceOpPattern, SgVectorCreateMaskOpPattern>(patterns.getContext(),
1096-
converter, analysis);
1096+
SgTileReductionOpPattern, SgVectorCreateMaskOpPattern>(
1097+
patterns.getContext(), converter, analysis);
10971098
patterns.insert<ElementWiseOpPattern<mlir::arith::NegFOp, 1>,
10981099
ElementWiseOpPattern<mlir::math::ExpOp, 1>,
10991100
ElementWiseOpPattern<mlir::math::SinOp, 1>,

lib/Dialect/XeTile/IR/XeTileOps.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -859,8 +859,8 @@ mlir::LogicalResult TransposeOp::verify() {
859859
return mlir::success();
860860
}
861861

862-
mlir::LogicalResult ReduceOp::verify() {
863-
auto dims = getReductionDim();
862+
mlir::LogicalResult ReductionOp::verify() {
863+
auto dims = getReductionDims();
864864
auto resShape = getResult().getType().getShape();
865865
for (auto i : dims)
866866
if (resShape[i] != 1)

lib/Dialect/XeTile/Transforms/Blocking.cpp

+12-11
Original file line numberDiff line numberDiff line change
@@ -556,29 +556,30 @@ struct VectorMultiDimReductionOpPattern
556556
}
557557
};
558558

559-
struct TileReduceOpPattern
560-
: public XeTileConversion<xetile::ReduceOp, TileUsageAnalysis> {
559+
struct TileReductionOpPattern
560+
: public XeTileConversion<xetile::ReductionOp, TileUsageAnalysis> {
561561

562-
using XeTileConversion<xetile::ReduceOp, TileUsageAnalysis>::XeTileConversion;
562+
using XeTileConversion<xetile::ReductionOp,
563+
TileUsageAnalysis>::XeTileConversion;
563564

564-
TileReduceOpPattern(mlir::MLIRContext *context,
565-
imex::XeTypeConverter &converter,
566-
TileUsageAnalysis &analysis,
567-
std::shared_ptr<XeuArchInterface> ptruArch)
565+
TileReductionOpPattern(mlir::MLIRContext *context,
566+
imex::XeTypeConverter &converter,
567+
TileUsageAnalysis &analysis,
568+
std::shared_ptr<XeuArchInterface> ptruArch)
568569
: XeTileConversion(context, converter, analysis) {
569570
this->uArchInterface = ptruArch;
570571
}
571572

572573
std::shared_ptr<XeuArchInterface> uArchInterface = nullptr;
573574

574575
mlir::LogicalResult
575-
matchAndRewrite(xetile::ReduceOp op, OpAdaptor adaptor,
576+
matchAndRewrite(xetile::ReductionOp op, OpAdaptor adaptor,
576577
OpPatternRewriter &rewriter) const override {
577578
auto loc = op.getLoc();
578579
auto srcTy = op.getSource().getType();
579580
auto elemTy = srcTy.getElementType();
580581
auto shape = srcTy.getShape();
581-
auto reductionDims = op.getReductionDim();
582+
auto reductionDims = op.getReductionDims();
582583

583584
if (srcTy.getRank() != 2 || reductionDims.size() != 1)
584585
return rewriter.notifyMatchFailure(
@@ -611,7 +612,7 @@ struct TileReduceOpPattern
611612

612613
auto newSource =
613614
addPackOp(adaptor.getSource(), {blkSizes[0], blkSizes[1]}, rewriter);
614-
auto newDest = rewriter.create<xetile::ReduceOp>(
615+
auto newDest = rewriter.create<xetile::ReductionOp>(
615616
loc, newDestType, op.getKind(), newSource, newReductionDims);
616617
auto unpack = addUnpackOp(newDest.getResult(), rewriter);
617618
rewriter.replaceOp(op, unpack);
@@ -1161,7 +1162,7 @@ void populateXeTileBlockingPatterns(
11611162
VectorizableOpPattern, SCFForOpPattern, SCFYieldOpPattern,
11621163
InitTileOpPattern, LoadTileOpPattern, StoreTileOpPattern,
11631164
TileMMAOpPattern, UpdateTileOffsetOpPattern,
1164-
VectorMultiDimReductionOpPattern, TileReduceOpPattern,
1165+
VectorMultiDimReductionOpPattern, TileReductionOpPattern,
11651166
TileBroadcastOpPattern>(patterns.getContext(), converter,
11661167
analysis, ptruArch);
11671168
patterns.insert<TransposeOpPattern<mlir::vector::TransposeOp>,

0 commit comments

Comments
 (0)