From 77ca9efc938924c1e13bb14788fd174a1500a890 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 25 Mar 2025 22:53:54 +0000 Subject: [PATCH 1/2] respect ops order in torchscript graph --- onnxscript/function_libs/torch_lib/ops/nn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index cfab834d6e..4c32f975d5 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -487,8 +487,8 @@ def _aten_gelu_approximate_none(self: TReal) -> TReal: inner = op.Div(self, 1.4142135623730951) erf = op.Erf(inner) inner = op.Add(erf, 1) - inner = op.Mul(self, inner) - result = op.Mul(0.5, inner) + inner = op.Mul(0.5, inner) + result = op.Mul(self, inner) return result @@ -505,8 +505,8 @@ def _aten_gelu_approximate_tanh(self: TReal) -> TReal: inner = op.Mul(op.Sqrt(two_over_pi), inner) inner = op.Tanh(inner) inner = op.Add(inner, 1) - inner = op.Mul(self, inner) - result = op.Mul(0.5, inner) + inner = op.Mul(0.5, inner) + result = op.Mul(self, inner) return result From f36b6f6fe1cf55cd9d095471f525ee52aa18d91e Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 25 Mar 2025 22:58:45 +0000 Subject: [PATCH 2/2] fix pattern --- onnxscript/rewriter/ort_fusions/gelu.py | 4 ++-- onnxscript/rewriter/ort_fusions/gelu_test.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/gelu.py b/onnxscript/rewriter/ort_fusions/gelu.py index f1c47e91f6..20bfdcb7de 100644 --- a/onnxscript/rewriter/ort_fusions/gelu.py +++ b/onnxscript/rewriter/ort_fusions/gelu.py @@ -20,8 +20,8 @@ def pattern(self, op, x): t4 = op.Mul(_sqrt_two_over_pi, t3) t5 = op.Tanh(t4) t6 = op.Add(t5, 1) - t7 = op.Mul(x, t6) - result = op.Mul(0.5, t7) + t7 = op.Mul(0.5, t6) + result = op.Mul(x, t7) return result def rewrite(self, op, x): diff --git a/onnxscript/rewriter/ort_fusions/gelu_test.py b/onnxscript/rewriter/ort_fusions/gelu_test.py index 193bf7e3c2..e509ce1454 100644 --- a/onnxscript/rewriter/ort_fusions/gelu_test.py +++ b/onnxscript/rewriter/ort_fusions/gelu_test.py @@ -28,8 +28,8 @@ def gelu_model(x): t4 = op.Mul(_sqrt_two_over_pi, t3) t5 = op.Tanh(t4) t6 = op.Add(t5, 1) - t7 = op.Mul(x, t6) - result = op.Mul(0.5, t7) + t7 = op.Mul(0.5, t6) + result = op.Mul(x, t7) return result model_proto = gelu_model.to_model_proto(