Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions lib/Conversion/TorchToLinalg/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,64 @@ using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {
// Runs an inclusive prefix sum along the middle dimension of a rank-3 tensor.
// The input shape is [outer, scan, inner]; the binary lifting loop keeps the
// implementation rank-independent after aten.cumsum is reshaped into this form.
static Value emitInclusiveScanByPowersOfTwo(Value running,
ConversionPatternRewriter &rewriter,
Location loc) {
RankedTensorType runningType = cast<RankedTensorType>(running.getType());
SmallVector<int64_t> runningShape =
makeShapeTorchCompatible(runningType.getShape());
int64_t outer = runningShape[0];
int64_t scanDimSize = runningShape[1];
int64_t inner = runningShape[2];
Type elementType = runningType.getElementType();

Value zero = arith::ConstantOp::create(rewriter, loc,
rewriter.getZeroAttr(elementType));

SmallVector<OpFoldResult> sliceOffsets(3, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> sliceSizes = {rewriter.getIndexAttr(outer),
rewriter.getIndexAttr(scanDimSize),
rewriter.getIndexAttr(inner)};
SmallVector<OpFoldResult> sliceStrides(3, rewriter.getIndexAttr(1));

for (int64_t offset = 1; offset < scanDimSize; offset <<= 1) {
SmallVector<int64_t> lowPad = {0, offset, 0};
SmallVector<int64_t> highPad = {0, 0, 0};
Type paddedType =
tensor::PadOp::inferResultType(runningType, lowPad, highPad);
SmallVector<OpFoldResult> lowPadValues = {rewriter.getIndexAttr(0),
rewriter.getIndexAttr(offset),
rewriter.getIndexAttr(0)};
SmallVector<OpFoldResult> highPadValues(3, rewriter.getIndexAttr(0));
Value padded = tensor::PadOp::create(rewriter, loc, paddedType, running,
lowPadValues, highPadValues, zero);

Value shifted =
tensor::ExtractSliceOp::create(rewriter, loc, runningType, padded,
sliceOffsets, sliceSizes, sliceStrides);

running = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, {running, shifted}, elementType,
[&](OpBuilder &builder, Location loc, ValueRange payloadArgs) {
Value result;
if (isa<mlir::FloatType>(elementType))
result = arith::AddFOp::create(builder, loc, payloadArgs[0],
payloadArgs[1]);
else if (isa<mlir::IntegerType>(elementType))
result = arith::AddIOp::create(builder, loc, payloadArgs[0],
payloadArgs[1]);
else
llvm_unreachable("unsupported cumsum element type");
linalg::YieldOp::create(builder, loc, result);
});
}

return running;
}

// Aten max.dim (min.dim) lowering represents the MaxDimOp (MinDimOp) as an
// linalg.indexed_generic op, producing two output buffers.
//
Expand Down Expand Up @@ -812,6 +870,89 @@ class ConvertReductionOp : public ConversionPattern {
};
} // namespace

namespace {
class ConvertAtenCumsumOp : public OpConversionPattern<AtenCumsumOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenCumsumOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = adaptor.getSelf();
auto selfType = dyn_cast<RankedTensorType>(self.getType());
if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure(op,
"only static tensor shapes supported");

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return rewriter.notifyMatchFailure(op, "dim must be constant");
dim = toPositiveDim(dim, selfType.getRank());
if (!isValidDim(dim, selfType.getRank()))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");

auto resultType = dyn_cast<RankedTensorType>(
getTypeConverter()->convertType(op.getType()));
if (!resultType || !resultType.hasStaticShape())
return rewriter.notifyMatchFailure(op, "expected static ranked result");

Type resultElementType = resultType.getElementType();
if (!isa<mlir::FloatType, mlir::IntegerType>(resultElementType))
return rewriter.notifyMatchFailure(
op, "only floating point and integer element types supported");

if (selfType.getElementType() != resultElementType)
self = torch_to_linalg::convertTensorToElementType(rewriter, loc, self,
resultElementType);

SmallVector<int64_t> inputShape =
makeShapeTorchCompatible(selfType.getShape());
int64_t scanDimSize = inputShape[dim];

int64_t outer = 1;
for (int64_t i = 0; i < dim; ++i)
outer *= inputShape[i];
int64_t inner = 1;
for (int64_t i = dim + 1, e = inputShape.size(); i < e; ++i)
inner *= inputShape[i];

SmallVector<int64_t> scanShape = {outer, scanDimSize, inner};
auto scanType = RankedTensorType::get(makeShapeLLVMCompatible(scanShape),
resultElementType);
auto shapeType = RankedTensorType::get({3}, rewriter.getIntegerType(64));
SmallVector<Value> scanShapeValues;
for (int64_t size : scanShape) {
scanShapeValues.push_back(arith::ConstantOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(size)));
}
Value scanShapeTensor = tensor::FromElementsOp::create(
rewriter, loc, shapeType, scanShapeValues);
Value running = tensor::ReshapeOp::create(rewriter, loc, scanType, self,
scanShapeTensor)
.getResult();

running = emitInclusiveScanByPowersOfTwo(running, rewriter, loc);

auto resultShapeType =
RankedTensorType::get({resultType.getRank()}, rewriter.getI64Type());
SmallVector<Value> resultShapeValues;
for (int64_t size : resultType.getShape()) {
resultShapeValues.push_back(arith::ConstantOp::create(
rewriter, loc, rewriter.getI64IntegerAttr(size)));
}
Value resultShapeTensor = tensor::FromElementsOp::create(
rewriter, loc, resultShapeType, resultShapeValues);
Value result = tensor::ReshapeOp::create(rewriter, loc, resultType, running,
resultShapeTensor)
.getResult();

rewriter.replaceOp(op, result);
return success();
}
};
} // namespace

void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, bool allowNonFinites) {
Expand All @@ -837,5 +978,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality(
target.addIllegalOp<AtenNormScalarOp>();
target.addIllegalOp<AtenLinalgVectorNormOp>();
target.addIllegalOp<AtenFrobeniusNormDimOp>();
target.addIllegalOp<AtenCumsumOp>();
patterns.add<ConvertReductionOp>(typeConverter, context, allowNonFinites);
patterns.add<ConvertAtenCumsumOp>(typeConverter, context);
}
19 changes: 19 additions & 0 deletions test/Conversion/TorchToLinalg/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,25 @@ func.func @torch.ops.aten.anydim$basic(%arg0: tensor<1x16x26x26xi1>) -> !torch.v

// -----

// CHECK-LABEL: func.func @torch.aten.cumsum$to_builtin_user
// CHECK: torch_c.to_builtin_tensor
// CHECK: tensor.reshape
// CHECK: tensor.pad
// CHECK: tensor.extract_slice
// CHECK: linalg.generic
// CHECK: arith.addf
// CHECK: tensor.reshape
// CHECK-NOT: torch.aten.cumsum
func.func @torch.aten.cumsum$to_builtin_user(%arg0: !torch.vtensor<[2,3],f32>) -> tensor<2x3xf32> {
%dim = torch.constant.int 1
%none = torch.constant.none
%0 = torch.aten.cumsum %arg0, %dim, %none : !torch.vtensor<[2,3],f32>, !torch.int, !torch.none -> !torch.vtensor<[2,3],f32>
%1 = torch_c.to_builtin_tensor %0 : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32>
return %1 : tensor<2x3xf32>
}

// -----

// Per PyTorch docs, torch.cat allows "a 1-D empty tensor with size (0,)"
// alongside operands of any rank. The linalg lowering must skip these.
// CHECK-LABEL: func.func @torch.aten.cat$rank1_empty
Expand Down
Loading