Skip to content

Commit 6024d7c

Browse files
titaiwangmsbmehta001
authored andcommitted
[rewriter | torchlib] respect ops order in torchscript graph (microsoft#2134)
This helps us to match the optimization pattern in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/fusion_fastgelu.py ref: microsoft#2132 (comment)
1 parent fe27b31 commit 6024d7c

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -487,8 +487,8 @@ def _aten_gelu_approximate_none(self: TReal) -> TReal:
487487
inner = op.Div(self, 1.4142135623730951)
488488
erf = op.Erf(inner)
489489
inner = op.Add(erf, 1)
490-
inner = op.Mul(self, inner)
491-
result = op.Mul(0.5, inner)
490+
inner = op.Mul(0.5, inner)
491+
result = op.Mul(self, inner)
492492
return result
493493

494494

@@ -505,8 +505,8 @@ def _aten_gelu_approximate_tanh(self: TReal) -> TReal:
505505
inner = op.Mul(op.Sqrt(two_over_pi), inner)
506506
inner = op.Tanh(inner)
507507
inner = op.Add(inner, 1)
508-
inner = op.Mul(self, inner)
509-
result = op.Mul(0.5, inner)
508+
inner = op.Mul(0.5, inner)
509+
result = op.Mul(self, inner)
510510
return result
511511

512512

onnxscript/rewriter/ort_fusions/gelu.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ def pattern(self, op, x):
2020
t4 = op.Mul(_sqrt_two_over_pi, t3)
2121
t5 = op.Tanh(t4)
2222
t6 = op.Add(t5, 1)
23-
t7 = op.Mul(x, t6)
24-
result = op.Mul(0.5, t7)
23+
t7 = op.Mul(0.5, t6)
24+
result = op.Mul(x, t7)
2525
return result
2626

2727
def rewrite(self, op, x):

onnxscript/rewriter/ort_fusions/gelu_test.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def gelu_model(x):
2828
t4 = op.Mul(_sqrt_two_over_pi, t3)
2929
t5 = op.Tanh(t4)
3030
t6 = op.Add(t5, 1)
31-
t7 = op.Mul(x, t6)
32-
result = op.Mul(0.5, t7)
31+
t7 = op.Mul(0.5, t6)
32+
result = op.Mul(x, t7)
3333
return result
3434

3535
model_proto = gelu_model.to_model_proto(

0 commit comments

Comments
 (0)