Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6623,6 +6623,37 @@ def Torch_AtenMmOp : Torch_Op<"aten.mm", [
}];
}

def Torch_Aten_ScaledMmOp : Torch_Op<"aten._scaled_mm", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::_scaled_mm : (Tensor, Tensor, Tensor, Tensor, Tensor?, Tensor?, int?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$mat2,
AnyTorchTensorType:$scale_a,
AnyTorchTensorType:$scale_b,
AnyTorchOptionalTensorType:$bias,
AnyTorchOptionalTensorType:$scale_result,
AnyTorchOptionalIntType:$out_dtype,
Torch_BoolType:$use_fast_accum
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten_ScaledMmOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 8, 1);
}
void Aten_ScaledMmOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 8, 1);
}
}];
let hasVerifier = 1;
}

def Torch_Aten_IntMmOp : Torch_Op<"aten._int_mm", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
319 changes: 319 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/MathExtras.h"
#include <array>
Expand Down Expand Up @@ -106,6 +107,106 @@ static SmallVector<int64_t> permuteShape(ArrayRef<int64_t> originalShape,
return result;
}

static bool isSupportedScaledMmDataElementType(Type type) {
return isa<Float8E4M3FNType, Float8E5M2Type>(type);
}

static bool isSupportedStaticScaledMmScaleElementType(Type type) {
return type.isF32();
}

static constexpr int64_t kScaledMmMatmulNAlignment = 16;

static SmallVector<int64_t> getTensorShape(RankedTensorType tensorTy) {
return SmallVector<int64_t>(tensorTy.getShape().begin(),
tensorTy.getShape().end());
}

static Value reshapeTensor(Value input, ArrayRef<int64_t> shape, Type elementTy,
ConversionPatternRewriter &rewriter, Location loc) {
auto resultTy =
RankedTensorType::get(makeShapeLLVMCompatible(shape), elementTy);
return tosa::ReshapeOp::create(rewriter, loc, resultTy, input,
tosa::getTosaConstShape(rewriter, loc, shape))
.getResult();
}

static FailureOr<Value> padLastDimWithZeros(Value input, int64_t paddedLastDim,
ConversionPatternRewriter &rewriter,
Location loc) {
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy || !inputTy.hasStaticShape())
return failure();

int64_t rank = inputTy.getRank();
if (rank == 0)
return failure();

int64_t lastDim = inputTy.getDimSize(rank - 1);
if (lastDim == paddedLastDim)
return input;
if (lastDim > paddedLastDim)
return failure();

SmallVector<int64_t> paddedShape = getTensorShape(inputTy);
paddedShape[rank - 1] = paddedLastDim;

SmallVector<int64_t> pads(2 * rank, 0);
pads[2 * (rank - 1) + 1] = paddedLastDim - lastDim;
Value padsConst = tosa::getTosaConstShape(rewriter, loc, pads);

auto padValueOr = tosa::createZeroPointTensor(
rewriter, loc, inputTy.getElementType(), /*zeroPoint=*/0);
if (!padValueOr)
return failure();

auto paddedTy = RankedTensorType::get(makeShapeLLVMCompatible(paddedShape),
inputTy.getElementType());
return tosa::PadOp::create(rewriter, loc, paddedTy, input, padsConst,
*padValueOr)
.getResult();
}

struct StaticScaledMmScaleShapes {
SmallVector<int64_t> scaleA;
SmallVector<int64_t> scaleB;
};

static std::optional<StaticScaledMmScaleShapes>
getStaticScaledMmBatchedScaleShapes(RankedTensorType scaleATy,
RankedTensorType scaleBTy, int64_t m,
int64_t n) {
if (!scaleATy.hasStaticShape() || !scaleBTy.hasStaticShape() ||
!isSupportedStaticScaledMmScaleElementType(scaleATy.getElementType()) ||
!isSupportedStaticScaledMmScaleElementType(scaleBTy.getElementType()))
return std::nullopt;

if (scaleATy.getNumElements() == 1 && scaleBTy.getNumElements() == 1)
return StaticScaledMmScaleShapes{{1, 1, 1}, {1, 1, 1}};

if (scaleATy.getRank() != 2 || scaleBTy.getRank() != 2)
return std::nullopt;

int64_t scaleARows = scaleATy.getDimSize(0);
int64_t scaleACols = scaleATy.getDimSize(1);
int64_t scaleBRows = scaleBTy.getDimSize(0);
int64_t scaleBCols = scaleBTy.getDimSize(1);

if (scaleARows == m && scaleACols == 1 && scaleBRows == 1 &&
scaleBCols == n)
return StaticScaledMmScaleShapes{{1, m, 1}, {1, 1, n}};

if (scaleARows == m && scaleACols == 1 && scaleBRows == 1 &&
scaleBCols == 1)
return StaticScaledMmScaleShapes{{1, m, 1}, {1, 1, 1}};

if (scaleARows == 1 && scaleACols == 1 && scaleBRows == 1 &&
scaleBCols == n)
return StaticScaledMmScaleShapes{{1, 1, 1}, {1, 1, n}};

return std::nullopt;
}

