@@ -622,14 +622,16 @@ Value floorIntDiv(PatternRewriter &rewriter, Operation *op, TensorType outType,
622
622
auto boolType =
623
623
RankedTensorType::get (outType.getShape (), rewriter.getIntegerType (1 ));
624
624
625
- auto lhsMulRhs = rewriter.create <tosa::MulOp>(op->getLoc (), i32Type, lhs, rhs,
626
- /* shift=*/ 0 );
625
+ auto lhsMulRhs = rewriter.create <tosa::MulOp>(
626
+ op->getLoc (), i32Type, lhs, rhs,
627
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
627
628
628
629
auto lhsRhsDifferentSign =
629
630
rewriter.create <tosa::GreaterOp>(op->getLoc (), boolType, zero, lhsMulRhs);
630
631
631
- auto truncMulRhs = rewriter.create <tosa::MulOp>(op->getLoc (), i32Type,
632
- intDivOp, rhs, /* shift=*/ 0 );
632
+ auto truncMulRhs = rewriter.create <tosa::MulOp>(
633
+ op->getLoc (), i32Type, intDivOp, rhs,
634
+ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
633
635
634
636
auto truncMulRhsEqualLhs =
635
637
rewriter.create <tosa::EqualOp>(op->getLoc (), boolType, truncMulRhs, lhs);
@@ -853,7 +855,7 @@ LogicalResult ConvertAtenOp<AtenLeakyReluOp>::matchAndRewrite(
853
855
self, zero);
854
856
auto mulTensor = rewriter.create <tosa::MulOp>(
855
857
op->getLoc (), getTypeConverter ()->convertType (op.getType ()), self,
856
- alphaTensor, /* shift= */ 0 );
858
+ alphaTensor, tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ) );
857
859
858
860
rewriter.replaceOpWithNewOp <tosa::SelectOp>(
859
861
op, getTypeConverter ()->convertType (op.getType ()), cond, self, mulTensor);
@@ -2151,8 +2153,9 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
2151
2153
/* checkForUnity=*/ true )))
2152
2154
return failure ();
2153
2155
2154
- auto multTensor = rewriter.create <tosa::MulOp>(op->getLoc (), resultTy, self,
2155
- alphaTensor, /* shift=*/ 0 );
2156
+ auto multTensor = rewriter.create <tosa::MulOp>(
2157
+ op->getLoc (), resultTy, self, alphaTensor,
2158
+ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
2156
2159
2157
2160
rewriter.replaceOpWithNewOp <tosa::SubOp>(op, resultTy, otherTensor,
2158
2161
multTensor);
@@ -2493,12 +2496,14 @@ Value computeBatchNorm(Operation *op, ConversionPatternRewriter &rewriter,
2493
2496
auto op3RsqrtOp2 = rewriter.create <tosa::RsqrtOp>(
2494
2497
op->getLoc (), variance.getType (), op2AddVarEpsilon.getResult ());
2495
2498
2496
- auto op4MulOp1Op3 = rewriter.create <tosa::MulOp>(op->getLoc (), outType,
2497
- op1SubInputMean.getResult (),
2498
- op3RsqrtOp2.getResult (), 0 );
2499
+ auto op4MulOp1Op3 = rewriter.create <tosa::MulOp>(
2500
+ op->getLoc (), outType, op1SubInputMean.getResult (),
2501
+ op3RsqrtOp2.getResult (),
2502
+ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
2499
2503
2500
2504
auto op5MulOp4Scale = rewriter.create <tosa::MulOp>(
2501
- op->getLoc (), outType, op4MulOp1Op3.getResult (), weight, 0 );
2505
+ op->getLoc (), outType, op4MulOp1Op3.getResult (), weight,
2506
+ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
2502
2507
2503
2508
return rewriter
2504
2509
.create <tosa::AddOp>(op->getLoc (), outType, op5MulOp4Scale.getResult (),
@@ -2710,19 +2715,22 @@ LogicalResult ConvertAtenOp<AtenNativeLayerNormOp>::matchAndRewrite(
2710
2715
// Compute mean.
2711
2716
Value sum = computeSumAndReshape (adaptor.getInput (), inputType, bcastOutType,
2712
2717
bcastOutShape);
2713
- Value meanVal = rewriter.create <tosa::MulOp>(op.getLoc (), bcastOutType, sum,
2714
- elemCntRcp, /* shift=*/ 0 );
2718
+ Value meanVal = rewriter.create <tosa::MulOp>(
2719
+ op.getLoc (), bcastOutType, sum, elemCntRcp,
2720
+ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
2715
2721
2716
2722
// Compute variance.
2717
2723
Value squareSumSub = rewriter.create <tosa::SubOp>(
2718
2724
op.getLoc (), inputType, adaptor.getInput (), meanVal);
2719
- Value squareSum = rewriter.create <tosa::MulOp>(op.getLoc (), inputType,
2720
- squareSumSub, squareSumSub, 0 );
2725
+ Value squareSum = rewriter.create <tosa::MulOp>(
2726
+ op.getLoc (), inputType, squareSumSub, squareSumSub,
2727
+ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
2721
2728
2722
2729
Value squareSumReduced =
2723
2730
computeSumAndReshape (squareSum, inputType, bcastOutType, bcastOutShape);
2724
2731
Value varianceVal = rewriter.create <tosa::MulOp>(
2725
- op.getLoc (), bcastOutType, squareSumReduced, elemCntRcp, /* shift=*/ 0 );
2732
+ op.getLoc (), bcastOutType, squareSumReduced, elemCntRcp,
2733
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
2726
2734
2727
2735
// Reshape weight and bias.
2728
2736
SmallVector<int64_t > weightAndBiasBcastShape;
@@ -2978,8 +2986,9 @@ LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
2978
2986
rewriter.create <tosa::ReciprocalOp>(op.getLoc (), ln2Op.getType (), ln2Op);
2979
2987
2980
2988
auto logOp = rewriter.create <tosa::LogOp>(op.getLoc (), outType, self);
2981
- rewriter.replaceOpWithNewOp <tosa::MulOp>(op, outType, logOp, rcpOp,
2982
- /* shift=*/ 0 );
2989
+ rewriter.replaceOpWithNewOp <tosa::MulOp>(
2990
+ op, outType, logOp, rcpOp,
2991
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
2983
2992
2984
2993
return success ();
2985
2994
}
@@ -3195,32 +3204,48 @@ static Value approximateErfOp(ConversionPatternRewriter &rewriter,
3195
3204
3196
3205
auto a1 =
3197
3206
tosa::getConstTensor<float >(rewriter, op, 0 .278393f , {}, dtype).value ();
3198
- auto a1X = rewriter.create <tosa::MulOp>(loc, outType, a1, absX, /* shift=*/ 0 );
3207
+ auto a1X = rewriter.create <tosa::MulOp>(
3208
+ loc, outType, a1, absX,
3209
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3199
3210
auto sum = rewriter.create <tosa::AddOp>(loc, outType, a1X, one);
3200
3211
3201
3212
auto a2 =
3202
3213
tosa::getConstTensor<float >(rewriter, op, 0 .230389f , {}, dtype).value ();
3203
- auto x2 = rewriter.create <tosa::MulOp>(loc, outType, absX, absX, /* shift=*/ 0 );
3204
- auto a2X = rewriter.create <tosa::MulOp>(loc, outType, a2, x2, /* shift=*/ 0 );
3214
+ auto x2 = rewriter.create <tosa::MulOp>(
3215
+ loc, outType, absX, absX,
3216
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3217
+ auto a2X = rewriter.create <tosa::MulOp>(
3218
+ loc, outType, a2, x2,
3219
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3205
3220
sum = rewriter.create <tosa::AddOp>(loc, outType, sum, a2X);
3206
3221
3207
3222
auto a3 =
3208
3223
tosa::getConstTensor<float >(rewriter, op, 0 .000972f , {}, dtype).value ();
3209
- auto x3 = rewriter.create <tosa::MulOp>(loc, outType, x2, absX, /* shift=*/ 0 );
3210
- auto a3X = rewriter.create <tosa::MulOp>(loc, outType, a3, x3, /* shift=*/ 0 );
3224
+ auto x3 = rewriter.create <tosa::MulOp>(
3225
+ loc, outType, x2, absX,
3226
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3227
+ auto a3X = rewriter.create <tosa::MulOp>(
3228
+ loc, outType, a3, x3,
3229
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3211
3230
sum = rewriter.create <tosa::AddOp>(loc, outType, sum, a3X);
3212
3231
3213
3232
auto a4 =
3214
3233
tosa::getConstTensor<float >(rewriter, op, 0 .078108f , {}, dtype).value ();
3215
- auto x4 = rewriter.create <tosa::MulOp>(loc, outType, x3, absX, /* shift=*/ 0 );
3216
- auto a4X = rewriter.create <tosa::MulOp>(loc, outType, a4, x4, /* shift=*/ 0 );
3234
+ auto x4 = rewriter.create <tosa::MulOp>(
3235
+ loc, outType, x3, absX,
3236
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3237
+ auto a4X = rewriter.create <tosa::MulOp>(
3238
+ loc, outType, a4, x4,
3239
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3217
3240
sum = rewriter.create <tosa::AddOp>(loc, outType, sum, a4X);
3218
3241
3219
3242
auto rcprl = rewriter.create <tosa::ReciprocalOp>(loc, outType, sum);
3220
- auto rcprl2 =
3221
- rewriter.create <tosa::MulOp>(loc, outType, rcprl, rcprl, /* shift=*/ 0 );
3222
- auto rcprl4 =
3223
- rewriter.create <tosa::MulOp>(loc, outType, rcprl2, rcprl2, /* shift=*/ 0 );
3243
+ auto rcprl2 = rewriter.create <tosa::MulOp>(
3244
+ loc, outType, rcprl, rcprl,
3245
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3246
+ auto rcprl4 = rewriter.create <tosa::MulOp>(
3247
+ loc, outType, rcprl2, rcprl2,
3248
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3224
3249
auto erf = rewriter.create <tosa::SubOp>(loc, outType, one, rcprl4);
3225
3250
3226
3251
// Deal with negative x.
@@ -3248,15 +3273,17 @@ static Value buildUnitNormalCdf(ConversionPatternRewriter &rewriter,
3248
3273
Value rsqrt2 =
3249
3274
tosa::getConstTensor<float >(rewriter, op, 0 .70710678f , {}, dtype).value ();
3250
3275
3251
- Value erfArg = rewriter.create <tosa::MulOp>(loc, outType, xMinusMean, rsqrt2,
3252
- /* shift=*/ 0 );
3276
+ Value erfArg = rewriter.create <tosa::MulOp>(
3277
+ loc, outType, xMinusMean, rsqrt2,
3278
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3253
3279
Value erf = approximateErfOp (rewriter, op, erfArg, dtype);
3254
3280
Value erfPlus1 = rewriter.create <tosa::AddOp>(loc, outType, one, erf );
3255
3281
Value oneHalf =
3256
3282
tosa::getConstTensor<float >(rewriter, op, 0.5 , {}, dtype).value ();
3257
3283
3258
- Value normalCdf = rewriter.create <tosa::MulOp>(loc, outType, oneHalf,
3259
- erfPlus1, /* shift=*/ 0 );
3284
+ Value normalCdf = rewriter.create <tosa::MulOp>(
3285
+ loc, outType, oneHalf, erfPlus1,
3286
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3260
3287
return normalCdf;
3261
3288
}
3262
3289
@@ -3295,8 +3322,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
3295
3322
op->getLoc (),
3296
3323
cast<RankedTensorType>(cdf.getType ()).cloneWith ({}, selfElemTy), cdf);
3297
3324
3298
- rewriter.replaceOpWithNewOp <tosa::MulOp>(op, resultType, self, cdf,
3299
- /* shift=*/ 0 );
3325
+ rewriter.replaceOpWithNewOp <tosa::MulOp>(
3326
+ op, resultType, self, cdf,
3327
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3300
3328
} else if (approximate.compare (" tanh" ) == 0 ) {
3301
3329
// "tanh" approximate
3302
3330
// GELU(x) = 0.5 * x * (1 + Tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
@@ -3337,8 +3365,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
3337
3365
.value ();
3338
3366
3339
3367
// 0.5 * x
3340
- auto halfInput = rewriter.create <tosa::MulOp>(op->getLoc (), resultType,
3341
- half, self, /* shift=*/ 0 );
3368
+ auto halfInput = rewriter.create <tosa::MulOp>(
3369
+ op->getLoc (), resultType, half, self,
3370
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3342
3371
3343
3372
// sqrt(2/pi)
3344
3373
auto sqrtTwoOverPi =
@@ -3349,9 +3378,9 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
3349
3378
rewriter.create <tosa::PowOp>(op->getLoc (), resultType, self, three);
3350
3379
3351
3380
// 0.044715 * x^3
3352
- auto inputPowThreeMul =
3353
- rewriter. create <tosa::MulOp>( op->getLoc (), resultType, magicNumber,
3354
- inputPowThree. getResult (), /* shift=*/ 0 );
3381
+ auto inputPowThreeMul = rewriter. create <tosa::MulOp>(
3382
+ op->getLoc (), resultType, magicNumber, inputPowThree. getResult () ,
3383
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ) );
3355
3384
3356
3385
// x + 0.044715 * x^3
3357
3386
auto inputPowThreeMulAdd = rewriter.create <tosa::AddOp>(
@@ -3360,7 +3389,8 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
3360
3389
// sqrt(2/pi) * (x + 0.044715 * x^3)
3361
3390
auto sqrtTwoOverPiMul = rewriter.create <tosa::MulOp>(
3362
3391
op->getLoc (), resultType, sqrtTwoOverPi.getResult (),
3363
- inputPowThreeMulAdd.getResult (), /* shift=*/ 0 );
3392
+ inputPowThreeMulAdd.getResult (),
3393
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3364
3394
3365
3395
// tanh(sqrt(2/pi) * (x + 0.044715 * x^3))
3366
3396
auto tanh = rewriter.create <tosa::TanhOp>(op->getLoc (), resultType,
@@ -3372,7 +3402,7 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
3372
3402
3373
3403
rewriter.replaceOpWithNewOp <tosa::MulOp>(
3374
3404
op, resultType, halfInput.getResult (), tanhAdd.getResult (),
3375
- /* shift=*/ 0 );
3405
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ) );
3376
3406
} else {
3377
3407
return rewriter.notifyMatchFailure (op,
3378
3408
" Unsupported approximation algorithm" );
@@ -3419,22 +3449,26 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
3419
3449
Value negOneHalf =
3420
3450
tosa::getConstTensor<float >(rewriter, op, -0 .5f , {}, selfElemTy).value ();
3421
3451
Value inputSquared = rewriter.create <tosa::MulOp>(
3422
- loc, selfType, adaptor.getSelf (), adaptor.getSelf (), /* shift=*/ 0 );
3452
+ loc, selfType, adaptor.getSelf (), adaptor.getSelf (),
3453
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3423
3454
Value negHalfInputSquared = rewriter.create <tosa::MulOp>(
3424
- loc, selfType, inputSquared, negOneHalf, /* shift=*/ 0 );
3455
+ loc, selfType, inputSquared, negOneHalf,
3456
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3425
3457
Value dinput =
3426
3458
rewriter.create <tosa::ExpOp>(loc, selfType, negHalfInputSquared);
3427
3459
Value cdf = buildUnitNormalCdf (rewriter, op, adaptor.getSelf (), selfElemTy);
3428
3460
Value dinputInput = rewriter.create <tosa::MulOp>(
3429
- loc, selfType, dinput, adaptor.getSelf (), /* shift=*/ 0 );
3461
+ loc, selfType, dinput, adaptor.getSelf (),
3462
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3430
3463
Value dinputInputAlpha = rewriter.create <tosa::MulOp>(
3431
- loc, selfType, dinputInput, kAlphaHalf , /* shift=*/ 0 );
3464
+ loc, selfType, dinputInput, kAlphaHalf ,
3465
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
3432
3466
Value cdfExt =
3433
3467
rewriter.create <tosa::AddOp>(loc, selfType, dinputInputAlpha, cdf);
3434
3468
rewriter.replaceOpWithNewOp <tosa::MulOp>(
3435
3469
op, getTypeConverter ()->convertType (op.getType ()),
3436
3470
adaptor.getGradOutput (), cdfExt,
3437
- /* shift=*/ 0 );
3471
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ) );
3438
3472
3439
3473
return success ();
3440
3474
}
@@ -4828,8 +4862,9 @@ LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
4828
4862
rewriter.create <tosa::AbsOp>(op->getLoc (), otherType, adaptor.getOther ());
4829
4863
auto rtolConstOp =
4830
4864
tosa::getTosaConstTensorSingleF32 (rewriter, op, static_cast <float >(rtol));
4831
- auto mulOp = rewriter.create <tosa::MulOp>(op->getLoc (), otherType,
4832
- rtolConstOp, lhsAbsOp, /* shift=*/ 0 );
4865
+ auto mulOp = rewriter.create <tosa::MulOp>(
4866
+ op->getLoc (), otherType, rtolConstOp, lhsAbsOp,
4867
+ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
4833
4868
auto atolConstOp =
4834
4869
tosa::getTosaConstTensorSingleF32 (rewriter, op, static_cast <float >(atol ));
4835
4870
auto addOp =
@@ -5354,7 +5389,8 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
5354
5389
auto otherTensorReciprocal = rewriter.create <tosa::ReciprocalOp>(
5355
5390
op.getLoc (), otherTensor.getType (), otherTensor);
5356
5391
divTensor = rewriter.create <tosa::MulOp>(
5357
- op.getLoc (), outType, self, otherTensorReciprocal, /* shift=*/ 0 );
5392
+ op.getLoc (), outType, self, otherTensorReciprocal,
5393
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
5358
5394
divTensor =
5359
5395
rewriter.create <tosa::FloorOp>(op.getLoc (), outType, divTensor);
5360
5396
} else {
@@ -5378,9 +5414,9 @@ class ConvertAtenRemainderFmodOp : public OpConversionPattern<AtenOpT> {
5378
5414
}
5379
5415
}
5380
5416
5381
- auto mulTensor = rewriter.create <tosa::MulOp>(op. getLoc (), outType,
5382
- otherTensor, divTensor,
5383
- /* shift=*/ 0 );
5417
+ auto mulTensor = rewriter.create <tosa::MulOp>(
5418
+ op. getLoc (), outType, otherTensor, divTensor,
5419
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ) );
5384
5420
rewriter.replaceOpWithNewOp <tosa::SubOp>(op, outType, self, mulTensor);
5385
5421
5386
5422
return success ();
@@ -6572,8 +6608,9 @@ LogicalResult ConvertAtenOp<AtenTrilOp>::matchAndRewrite(
6572
6608
llvm_unreachable (" Invalid integer width" );
6573
6609
});
6574
6610
6575
- rewriter.replaceOpWithNewOp <tosa::MulOp>(op, resultType, self, trilMask,
6576
- /* shift=*/ 0 );
6611
+ rewriter.replaceOpWithNewOp <tosa::MulOp>(
6612
+ op, resultType, self, trilMask,
6613
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
6577
6614
6578
6615
return success ();
6579
6616
}
@@ -6663,14 +6700,16 @@ LogicalResult ConvertAtenOp<AtenRoundOp>::matchAndRewrite(
6663
6700
auto ceilInput = rewriter.create <tosa::CeilOp>(op->getLoc (), resultTy, self);
6664
6701
6665
6702
auto floorInputDivByTwo = rewriter.create <tosa::MulOp>(
6666
- op->getLoc (), resultTy, floorInput.getResult (), oneHalf, /* shift=*/ 0 );
6703
+ op->getLoc (), resultTy, floorInput.getResult (), oneHalf,
6704
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
6667
6705
6668
6706
auto floorDivResult = rewriter.create <tosa::FloorOp>(
6669
6707
op->getLoc (), resultTy, floorInputDivByTwo.getResult ());
6670
6708
6671
6709
// (floor(input) // 2) * 2
6672
6710
auto evenComparison = rewriter.create <tosa::MulOp>(
6673
- op->getLoc (), resultTy, floorDivResult.getResult (), two, /* shift=*/ 0 );
6711
+ op->getLoc (), resultTy, floorDivResult.getResult (), two,
6712
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ));
6674
6713
6675
6714
// floor(input) // 2) * 2 == input <=> floor(input) % 2 == 0
6676
6715
auto floorInputEven = rewriter.create <tosa::EqualOp>(
@@ -6849,7 +6888,7 @@ LogicalResult ConvertAtenOp<AtenDiagonalOp>::matchAndRewrite(
6849
6888
6850
6889
Value diagonalTensor = rewriter.create <tosa::MulOp>(
6851
6890
op->getLoc (), transposedInputType, selfTransposed, diagonalMask,
6852
- /* shift=*/ 0 );
6891
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ) );
6853
6892
6854
6893
auto resultShape = makeShapeTorchCompatible (resultType.getShape ());
6855
6894
auto targetReduceDim = resultShape[resultType.getRank () - 1 ];
@@ -8127,9 +8166,9 @@ LogicalResult ConvertAtenOp<AtenLogitOp>::matchAndRewrite(
8127
8166
auto oneMinusZiReciprocal = rewriter.create <tosa::ReciprocalOp>(
8128
8167
op->getLoc (), resultType, oneMinusZi.getResult ());
8129
8168
8130
- auto mulOp = rewriter.create <tosa::MulOp>(op-> getLoc (), resultType, zi,
8131
- oneMinusZiReciprocal.getResult (),
8132
- /* shift=*/ 0 );
8169
+ auto mulOp = rewriter.create <tosa::MulOp>(
8170
+ op-> getLoc (), resultType, zi, oneMinusZiReciprocal.getResult (),
8171
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ) );
8133
8172
8134
8173
auto result =
8135
8174
rewriter.create <tosa::LogOp>(op->getLoc (), resultType, mulOp.getResult ());
@@ -8220,7 +8259,7 @@ LogicalResult ConvertAtenOp<AtenLog10Op>::matchAndRewrite(
8220
8259
8221
8260
auto result = rewriter.create <tosa::MulOp>(
8222
8261
op->getLoc (), resultType, logOfSelf.getResult (), reciprocalOp.getResult (),
8223
- /* shift=*/ 0 );
8262
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ) );
8224
8263
8225
8264
rewriter.replaceOp (op, {result.getResult ()});
8226
8265
@@ -8301,7 +8340,7 @@ LogicalResult ConvertAtenOp<AtenTanOp>::matchAndRewrite(
8301
8340
8302
8341
auto result = rewriter.create <tosa::MulOp>(
8303
8342
op->getLoc (), resultType, sinOp.getResult (), reciprocalOp.getResult (),
8304
- /* shift=*/ 0 );
8343
+ /* shift=*/ tosa::getTosaConstTensorSingleI32 (rewriter, op, 0 ) );
8305
8344
8306
8345
rewriter.replaceOp (op, {result.getResult ()});
8307
8346
0 commit comments