Skip to content

Emit tosa::erf during lowering gelu op #4151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 1 addition & 72 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3369,77 +3369,6 @@ LogicalResult ConvertAtenOp<AtenViewOp>::matchAndRewrite(
return success();
}

static std::optional<Value>
approximateErfOp(ConversionPatternRewriter &rewriter, Operation *op, Value x,
Type dtype) {
// Using:
// https://en.wikipedia.org/wiki/Error_function#Numerical_approximations with
// maximum error as 5 x 10^-4 where a1 = 0.278393, a2 = 0.230389, a3 =
// 0.000972, a4 = 0.078108.
//
// Erf = 1 - 1 / (1 + a1X + a2X + a3X + a4X)^4

auto outType = cast<TensorType>(x.getType());
auto loc = op->getLoc();
auto absX = rewriter.create<tosa::AbsOp>(loc, outType, x);
auto zero = tosa::getConstTensor<float>(rewriter, op, 0, {}, dtype).value();
auto one = tosa::getConstTensor<float>(rewriter, op, 1, {}, dtype).value();
auto a1 =
tosa::getConstTensor<float>(rewriter, op, 0.278393f, {}, dtype).value();
auto a2 =
tosa::getConstTensor<float>(rewriter, op, 0.230389f, {}, dtype).value();
auto a3 =
tosa::getConstTensor<float>(rewriter, op, 0.000972f, {}, dtype).value();
auto a4 =
tosa::getConstTensor<float>(rewriter, op, 0.078108f, {}, dtype).value();

if (mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, zero).failed() ||
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, one).failed() ||
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a1).failed() ||
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a2).failed() ||
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a3).failed() ||
mlir::tosa::EqualizeRanks(rewriter, op->getLoc(), x, a4).failed())
return std::nullopt;

auto a1X =
tosa::createMulOpAndCast(rewriter, op, outType, a1, absX, /*shift=*/0);
auto sum = rewriter.create<tosa::AddOp>(loc, outType, a1X, one);

auto x2 =
tosa::createMulOpAndCast(rewriter, op, outType, absX, absX, /*shift=*/0);
auto a2X =
tosa::createMulOpAndCast(rewriter, op, outType, a2, x2, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a2X);

auto x3 =
tosa::createMulOpAndCast(rewriter, op, outType, x2, absX, /*shift=*/0);
auto a3X =
tosa::createMulOpAndCast(rewriter, op, outType, a3, x3, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a3X);

auto x4 =
tosa::createMulOpAndCast(rewriter, op, outType, x3, absX, /*shift=*/0);
auto a4X =
tosa::createMulOpAndCast(rewriter, op, outType, a4, x4, /*shift=*/0);
sum = rewriter.create<tosa::AddOp>(loc, outType, sum, a4X);

auto rcprl = rewriter.create<tosa::ReciprocalOp>(loc, outType, sum);
auto rcprl2 = tosa::createMulOpAndCast(rewriter, op, outType, rcprl, rcprl,
/*shift=*/0);
auto rcprl4 = tosa::createMulOpAndCast(rewriter, op, outType, rcprl2, rcprl2,
/*shift=*/0);
auto erf = rewriter.create<tosa::SubOp>(loc, outType, one, rcprl4);

// Deal with negative x.
auto cond = rewriter.create<tosa::GreaterEqualOp>(
loc,
RankedTensorType::get(outType.getShape(), rewriter.getIntegerType(1)), x,
zero);
auto negateErf = rewriter.create<tosa::NegateOp>(loc, outType, erf);

return rewriter.create<tosa::SelectOp>(loc, outType, cond, erf, negateErf);
}

