Skip to content

Commit d4088c9

Browse files
committed
1 parent 1225073 commit d4088c9

File tree

6 files changed

+148
-102
lines changed

6 files changed

+148
-102
lines changed

externals/llvm-project

Submodule llvm-project updated 4692 files

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

+4
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ bool isScale32(mlir::quant::UniformQuantizedType output_element_type);
4545
Value getTosaConstTensorSingleF32(PatternRewriter &rewriter, Operation *op,
4646
float val);
4747

48+
// Create a 32-bit int constant operator from a int
49+
Value getTosaConstTensorSingleI32(PatternRewriter &rewriter, Operation *op,
50+
int32_t val);
51+
4852
// Create a zero constant tensor of the desired type and shape.
4953
std::optional<Value> getZerosLikeTensor(PatternRewriter &rewriter,
5054
Operation *op, Type type);

lib/Conversion/TorchToTosa/TorchToTosa.cpp

+102-63
Original file line numberDiff line numberDiff line change
@@ -622,14 +622,16 @@ Value floorIntDiv(PatternRewriter &rewriter, Operation *op, TensorType outType,
622622
auto boolType =
623623
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1));
624624

625-
auto lhsMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type, lhs, rhs,
626-
/*shift=*/0);
625+
auto lhsMulRhs = rewriter.create<tosa::MulOp>(
626+
op->getLoc(), i32Type, lhs, rhs,
627+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
627628

628629
auto lhsRhsDifferentSign =
629630
rewriter.create<tosa::GreaterOp>(op->getLoc(), boolType, zero, lhsMulRhs);
630631

631-
auto truncMulRhs = rewriter.create<tosa::MulOp>(op->getLoc(), i32Type,
632-
intDivOp, rhs, /*shift=*/0);
632+
auto truncMulRhs = rewriter.create<tosa::MulOp>(
633+
op->getLoc(), i32Type, intDivOp, rhs,
634+
tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
633635

634636
auto truncMulRhsEqualLhs =
635637
rewriter.create<tosa::EqualOp>(op->getLoc(), boolType, truncMulRhs, lhs);
@@ -853,7 +855,7 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
853855
self, zero);
854856
auto mulTensor = rewriter.create<tosa::MulOp>(
855857
op->getLoc(), getTypeConverter()->convertType(op.getType()), self,
856-
alphaTensor, /*shift=*/0);
858+
alphaTensor, tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
857859

858860
rewriter.replaceOpWithNewOp<tosa::SelectOp>(
859861
op, getTypeConverter()->convertType(op.getType()), cond, self, mulTensor);
@@ -2151,8 +2153,9 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
21512153
/*checkForUnity=*/true)))
21522154
return failure();
21532155