struct ZeroInsertionResult {
Value value;
bool trimmedTail;
Expand Down Expand Up @@ -210,6 +311,139 @@ insertZerosAlongAxis(Value input, int axis, int64_t stride,
return ZeroInsertionResult{result, trimmedTail};
}

static Value castScaledMmResultToType(Value result, RankedTensorType resultTy,
ConversionPatternRewriter &rewriter) {
if (result.getType() == resultTy)
return result;
return tosa::tosaCastTensorToType(rewriter, result, resultTy).value();
}

static FailureOr<Value>
addBiasToScaledMmAccumulator(Value accumulator, Value bias,
ConversionPatternRewriter &rewriter,
Location loc) {
if (isa<Torch::NoneType>(bias.getType()))
return accumulator;
auto biasTy = dyn_cast<RankedTensorType>(bias.getType());
if (!biasTy)
return failure();
auto accumulatorTy = cast<RankedTensorType>(accumulator.getType());
auto biasAccumulatorTy =
RankedTensorType::get(biasTy.getShape(), accumulatorTy.getElementType());
bias = tosa::tosaCastTensorToType(rewriter, bias, biasAccumulatorTy).value();
if (mlir::tosa::EqualizeRanks(rewriter, loc, accumulator, bias).failed())
return failure();
return tosa::AddOp::create(rewriter, loc, accumulator.getType(), accumulator,
bias)
.getResult();
}

static LogicalResult rewriteScaledMmToMatMulOp(
Operation *op, Value lhs, Value rhs, Value scaleA, Value scaleB, Value bias,
RankedTensorType lhsTy, RankedTensorType rhsTy, RankedTensorType scaleATy,
RankedTensorType scaleBTy, RankedTensorType resultTy, int64_t m, int64_t k,
int64_t n, ConversionPatternRewriter &rewriter, Location loc) {
auto f32Ty = rewriter.getF32Type();
int64_t paddedN = llvm::alignTo(n, kScaledMmMatmulNAlignment);
bool needsOutputPadding = paddedN != n;

// TOSA FP8 matmul requires the output channel dimension to be aligned. Pad
// RHS columns and slice the f32 accumulator back to the user-visible N after
// matmul/scaling.
if (needsOutputPadding) {
FailureOr<Value> paddedRhsOr =
padLastDimWithZeros(rhs, paddedN, rewriter, loc);
if (failed(paddedRhsOr))
return rewriter.notifyMatchFailure(
op, "failed to pad aten._scaled_mm rhs output dimension");
rhs = *paddedRhsOr;
}

lhs = reshapeTensor(lhs, {1, m, k}, lhsTy.getElementType(), rewriter, loc);
rhs = reshapeTensor(rhs, {1, k, paddedN}, rhsTy.getElementType(), rewriter,
loc);

auto scaleAF32Ty = RankedTensorType::get(
makeShapeLLVMCompatible(scaleATy.getShape()), f32Ty);
auto scaleBF32Ty = RankedTensorType::get(
makeShapeLLVMCompatible(scaleBTy.getShape()), f32Ty);
scaleA = tosa::tosaCastTensorToType(rewriter, scaleA, scaleAF32Ty).value();
scaleB = tosa::tosaCastTensorToType(rewriter, scaleB, scaleBF32Ty).value();

std::optional<StaticScaledMmScaleShapes> batchedScaleShapes =
getStaticScaledMmBatchedScaleShapes(scaleATy, scaleBTy, m, n);
if (!batchedScaleShapes)
return rewriter.notifyMatchFailure(
op, "aten._scaled_mm expects static FP8 scales to be fp32 "
"tensorwise scales or paired 2D PyTorch scale layouts");

SmallVector<int64_t> batchedScaleAShapeVec = batchedScaleShapes->scaleA;
SmallVector<int64_t> batchedScaleBShapeVec = batchedScaleShapes->scaleB;

scaleA = reshapeTensor(scaleA, batchedScaleAShapeVec, f32Ty, rewriter, loc);
scaleB = reshapeTensor(scaleB, batchedScaleBShapeVec, f32Ty, rewriter, loc);

// If the RHS scale is per output channel, pad it alongside the RHS data so
// the scale multiply broadcasts over the padded matmul result.
if (needsOutputPadding && batchedScaleBShapeVec[2] == n) {
FailureOr<Value> paddedScaleBOr =
padLastDimWithZeros(scaleB, paddedN, rewriter, loc);
if (failed(paddedScaleBOr))
return rewriter.notifyMatchFailure(
op, "failed to pad aten._scaled_mm rhs channel scale");
scaleB = *paddedScaleBOr;
batchedScaleBShapeVec[2] = paddedN;
}

auto zeroPointAOr =
tosa::createZeroPointTensor(rewriter, loc, lhsTy.getElementType(), 0);
auto zeroPointBOr =
tosa::createZeroPointTensor(rewriter, loc, rhsTy.getElementType(), 0);
if (!zeroPointAOr || !zeroPointBOr)
return rewriter.notifyMatchFailure(
op, "failed to materialize FP8 zero point for matmul");

auto matmulTy = RankedTensorType::get({1, m, paddedN}, f32Ty);
Value matmul = tosa::MatMulOp::create(rewriter, loc, matmulTy, lhs, rhs,
*zeroPointAOr, *zeroPointBOr)
.getResult();

auto combinedScaleTy = RankedTensorType::get(
{1, std::max(batchedScaleAShapeVec[1], batchedScaleBShapeVec[1]),
std::max(batchedScaleAShapeVec[2], batchedScaleBShapeVec[2])},
f32Ty);
Value combinedScale =
tosa::createMulOpAndCast(rewriter, op, combinedScaleTy, scaleA, scaleB,
/*shift=*/0);
Value scaledMatmul =
tosa::createMulOpAndCast(rewriter, op, matmulTy, matmul, combinedScale,
/*shift=*/0);

if (needsOutputPadding) {
auto slicedTy = RankedTensorType::get({1, m, n}, f32Ty);
scaledMatmul =
tosa::SliceOp::create(rewriter, loc, slicedTy, scaledMatmul,
tosa::getTosaConstShape(rewriter, loc, {0, 0, 0}),
tosa::getTosaConstShape(rewriter, loc, {1, m, n}))
.getResult();
}

auto reshapedTy = RankedTensorType::get(resultTy.getShape(), f32Ty);
Value result =
tosa::ReshapeOp::create(
rewriter, loc, reshapedTy, scaledMatmul,
tosa::getTosaConstShape(rewriter, loc, getTensorShape(resultTy)))
.getResult();
auto resultWithBiasOr =
addBiasToScaledMmAccumulator(result, bias, rewriter, loc);
if (failed(resultWithBiasOr))
return rewriter.notifyMatchFailure(
op, "Failed to add bias to aten._scaled_mm accumulator");
result = castScaledMmResultToType(*resultWithBiasOr, resultTy, rewriter);
rewriter.replaceOp(op, {result});
return success();
}

static LogicalResult
getTorchToTosaPermutations(Location loc, int64_t rank,
SmallVectorImpl<int32_t> &torchToTosa,
Expand Down Expand Up @@ -1074,6 +1308,7 @@ class ConvertAtenDivOp : public TorchToTosaOpConversionPattern<AtenOpT> {
// types can only be floating point for tosa::ReciprocalOp.
rhsTensor =
tosa::tosaCastTensorToType(rewriter, rhsTensor, outType).value();
lhs = tosa::tosaCastTensorToType(rewriter, lhs, outType).value();
auto rhsRcp = tosa::ReciprocalOp::create(rewriter, op->getLoc(),
rhsTensor.getType(), rhsTensor);

Expand Down Expand Up @@ -2592,6 +2827,86 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
}
};

template <typename AtenOpT>
class ConvertAtenScaledMmOp : public TorchToTosaOpConversionPattern<AtenOpT> {
public:
using TorchToTosaOpConversionPattern<AtenOpT>::TorchToTosaOpConversionPattern;
using OpAdaptor = typename AtenOpT::Adaptor;

LogicalResult
matchAndRewriteImpl(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool useFastAccum = false;
if (!matchPattern(op.getUseFastAccum(), m_TorchConstantBool(&useFastAccum)))
return rewriter.notifyMatchFailure(
op, "aten._scaled_mm requires a constant bool use_fast_accum");
// TOSA does not expose an equivalent fast-accumulation mode. Lower both
// PyTorch modes to the same f32-accumulating TOSA sequence.

if (!isa<Torch::NoneType>(op.getScaleResult().getType()))
return rewriter.notifyMatchFailure(
op, "aten._scaled_mm with scale_result is not supported");

Value lhs = adaptor.getSelf();
Value rhs = adaptor.getMat2();
Value scaleA = adaptor.getScaleA();
Value scaleB = adaptor.getScaleB();
Value bias = adaptor.getBias();

auto lhsTy = dyn_cast<RankedTensorType>(lhs.getType());
auto rhsTy = dyn_cast<RankedTensorType>(rhs.getType());
auto scaleATy = dyn_cast<RankedTensorType>(scaleA.getType());
auto scaleBTy = dyn_cast<RankedTensorType>(scaleB.getType());
auto resultTy = dyn_cast<RankedTensorType>(
this->getTypeConverter()->convertType(op.getType()));

if (!lhsTy || !rhsTy || !scaleATy || !scaleBTy || !resultTy)
return rewriter.notifyMatchFailure(
op,
"aten._scaled_mm requires tensor operands with ranked result type");

if (!lhsTy.hasStaticShape() || !rhsTy.hasStaticShape() ||
!resultTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "aten._scaled_mm requires static input/result shapes");

auto lhsElemTy = lhsTy.getElementType();
auto rhsElemTy = rhsTy.getElementType();

if (!isSupportedScaledMmDataElementType(lhsElemTy) ||
!isSupportedScaledMmDataElementType(rhsElemTy))
return rewriter.notifyMatchFailure(
op, "aten._scaled_mm only supports FP8 input types");

if (!isSupportedStaticScaledMmScaleElementType(scaleATy.getElementType()) ||
!isSupportedStaticScaledMmScaleElementType(scaleBTy.getElementType()))
return rewriter.notifyMatchFailure(
op, "aten._scaled_mm expects fp32 static FP8 scales");

Location loc = op.getLoc();

if (lhsTy.getRank() != 2 || rhsTy.getRank() != 2 || resultTy.getRank() != 2)
return rewriter.notifyMatchFailure(
op, "aten._scaled_mm expects rank-2 input and result tensors for "
"static FP8 scales");

int64_t m = lhsTy.getShape()[0];
int64_t k = lhsTy.getShape()[1];
int64_t rhsK = rhsTy.getShape()[0];
int64_t n = rhsTy.getShape()[1];
if (k != rhsK)
return rewriter.notifyMatchFailure(
op, "aten._scaled_mm requires inner dimensions of lhs/rhs to match");
if (resultTy.getShape()[0] != m || resultTy.getShape()[1] != n)
return rewriter.notifyMatchFailure(
op, "aten._scaled_mm expects static FP8 result shape [M, N]");

return rewriteScaledMmToMatMulOp(op, lhs, rhs, scaleA, scaleB, bias, lhsTy,
rhsTy, scaleATy, scaleBTy, resultTy, m, k,
n, rewriter, loc);
}
};

template <>
LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewriteImpl(
AtenRsubScalarOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -10810,6 +11125,10 @@ std::set<StringRef> populateTorchToTosaConversionPatternsAndIllegalOps(
INSERT_LINEAR_ATENOP_PATTERN(AtenLinearOp);
#undef INSERT_LINEAR_ATENOP_PATTERN

illegalOps.insert(Aten_ScaledMmOp::getOperationName());
patterns.addWithLabel<ConvertAtenScaledMmOp<Aten_ScaledMmOp>>(
Aten_ScaledMmOp::getOperationName(), typeConverter, context);

#define INSERT_ADAPTIVE_POOLING_ATENOP_PATTERN(AtenOp, TosaOpT) \
illegalOps.insert(AtenOp::getOperationName()); \
patterns.addWithLabel<ConvertAtenAdaptivePoolingOp<AtenOp, TosaOpT>>( \
Expand Down
Loading
Loading