static std::optional<Value>
buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x,
Type dtype) {
Expand Down Expand Up @@ -3467,7 +3396,7 @@ buildUnitNormalCdf(ConversionPatternRewriter &rewriter, Operation *op, Value x,
Value erfArg =
tosa::createMulOpAndCast(rewriter, op, outType, xMinusMean, rsqrt2,
/*shift=*/0);
Value erf = approximateErfOp(rewriter, op, erfArg, dtype).value();
Value erf = rewriter.create<tosa::ErfOp>(loc, outType, erfArg);
Value erfPlus1 = rewriter.create<tosa::AddOp>(loc, outType, one, erf);

Value normalCdf = tosa::createMulOpAndCast(rewriter, op, outType, oneHalf,
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2195,6 +2195,7 @@
"ElementwiseGeIntScalarModule_basic",
"ElementwiseGeMixedIntScalarModule_basic",
"ElementwiseGeluModule_basic",
"ElementwiseGeluTosaModule_basic",
"ElementwiseGtFloatScalarModule_basic",
"ElementwiseGtFloatTensorModule_basic",
"ElementwiseGtIntScalarModule_basic",
Expand Down
24 changes: 24 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,6 +1374,30 @@ def ElementwiseGeluModule_basic(module, tu: TestUtils):
# ==============================================================================


class ElementwiseGeluTosaModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([-1, -1], torch.float32, True),
]
)
def forward(self, x):
x = torch.ops.aten.gelu(x)
return x


@register_test_case(module_factory=lambda: ElementwiseGeluTosaModule())
def ElementwiseGeluTosaModule_basic(module, tu: TestUtils):
module.forward(tu.rand(50, 30, low=-2.7, high=2.7))


# ==============================================================================


class ElementwiseGeluApproximateTanhModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
36 changes: 36 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3071,6 +3071,42 @@ func.func @torch.aten.upsample_nearest2d.vec$basic(%arg0: !torch.vtensor<[1,1,4,

// -----

// CHECK-LABEL: func.func @torch.aten.gelu$none(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1500,1536],f32>) -> !torch.vtensor<[1,1500,1536],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1500,1536],f32> -> tensor<1x1500x1536xf32>
// CHECK: %[[VAL_2:.*]] = torch.constant.str "none"
// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<5.000000e-01> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0.707106769> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_7]] : (tensor<f32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_9]] : (tensor<f32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
// CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_11]] : (tensor<f32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
// CHECK: %[[VAL_13:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_13]] : (tensor<f32>, !tosa.shape<3>) -> tensor<1x1x1xf32>
// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_1]], %[[VAL_8]] : (tensor<1x1500x1536xf32>, tensor<1x1x1xf32>) -> tensor<1x1500x1536xf32>
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_15]], %[[VAL_14]], %[[VAL_16]] : (tensor<1x1500x1536xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<1x1500x1536xf32>
// CHECK: %[[VAL_18:.*]] = tosa.erf %[[VAL_17]] : (tensor<1x1500x1536xf32>) -> tensor<1x1500x1536xf32>
// CHECK: %[[VAL_19:.*]] = tosa.add %[[VAL_10]], %[[VAL_18]] : (tensor<1x1x1xf32>, tensor<1x1500x1536xf32>) -> tensor<1x1500x1536xf32>
// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_12]], %[[VAL_19]], %[[VAL_20]] : (tensor<1x1x1xf32>, tensor<1x1500x1536xf32>, tensor<1xi8>) -> tensor<1x1500x1536xf32>
// CHECK: %[[VAL_22:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_23:.*]] = tosa.mul %[[VAL_1]], %[[VAL_21]], %[[VAL_22]] : (tensor<1x1500x1536xf32>, tensor<1x1500x1536xf32>, tensor<1xi8>) -> tensor<1x1500x1536xf32>
// CHECK: %[[VAL_24:.*]] = torch_c.from_builtin_tensor %[[VAL_23]] : tensor<1x1500x1536xf32> -> !torch.vtensor<[1,1500,1536],f32>
// CHECK: return %[[VAL_24]] : !torch.vtensor<[1,1500,1536],f32>
// CHECK: }
func.func @torch.aten.gelu$none(%arg0: !torch.vtensor<[1,1500,1536],f32>) -> !torch.vtensor<[1,1500,1536],f32> {
%str = torch.constant.str "none"
%0 = torch.aten.gelu %arg0, %str : !torch.vtensor<[1,1500,1536],f32>, !torch.str -> !torch.vtensor<[1,1500,1536],f32>
return %0 : !torch.vtensor<[1,1500,1536],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.gelu$tanh(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[5,3],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,3],f32> -> tensor<5x3xf32>
Expand Down
Loading