From c785435a049a694a1814a7304ed0c34abe8b2580 Mon Sep 17 00:00:00 2001 From: Vinit Deodhar Date: Tue, 22 Apr 2025 00:16:49 -0400 Subject: [PATCH 1/2] Emit tosa::erf during lowering gelu op --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 73 +--------------------- test/Conversion/TorchToTosa/basic.mlir | 36 +++++++++++ 2 files changed, 37 insertions(+), 72 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index d0316f26422e..cfda837e7391 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -3369,77 +3369,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -static std::optional -approximateErfOp(ConversionPatternRewriter &rewriter, Operation *op, Value x, - Type dtype) { - // Using: - // https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with - // maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 = - // 0.000972, a4 = 0.078108. - // - // Erf = 1 - 1 / (1 + a1X + a2X + a3X + a4X)^4 - - auto outType = cast(x.getType()); - auto loc = op->getLoc(); - auto absX = rewriter.create(loc, outType, x); - auto zero = tosa::getConstTensor(rewriter, op, 0, {}, dtype).value(); - auto one = tosa::getConstTensor(rewriter, op, 1, {}, dtype).value(); - auto a1 = - tosa::getConstTensor(rewriter, op, 0.278393f, {}, dtype).value(); - auto a2 = - tosa::getConstTensor(rewriter, op, 0.230389f, {}, dtype).value(); - auto a3 = - tosa::getConstTensor(rewriter, op, 0.000972f, {}, dtype).value(); - auto a4 = - tosa::getConstTensor(rewriter, op, 0.078108f, {}, dtype).value(); - - if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, zero).failed() || - mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, one).failed() || - mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a1).failed() || - mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a2).failed() || - mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a3).failed() || - mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a4).failed()) - return std::nullopt; - - auto a1X = - tosa::createMulOpAndCast(rewriter, op, outType, a1, absX, /*shift=*/0); - auto sum = rewriter.create(loc, outType, a1X, one); - - auto x2 = - tosa::createMulOpAndCast(rewriter, op, outType, absX, absX, /*shift=*/0); - auto a2X = - tosa::createMulOpAndCast(rewriter, op, outType, a2, x2, /*shift=*/0); - sum = rewriter.create(loc, outType, sum, a2X); - - auto x3 = - tosa::createMulOpAndCast(rewriter, op, outType, x2, absX, /*shift=*/0); - auto a3X = - tosa::createMulOpAndCast(rewriter, op, outType, a3, x3, /*shift=*/0); - sum = rewriter.create(loc, outType, sum, a3X); - - auto x4 = - tosa::createMulOpAndCast(rewriter, op, outType, x3, absX, /*shift=*/0); - auto a4X = - tosa::createMulOpAndCast(rewriter, op, outType, a4, x4, /*shift=*/0); - sum = rewriter.create(loc, outType, sum, a4X); - - auto rcprl = rewriter.create(loc, outType, sum); - auto rcprl2 = tosa::createMulOpAndCast(rewriter, op, outType, rcprl, rcprl, - /*shift=*/0); - auto rcprl4 = tosa::createMulOpAndCast(rewriter, op, outType, rcprl2, rcprl2, - /*shift=*/0); - auto erf = rewriter.create(loc, outType, one, rcprl4); - - // Deal with negative x. - auto cond = rewriter.create( - loc, - RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)), x, - zero); - auto negateErf = rewriter.create(loc, outType, erf); - - return rewriter.create(loc, outType, cond, erf, negateErf); -} - static std::optional buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x, Type dtype) { @@ -3467,7 +3396,7 @@ buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x, Value erfArg = tosa::createMulOpAndCast(rewriter, op, outType, xMinusMean, rsqrt2, /*shift=*/0); - Value erf = approximateErfOp(rewriter, op, erfArg, dtype).value(); + Value erf = rewriter.create(loc, outType, erfArg); Value erfPlus1 = rewriter.create(loc, outType, one, erf); Value normalCdf = tosa::createMulOpAndCast(rewriter, op, outType, oneHalf, diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 5ab101bab3bc..3d48c49f1fef 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -3071,6 +3071,42 @@ func.func @torch.aten.upsample_nearest2d.vec$basic(%arg0: !torch.vtensor<[1,1,4, // ----- +// CHECK-LABEL: func.func @torch.aten.gelu$none( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1500,1536],f32>) -> !torch.vtensor<[1,1500,1536],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1500,1536],f32> -> tensor<1x1500x1536xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.str "none" +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<5.000000e-01> : tensor}> : () -> tensor +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0.707106769> : tensor}> : () -> tensor +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_7]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_9]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_11]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_13]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xf32> +// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<1x1500x1536xf32>, tensor<1x1x1xf32>) -> tensor<1x1500x1536xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_15]], %[[VAL_14]], %[[VAL_16]] : (tensor<1x1500x1536xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<1x1500x1536xf32> +// CHECK: %[[VAL_18:.*]] = tosa.erf %[[VAL_17]] : (tensor<1x1500x1536xf32>) -> tensor<1x1500x1536xf32> +// CHECK: %[[VAL_19:.*]] = tosa.add %[[VAL_10]], %[[VAL_18]] : (tensor<1x1x1xf32>, tensor<1x1500x1536xf32>) -> tensor<1x1500x1536xf32> +// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_12]], %[[VAL_19]], %[[VAL_20]] : (tensor<1x1x1xf32>, tensor<1x1500x1536xf32>, tensor<1xi8>) -> tensor<1x1500x1536xf32> +// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_1]], %[[VAL_21]], %[[VAL_22]] : (tensor<1x1500x1536xf32>, tensor<1x1500x1536xf32>, tensor<1xi8>) -> tensor<1x1500x1536xf32> +// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<1x1500x1536xf32> -> !torch.vtensor<[1,1500,1536],f32> +// CHECK: return %[[VAL_24]] : !torch.vtensor<[1,1500,1536],f32> +// CHECK: } +func.func @torch.aten.gelu$none(%arg0: !torch.vtensor<[1,1500,1536],f32>) -> !torch.vtensor<[1,1500,1536],f32> { + %str = torch.constant.str "none" + %0 = torch.aten.gelu %arg0, %str : !torch.vtensor<[1,1500,1536],f32>, !torch.str -> !torch.vtensor<[1,1500,1536],f32> + return %0 : !torch.vtensor<[1,1500,1536],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.gelu$tanh( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[5,3],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,3],f32> -> tensor<5x3xf32> From 34f5bf2302641d403801e1d8d1fe94e76aaaa71b Mon Sep 17 00:00:00 2001 From: Vinit Deodhar Date: Thu, 24 Apr 2025 22:24:48 -0400 Subject: [PATCH 2/2] e2e test for gelu --- projects/pt1/e2e_testing/xfail_sets.py | 1 + .../test_suite/elementwise.py | 24 +++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 202378d1f9ac..3d47d61ca26a 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2195,6 +2195,7 @@ "ElementwiseGeIntScalarModule_basic", "ElementwiseGeMixedIntScalarModule_basic", "ElementwiseGeluModule_basic", + "ElementwiseGeluTosaModule_basic", "ElementwiseGtFloatScalarModule_basic", "ElementwiseGtFloatTensorModule_basic", "ElementwiseGtIntScalarModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 3ee851611ac0..a6b1db7f4a2b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1374,6 +1374,30 @@ def ElementwiseGeluModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseGeluTosaModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1], torch.float32, True), + ] + ) + def forward(self, x): + x = torch.ops.aten.gelu(x) + return x + + +@register_test_case(module_factory=lambda: ElementwiseGeluTosaModule()) +def ElementwiseGeluTosaModule_basic(module, tu: TestUtils): + module.forward(tu.rand(50, 30, low=-2.7, high=2.7)) + + +# ============================================================================== + + class ElementwiseGeluApproximateTanhModule(torch.nn.Module): def __init__(self): super().__init__()