@@ -556,29 +556,30 @@ struct VectorMultiDimReductionOpPattern
556
556
}
557
557
};
558
558
559
- struct TileReduceOpPattern
560
- : public XeTileConversion<xetile::ReduceOp , TileUsageAnalysis> {
559
+ struct TileReductionOpPattern
560
+ : public XeTileConversion<xetile::ReductionOp , TileUsageAnalysis> {
561
561
562
- using XeTileConversion<xetile::ReduceOp, TileUsageAnalysis>::XeTileConversion;
562
+ using XeTileConversion<xetile::ReductionOp,
563
+ TileUsageAnalysis>::XeTileConversion;
563
564
564
- TileReduceOpPattern (mlir::MLIRContext *context,
565
- imex::XeTypeConverter &converter,
566
- TileUsageAnalysis &analysis,
567
- std::shared_ptr<XeuArchInterface> ptruArch)
565
+ TileReductionOpPattern (mlir::MLIRContext *context,
566
+ imex::XeTypeConverter &converter,
567
+ TileUsageAnalysis &analysis,
568
+ std::shared_ptr<XeuArchInterface> ptruArch)
568
569
: XeTileConversion(context, converter, analysis) {
569
570
this ->uArchInterface = ptruArch;
570
571
}
571
572
572
573
std::shared_ptr<XeuArchInterface> uArchInterface = nullptr ;
573
574
574
575
mlir::LogicalResult
575
- matchAndRewrite (xetile::ReduceOp op, OpAdaptor adaptor,
576
+ matchAndRewrite (xetile::ReductionOp op, OpAdaptor adaptor,
576
577
OpPatternRewriter &rewriter) const override {
577
578
auto loc = op.getLoc ();
578
579
auto srcTy = op.getSource ().getType ();
579
580
auto elemTy = srcTy.getElementType ();
580
581
auto shape = srcTy.getShape ();
581
- auto reductionDims = op.getReductionDim ();
582
+ auto reductionDims = op.getReductionDims ();
582
583
583
584
if (srcTy.getRank () != 2 || reductionDims.size () != 1 )
584
585
return rewriter.notifyMatchFailure (
@@ -611,7 +612,7 @@ struct TileReduceOpPattern
611
612
612
613
auto newSource =
613
614
addPackOp (adaptor.getSource (), {blkSizes[0 ], blkSizes[1 ]}, rewriter);
614
- auto newDest = rewriter.create <xetile::ReduceOp >(
615
+ auto newDest = rewriter.create <xetile::ReductionOp >(
615
616
loc, newDestType, op.getKind (), newSource, newReductionDims);
616
617
auto unpack = addUnpackOp (newDest.getResult (), rewriter);
617
618
rewriter.replaceOp (op, unpack);
@@ -1161,7 +1162,7 @@ void populateXeTileBlockingPatterns(
1161
1162
VectorizableOpPattern, SCFForOpPattern, SCFYieldOpPattern,
1162
1163
InitTileOpPattern, LoadTileOpPattern, StoreTileOpPattern,
1163
1164
TileMMAOpPattern, UpdateTileOffsetOpPattern,
1164
- VectorMultiDimReductionOpPattern, TileReduceOpPattern ,
1165
+ VectorMultiDimReductionOpPattern, TileReductionOpPattern ,
1165
1166
TileBroadcastOpPattern>(patterns.getContext (), converter,
1166
1167
analysis, ptruArch);
1167
1168
patterns.insert <TransposeOpPattern<mlir::vector::TransposeOp>,
0 commit comments