2154-
auto multTensor = rewriter.create<tosa::MulOp>(op->getLoc(), resultTy, self,
2155-
alphaTensor, /*shift=*/0);
2156+
auto multTensor = rewriter.create<tosa::MulOp>(
2157+
op->getLoc(), resultTy, self, alphaTensor,
2158+
tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
21562159

21572160
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, resultTy, otherTensor,
21582161
multTensor);
@@ -2493,12 +2496,14 @@ Value computeBatchNorm(Operation *op, ConversionPatternRewriter &rewriter,
24932496
auto op3RsqrtOp2 = rewriter.create<tosa::RsqrtOp>(
24942497
op->getLoc(), variance.getType(), op2AddVarEpsilon.getResult());
24952498

2496-
auto op4MulOp1Op3 = rewriter.create<tosa::MulOp>(op->getLoc(), outType,
2497-
op1SubInputMean.getResult(),
2498-
op3RsqrtOp2.getResult(), 0);
2499+
auto op4MulOp1Op3 = rewriter.create<tosa::MulOp>(
2500+
op->getLoc(), outType, op1SubInputMean.getResult(),
2501+
op3RsqrtOp2.getResult(),
2502+
tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
24992503

25002504
auto op5MulOp4Scale = rewriter.create<tosa::MulOp>(
2501-
op->getLoc(), outType, op4MulOp1Op3.getResult(), weight, 0);
2505+
op->getLoc(), outType, op4MulOp1Op3.getResult(), weight,
2506+
tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
25022507

25032508
return rewriter
25042509
.create<tosa::AddOp>(op->getLoc(), outType, op5MulOp4Scale.getResult(),
@@ -2710,19 +2715,22 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
27102715
// Compute mean.
27112716
Value sum = computeSumAndReshape(adaptor.getInput(), inputType, bcastOutType,
27122717
bcastOutShape);
2713-
Value meanVal = rewriter.create<tosa::MulOp>(op.getLoc(), bcastOutType, sum,
2714-
elemCntRcp, /*shift=*/0);
2718+
Value meanVal = rewriter.create<tosa::MulOp>(
2719+
op.getLoc(), bcastOutType, sum, elemCntRcp,
2720+
tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
27152721

27162722
// Compute variance.
27172723
Value squareSumSub = rewriter.create<tosa::SubOp>(
27182724
op.getLoc(), inputType, adaptor.getInput(), meanVal);
2719-
Value squareSum = rewriter.create<tosa::MulOp>(op.getLoc(), inputType,
2720-
squareSumSub, squareSumSub, 0);
2725+
Value squareSum = rewriter.create<tosa::MulOp>(
2726+
op.getLoc(), inputType, squareSumSub, squareSumSub,
2727+
tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
27212728

27222729
Value squareSumReduced =
27232730
computeSumAndReshape(squareSum, inputType, bcastOutType, bcastOutShape);
27242731
Value varianceVal = rewriter.create<tosa::MulOp>(
2725-
op.getLoc(), bcastOutType, squareSumReduced, elemCntRcp, /*shift=*/0);
2732+
op.getLoc(), bcastOutType, squareSumReduced, elemCntRcp,
2733+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
27262734

27272735
// Reshape weight and bias.
27282736
SmallVector<int64_t> weightAndBiasBcastShape;
@@ -2978,8 +2986,9 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
29782986
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);
29792987

29802988
auto logOp = rewriter.create<tosa::LogOp>(op.getLoc(), outType, self);
2981-
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, outType, logOp, rcpOp,
2982-
/*shift=*/0);
2989+
rewriter.replaceOpWithNewOp<tosa::MulOp>(
2990+
op, outType, logOp, rcpOp,
2991+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
29832992

29842993
return success();
29852994
}
@@ -3195,32 +3204,48 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
31953204

31963205
auto a1 =
31973206
tosa::getConstTensor<float>(rewriter, op, 0.278393f, {}, dtype).value();
3198-
auto a1X = rewriter.create<tosa::MulOp>(loc, outType, a1, absX, /*shift=*/0);
3207+
auto a1X = rewriter.create<tosa::MulOp>(
3208+
loc, outType, a1, absX,
3209+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
31993210
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);
32003211

32013212
auto a2 =
32023213
tosa::getConstTensor<float>(rewriter, op, 0.230389f, {}, dtype).value();
3203-
auto x2 = rewriter.create<tosa::MulOp>(loc, outType, absX, absX, /*shift=*/0);
3204-
auto a2X = rewriter.create<tosa::MulOp>(loc, outType, a2, x2, /*shift=*/0);
3214+
auto x2 = rewriter.create<tosa::MulOp>(
3215+
loc, outType, absX, absX,
3216+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
3217+
auto a2X = rewriter.create<tosa::MulOp>(
3218+
loc, outType, a2, x2,
3219+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
32053220
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);
32063221

32073222
auto a3 =
32083223
tosa::getConstTensor<float>(rewriter, op, 0.000972f, {}, dtype).value();
3209-
auto x3 = rewriter.create<tosa::MulOp>(loc, outType, x2, absX, /*shift=*/0);
3210-
auto a3X = rewriter.create<tosa::MulOp>(loc, outType, a3, x3, /*shift=*/0);
3224+
auto x3 = rewriter.create<tosa::MulOp>(
3225+
loc, outType, x2, absX,
3226+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
3227+
auto a3X = rewriter.create<tosa::MulOp>(
3228+
loc, outType, a3, x3,
3229+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
32113230
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);
32123231

32133232
auto a4 =
32143233
tosa::getConstTensor<float>(rewriter, op, 0.078108f, {}, dtype).value();
3215-
auto x4 = rewriter.create<tosa::MulOp>(loc, outType, x3, absX, /*shift=*/0);
3216-
auto a4X = rewriter.create<tosa::MulOp>(loc, outType, a4, x4, /*shift=*/0);
3234+
auto x4 = rewriter.create<tosa::MulOp>(
3235+
loc, outType, x3, absX,
3236+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
3237+
auto a4X = rewriter.create<tosa::MulOp>(
3238+
loc, outType, a4, x4,
3239+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
32173240
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);
32183241

32193242
auto rcprl = rewriter.create<tosa::ReciprocalOp>(loc, outType, sum);
3220-
auto rcprl2 =
3221-
rewriter.create<tosa::MulOp>(loc, outType, rcprl, rcprl, /*shift=*/0);
3222-
auto rcprl4 =
3223-
rewriter.create<tosa::MulOp>(loc, outType, rcprl2, rcprl2, /*shift=*/0);
3243+
auto rcprl2 = rewriter.create<tosa::MulOp>(
3244+
loc, outType, rcprl, rcprl,
3245+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
3246+
auto rcprl4 = rewriter.create<tosa::MulOp>(
3247+
loc, outType, rcprl2, rcprl2,
3248+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
32243249
auto erf = rewriter.create<tosa::SubOp>(loc, outType, one, rcprl4);
32253250

32263251
// Deal with negative x.
@@ -3248,15 +3273,17 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
32483273
Value rsqrt2 =
32493274
tosa::getConstTensor<float>(rewriter, op, 0.70710678f, {}, dtype).value();
32503275

3251-
Value erfArg = rewriter.create<tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
3252-
/*shift=*/0);
3276+
Value erfArg = rewriter.create<tosa::MulOp>(
3277+
loc, outType, xMinusMean, rsqrt2,
3278+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
32533279
Value erf = approximateErfOp(rewriter, op, erfArg, dtype);
32543280
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);
32553281
Value oneHalf =
32563282
tosa::getConstTensor<float>(rewriter, op, 0.5, {}, dtype).value();
32573283

3258-
Value normalCdf = rewriter.create<tosa::MulOp>(loc, outType, oneHalf,
3259-
erfPlus1, /*shift=*/0);
3284+
Value normalCdf = rewriter.create<tosa::MulOp>(
3285+
loc, outType, oneHalf, erfPlus1,
3286+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
32603287
return normalCdf;
32613288
}
32623289

@@ -3295,8 +3322,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
32953322
op->getLoc(),
32963323
cast<RankedTensorType>(cdf.getType()).cloneWith({}, selfElemTy), cdf);
32973324

3298-
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, cdf,
3299-
/*shift=*/0);
3325+
rewriter.replaceOpWithNewOp<tosa::MulOp>(
3326+
op, resultType, self, cdf,
3327+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
33003328
} else if (approximate.compare("tanh") == 0) {
33013329
// "tanh" approximate
33023330
// GELU(x) = 0.5 * x * (1 + Tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
@@ -3337,8 +3365,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
33373365
.value();
33383366

33393367
// 0.5 * x
3340-
auto halfInput = rewriter.create<tosa::MulOp>(op->getLoc(), resultType,
3341-
half, self, /*shift=*/0);
3368+
auto halfInput = rewriter.create<tosa::MulOp>(
3369+
op->getLoc(), resultType, half, self,
3370+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
33423371

33433372
// sqrt(2/pi)
33443373
auto sqrtTwoOverPi =
@@ -3349,9 +3378,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
33493378
rewriter.create<tosa::PowOp>(op->getLoc(), resultType, self, three);
33503379

33513380
// 0.044715 * x^3
3352-
auto inputPowThreeMul =
3353-
rewriter.create<tosa::MulOp>(op->getLoc(), resultType, magicNumber,
3354-
inputPowThree.getResult(), /*shift=*/0);
3381+
auto inputPowThreeMul = rewriter.create<tosa::MulOp>(
3382+
op->getLoc(), resultType, magicNumber, inputPowThree.getResult(),
3383+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
33553384

33563385
// x + 0.044715 * x^3
33573386
auto inputPowThreeMulAdd = rewriter.create<tosa::AddOp>(
@@ -3360,7 +3389,8 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
33603389
// sqrt(2/pi) * (x + 0.044715 * x^3)
33613390
auto sqrtTwoOverPiMul = rewriter.create<tosa::MulOp>(
33623391
op->getLoc(), resultType, sqrtTwoOverPi.getResult(),
3363-
inputPowThreeMulAdd.getResult(), /*shift=*/0);
3392+
inputPowThreeMulAdd.getResult(),
3393+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
33643394

33653395
// tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
33663396
auto tanh = rewriter.create<tosa::TanhOp>(op->getLoc(), resultType,
@@ -3372,7 +3402,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
33723402

33733403
rewriter.replaceOpWithNewOp<tosa::MulOp>(
33743404
op, resultType, halfInput.getResult(), tanhAdd.getResult(),
3375-
/*shift=*/0);
3405+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
33763406
} else {
33773407
return rewriter.notifyMatchFailure(op,
33783408
"Unsupported approximation algorithm");
@@ -3419,22 +3449,26 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
34193449
Value negOneHalf =
34203450
tosa::getConstTensor<float>(rewriter, op, -0.5f, {}, selfElemTy).value();
34213451
Value inputSquared = rewriter.create<tosa::MulOp>(
3422-
loc, selfType, adaptor.getSelf(), adaptor.getSelf(), /*shift=*/0);
3452+
loc, selfType, adaptor.getSelf(), adaptor.getSelf(),
3453+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
34233454
Value negHalfInputSquared = rewriter.create<tosa::MulOp>(
3424-
loc, selfType, inputSquared, negOneHalf, /*shift=*/0);
3455+
loc, selfType, inputSquared, negOneHalf,
3456+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
34253457
Value dinput =
34263458
rewriter.create<tosa::ExpOp>(loc, selfType, negHalfInputSquared);
34273459
Value cdf = buildUnitNormalCdf(rewriter, op, adaptor.getSelf(), selfElemTy);
34283460
Value dinputInput = rewriter.create<tosa::MulOp>(
3429-
loc, selfType, dinput, adaptor.getSelf(), /*shift=*/0);
3461+
loc, selfType, dinput, adaptor.getSelf(),
3462+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
34303463
Value dinputInputAlpha = rewriter.create<tosa::MulOp>(
3431-
loc, selfType, dinputInput, kAlphaHalf, /*shift=*/0);
3464+
loc, selfType, dinputInput, kAlphaHalf,
3465+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
34323466
Value cdfExt =
34333467
rewriter.create<tosa::AddOp>(loc, selfType, dinputInputAlpha, cdf);
34343468
rewriter.replaceOpWithNewOp<tosa::MulOp>(
34353469
op, getTypeConverter()->convertType(op.getType()),
34363470
adaptor.getGradOutput(), cdfExt,
3437-
/*shift=*/0);
3471+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
34383472

34393473
return success();
34403474
}
@@ -4828,8 +4862,9 @@ LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
48284862
rewriter.create<tosa::AbsOp>(op->getLoc(), otherType, adaptor.getOther());
48294863
auto rtolConstOp =
48304864
tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(rtol));
4831-
auto mulOp = rewriter.create<tosa::MulOp>(op->getLoc(), otherType,
4832-
rtolConstOp, lhsAbsOp, /*shift=*/0);
4865+
auto mulOp = rewriter.create<tosa::MulOp>(
4866+
op->getLoc(), otherType, rtolConstOp, lhsAbsOp,
4867+
tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
48334868
auto atolConstOp =
48344869
tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(atol));
48354870
auto addOp =
@@ -5354,7 +5389,8 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
53545389
auto otherTensorReciprocal = rewriter.create<tosa::ReciprocalOp>(
53555390
op.getLoc(), otherTensor.getType(), otherTensor);
53565391
divTensor = rewriter.create<tosa::MulOp>(
5357-
op.getLoc(), outType, self, otherTensorReciprocal, /*shift=*/0);
5392+
op.getLoc(), outType, self, otherTensorReciprocal,
5393+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
53585394
divTensor =
53595395
rewriter.create<tosa::FloorOp>(op.getLoc(), outType, divTensor);
53605396
} else {
@@ -5378,9 +5414,9 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
53785414
}
53795415
}
53805416

5381-
auto mulTensor = rewriter.create<tosa::MulOp>(op.getLoc(), outType,
5382-
otherTensor, divTensor,
5383-
/*shift=*/0);
5417+
auto mulTensor = rewriter.create<tosa::MulOp>(
5418+
op.getLoc(), outType, otherTensor, divTensor,
5419+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
53845420
rewriter.replaceOpWithNewOp<tosa::SubOp>(op, outType, self, mulTensor);
53855421

53865422
return success();
@@ -6572,8 +6608,9 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
65726608
llvm_unreachable("Invalid integer width");
65736609
});
65746610

6575-
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, resultType, self, trilMask,
6576-
/*shift=*/0);
6611+
rewriter.replaceOpWithNewOp<tosa::MulOp>(
6612+
op, resultType, self, trilMask,
6613+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
65776614

65786615
return success();
65796616
}
@@ -6663,14 +6700,16 @@ LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
66636700
auto ceilInput = rewriter.create<tosa::CeilOp>(op->getLoc(), resultTy, self);
66646701

66656702
auto floorInputDivByTwo = rewriter.create<tosa::MulOp>(
6666-
op->getLoc(), resultTy, floorInput.getResult(), oneHalf, /*shift=*/0);
6703+
op->getLoc(), resultTy, floorInput.getResult(), oneHalf,
6704+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
66676705

66686706
auto floorDivResult = rewriter.create<tosa::FloorOp>(
66696707
op->getLoc(), resultTy, floorInputDivByTwo.getResult());
66706708

66716709
// (floor(input) // 2) * 2
66726710
auto evenComparison = rewriter.create<tosa::MulOp>(
6673-
op->getLoc(), resultTy, floorDivResult.getResult(), two, /*shift=*/0);
6711+
op->getLoc(), resultTy, floorDivResult.getResult(), two,
6712+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
66746713

66756714
// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
66766715
auto floorInputEven = rewriter.create<tosa::EqualOp>(
@@ -6849,7 +6888,7 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
68496888

68506889
Value diagonalTensor = rewriter.create<tosa::MulOp>(
68516890
op->getLoc(), transposedInputType, selfTransposed, diagonalMask,
6852-
/*shift=*/0);
6891+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
68536892

68546893
auto resultShape = makeShapeTorchCompatible(resultType.getShape());
68556894
auto targetReduceDim = resultShape[resultType.getRank() - 1];
@@ -8127,9 +8166,9 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
81278166
auto oneMinusZiReciprocal = rewriter.create<tosa::ReciprocalOp>(
81288167
op->getLoc(), resultType, oneMinusZi.getResult());
81298168

8130-
auto mulOp = rewriter.create<tosa::MulOp>(op->getLoc(), resultType, zi,
8131-
oneMinusZiReciprocal.getResult(),
8132-
/*shift=*/0);
8169+
auto mulOp = rewriter.create<tosa::MulOp>(
8170+
op->getLoc(), resultType, zi, oneMinusZiReciprocal.getResult(),
8171+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
81338172

81348173
auto result =
81358174
rewriter.create<tosa::LogOp>(op->getLoc(), resultType, mulOp.getResult());
@@ -8220,7 +8259,7 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
82208259

82218260
auto result = rewriter.create<tosa::MulOp>(
82228261
op->getLoc(), resultType, logOfSelf.getResult(), reciprocalOp.getResult(),
8223-
/*shift=*/0);
8262+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
82248263

82258264
rewriter.replaceOp(op, {result.getResult()});
82268265

@@ -8301,7 +8340,7 @@ LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
83018340

83028341
auto result = rewriter.create<tosa::MulOp>(
83038342
op->getLoc(), resultType, sinOp.getResult(), reciprocalOp.getResult(),
8304-
/*shift=*/0);
8343+
/*shift=*/tosa::getTosaConstTensorSingleI32(rewriter, op, 0));
83058344

83068345
rewriter.replaceOp(op, {result.getResult()});
83078346

0 commit comments

Comments
 (0)