From 97119bd3c90f5b872de2d81241cc9f49eaba85d3 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 23 Jun 2025 11:49:16 -0700 Subject: [PATCH 1/7] Add support for onnx fusions Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/onnx_fusions/_core.py | 37 +++++++++ .../onnx_fusions/rms_normalization.py | 80 +++++++++++++++++++ .../rewriter/ort_fusions/rms_normalization.py | 42 ++++------ 3 files changed, 131 insertions(+), 28 deletions(-) create mode 100644 onnxscript/rewriter/onnx_fusions/_core.py create mode 100644 onnxscript/rewriter/onnx_fusions/rms_normalization.py diff --git a/onnxscript/rewriter/onnx_fusions/_core.py b/onnxscript/rewriter/onnx_fusions/_core.py new file mode 100644 index 0000000000..7be2550e1c --- /dev/null +++ b/onnxscript/rewriter/onnx_fusions/_core.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnx_ir as ir +from onnxscript.rewriter.onnx_fusions import rms_normalization + + +def _get_onnx_opset_version(model: ir.Model) -> int | None: + """Get the ONNX opset version imported by the model.""" + model_version1 = model.opset_imports.get("") + model_version2 = model.opset_imports.get("ai.onnx") + if model_version1 is not None and model_version2 is not None: + if model_version1 != model_version2: + raise ValueError( + f"Model imports multiple onnx opsets: {model_version1} and {model_version2}." + ) + return model_version1 or model_version2 + + +def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: + """ + Apply fusions targetting ONNX opset 23. + """ + counts: dict[str, int] = {} + counts["RMSNormalization"] = rms_normalization.fuse( + model, debug=debug + ) + + +def fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: + """ + Apply fusions targetting ONNX ops. + """ + model_opset_version = _get_onnx_opset_version(model) + if model_opset_version == 23: + return _opset_23_fuse(model, debug=debug) \ No newline at end of file diff --git a/onnxscript/rewriter/onnx_fusions/rms_normalization.py b/onnxscript/rewriter/onnx_fusions/rms_normalization.py new file mode 100644 index 0000000000..60721f9100 --- /dev/null +++ b/onnxscript/rewriter/onnx_fusions/rms_normalization.py @@ -0,0 +1,80 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnxscript.ir as ir +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +""" +RMS Normalization: ONNX Opset 23 op +See: https://onnx.ai/onnx/operators/onnx__RMSNormalization.html#l-onnx-doc-rmsnormalization + + +Key points for the fusion optimization: +* Input and scale are allowed to be of different types. +* The normalization of the input can be done in a different precision than the input type, +indicated by stash_type. +* Input (x) must be: float or double or float16 or bfloat16 +* Scale must be: float or double or float16 or bfloat16 +""" + +float_types = [ + ir.DataType.FLOAT, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + ir.DataType.DOUBLE, +] +fp_float_types = [ir.DataType.FLOAT, ir.DataType.DOUBLE] + + +class RmsNormFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): + x = pattern.OrValue([op.Cast(x, to=compute_dtype), x]) + x_square = op.Pow(x, 2.0) + mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0) + mean_square_plus_epsilon = op.Add(mean_square, epsilon) + rms = op.Sqrt(mean_square_plus_epsilon) + reciprocal_rms = op.Reciprocal(rms) + normalized = op.Mul(x, reciprocal_rms) + normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized]) + return op.Mul(scale, normalized) + + def check(self, op, x, scale, epsilon, compute_dtype, target_dtype, **_) -> pattern.MatchResult: # type: ignore[name-defined] + """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" + check_result = pattern.MatchResult() + # epsilon must be a scalar + epsilon_value = _ir_utils.get_singleton_value(epsilon) + if not isinstance(epsilon_value, float): # TODO: support other types + return check_result.fail("Epsilon is not a float value.", epsilon) + if x.dtype not in float_types: + return check_result.fail("Input is not a supported float type.", x) + if scale.dtype not in float_types: + return check_result.fail("Scale is not a supported float type.", scale) + self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype + if self._stash_dtype not in fp_float_types: + # TODO: ONNX documentation does not specify restrictions on stash_type, though + # ORT's SimplifiedLayerNormalization requires it to be float or double. + return check_result.fail("Normalization precision is not a float or double type.") + # target_dtype is guaranteed to be the same as scale type in a well-typed input + # for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input. + # TODO (rama): Consider adding checks to protect against incorrectly typed models: + return check_result + + def rewrite(self, op, x, scale, epsilon, **_): + # Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake. + # No need to use com.microsoft domain here; but this is a custom op in ORT. + return op.RMSNormalization( + x, + scale, + axis=-1, + epsilon=_ir_utils.get_singleton_value(epsilon), + stash_type=self._stash_dtype, + ) + + +_rule = RmsNormFusion.rule() +rms_normalization_rules = [_rule] +rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules) + + +fuse_rms_normalization = _fusion_utils.apply_fusion_rules(rms_normalization_ruleset) diff --git a/onnxscript/rewriter/ort_fusions/rms_normalization.py b/onnxscript/rewriter/ort_fusions/rms_normalization.py index 916ce1be12..82bd022d9b 100644 --- a/onnxscript/rewriter/ort_fusions/rms_normalization.py +++ b/onnxscript/rewriter/ort_fusions/rms_normalization.py @@ -28,49 +28,39 @@ class RmsNormFusion(pattern.RewriteRuleClassBase): - def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool): - """ - Args: - name: Name of the rule. - cast_input: Whether to cast input to do the normalization in a different precision. - cast_normalized: Whether to cast the normalized output to the target dtype (same as scale). - """ - super().__init__(name=name) - self._cast_input = cast_input - self._cast_normalized = cast_normalized - def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): - if self._cast_input: - x = op.Cast(x, to=compute_dtype) + x = pattern.OrValue([op.Cast(x, to=compute_dtype), x]) x_square = op.Pow(x, 2.0) mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0) mean_square_plus_epsilon = op.Add(mean_square, epsilon) rms = op.Sqrt(mean_square_plus_epsilon) reciprocal_rms = op.Reciprocal(rms) normalized = op.Mul(x, reciprocal_rms) - if self._cast_normalized: - normalized = op.Cast(normalized, to=target_dtype) + normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized]) return op.Mul(scale, normalized) - def check(self, op, x, scale, epsilon, compute_dtype, target_dtype) -> pattern.MatchResult: # type: ignore[name-defined] + def check( + self, op, x, scale, epsilon, compute_dtype, target_dtype, **_ + ) -> pattern.MatchResult: # type: ignore[name-defined] """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" check_result = pattern.MatchResult() # epsilon must be a scalar epsilon_value = _ir_utils.get_singleton_value(epsilon) if not isinstance(epsilon_value, float): # TODO: support other types return check_result.fail("Epsilon is not a float value.", epsilon) - # input and output must be same dtype if x.dtype not in float_types: return check_result.fail("Input is not a float type.", x) if scale.dtype not in float_types: return check_result.fail("Scale is not a float type.", scale) - stash_dtype = compute_dtype.value if self._cast_input else x.dtype - if stash_dtype not in fp_float_types: + self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype + if self._stash_dtype not in fp_float_types: return check_result.fail("Normalization precision is not a float or double type.") + # target_dtype is guaranteed to be the same as scale type in a well-typed input + # for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input. + # TODO (rama): Consider adding checks to protect against incorrectly typed models: return check_result - def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype): - stash_dtype = compute_dtype.value if self._cast_input else x.dtype + def rewrite(self, op, x, scale, epsilon, **_): # Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake. # No need to use com.microsoft domain here; but this is a custom op in ORT. return op.SimplifiedLayerNormalization( @@ -78,16 +68,12 @@ def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype): scale, axis=-1, epsilon=_ir_utils.get_singleton_value(epsilon), - stash_type=stash_dtype, + stash_type=self._stash_dtype, ) -_rule_0 = RmsNormFusion.rule("RmsNorm-0", cast_input=True, cast_normalized=True) -_rule_1 = RmsNormFusion.rule("RmsNorm-1", cast_input=False, cast_normalized=True) -_rule_2 = RmsNormFusion.rule("RmsNorm-2", cast_input=True, cast_normalized=False) -_rule_3 = RmsNormFusion.rule("RmsNorm-3", cast_input=False, cast_normalized=False) - -rms_normalization_rules = [_rule_0, _rule_1, _rule_2, _rule_3] +_rule = RmsNormFusion.rule() +rms_normalization_rules = [_rule] rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules) From 865e9c071bc4aa0e726eb17902e82be3564ffb4a Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 23 Jun 2025 22:36:05 -0700 Subject: [PATCH 2/7] Add test case (partial) Signed-off-by: Ganesan Ramalingam --- .../{_core.py => _onnx_fusions.py} | 0 .../onnx_fusions/_onnx_fusions_test.py | 26 +++++++++++++++++++ 2 files changed, 26 insertions(+) rename onnxscript/rewriter/onnx_fusions/{_core.py => _onnx_fusions.py} (100%) create mode 100644 onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py diff --git a/onnxscript/rewriter/onnx_fusions/_core.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py similarity index 100% rename from onnxscript/rewriter/onnx_fusions/_core.py rename to onnxscript/rewriter/onnx_fusions/_onnx_fusions.py diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py new file mode 100644 index 0000000000..55ea119c2c --- /dev/null +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py @@ -0,0 +1,26 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest +import onnxscript + + +class OnnxFusionsTest(unittest.TestCase): + def test_rms_normalization_fusion(self): + opset23 = onnxscript.values.Opset("", 23) + @onnxscript.script() + def rms_norm_script(embedding): + two = opset23.Constant(value_float=2.0) + pow_1 = opset23.Pow(embedding, two) + mean = opset23.ReduceMean(pow_1, [-1], keepdims=1, noop_with_empty_axes=0) + epsilon = opset23.Constant(value_float=1e-05) + add_1 = opset23.Add(mean, epsilon) + val_244 = opset23.Sqrt(add_1) + rsqrt = opset23.Reciprocal(val_244) + mul_3 = opset23.Mul(embedding, rsqrt) + rms_norm_model = rms_norm_script.to_model_proto() + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file From c1e346d20080f7f6a41ea27ab36e420438f38158 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 24 Jun 2025 15:48:30 -0700 Subject: [PATCH 3/7] Finish test case Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/onnx_fusions/__init__.py | 10 ++++++++++ .../rewriter/onnx_fusions/_onnx_fusions.py | 15 +++++--------- .../onnx_fusions/_onnx_fusions_test.py | 20 ++++++++++++++++--- .../onnx_fusions/rms_normalization.py | 4 +++- 4 files changed, 35 insertions(+), 14 deletions(-) create mode 100644 onnxscript/rewriter/onnx_fusions/__init__.py diff --git a/onnxscript/rewriter/onnx_fusions/__init__.py b/onnxscript/rewriter/onnx_fusions/__init__.py new file mode 100644 index 0000000000..250631a7c0 --- /dev/null +++ b/onnxscript/rewriter/onnx_fusions/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter.onnx_fusions._onnx_fusions import fuse + +__all__ = [ + "fuse", +] + diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py index 7be2550e1c..1d864d20dd 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py @@ -3,6 +3,7 @@ from __future__ import annotations import onnx_ir as ir + from onnxscript.rewriter.onnx_fusions import rms_normalization @@ -19,19 +20,13 @@ def _get_onnx_opset_version(model: ir.Model) -> int | None: def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: - """ - Apply fusions targetting ONNX opset 23. - """ + """Apply fusions targetting ONNX opset 23.""" counts: dict[str, int] = {} - counts["RMSNormalization"] = rms_normalization.fuse( - model, debug=debug - ) + counts["RMSNormalization"] = rms_normalization.fuse_rms_normalization(model, debug=debug) def fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: - """ - Apply fusions targetting ONNX ops. - """ + """Apply fusions targetting ONNX ops.""" model_opset_version = _get_onnx_opset_version(model) if model_opset_version == 23: - return _opset_23_fuse(model, debug=debug) \ No newline at end of file + return _opset_23_fuse(model, debug=debug) diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py index 55ea119c2c..dfd9ca4296 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py @@ -3,14 +3,19 @@ from __future__ import annotations import unittest + +import onnx_ir as ir + import onnxscript +import onnxscript.rewriter.onnx_fusions as onnx_fusions class OnnxFusionsTest(unittest.TestCase): def test_rms_normalization_fusion(self): opset23 = onnxscript.values.Opset("", 23) + @onnxscript.script() - def rms_norm_script(embedding): + def rms_norm_script(embedding, layernorm_weight): two = opset23.Constant(value_float=2.0) pow_1 = opset23.Pow(embedding, two) mean = opset23.ReduceMean(pow_1, [-1], keepdims=1, noop_with_empty_axes=0) @@ -19,8 +24,17 @@ def rms_norm_script(embedding): val_244 = opset23.Sqrt(add_1) rsqrt = opset23.Reciprocal(val_244) mul_3 = opset23.Mul(embedding, rsqrt) - rms_norm_model = rms_norm_script.to_model_proto() + mul_4 = opset23.Mul(layernorm_weight, mul_3) + return mul_4 + + rms_norm_model_proto = rms_norm_script.to_model_proto( + input_types=[onnxscript.FLOAT[128], onnxscript.FLOAT[128]], + output_types=[onnxscript.FLOAT[128]], + ) + model = ir.serde.deserialize_model(rms_norm_model_proto) + onnx_fusions.fuse(model, debug=True) + self.assertEqual(model.graph.node(-1).op_type, "RMSNormalization") if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/onnxscript/rewriter/onnx_fusions/rms_normalization.py b/onnxscript/rewriter/onnx_fusions/rms_normalization.py index 60721f9100..2c9c78c1d5 100644 --- a/onnxscript/rewriter/onnx_fusions/rms_normalization.py +++ b/onnxscript/rewriter/onnx_fusions/rms_normalization.py @@ -39,7 +39,9 @@ def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized]) return op.Mul(scale, normalized) - def check(self, op, x, scale, epsilon, compute_dtype, target_dtype, **_) -> pattern.MatchResult: # type: ignore[name-defined] + def check( + self, op, x, scale, epsilon, compute_dtype, target_dtype, **_ + ) -> pattern.MatchResult: # type: ignore[name-defined] """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" check_result = pattern.MatchResult() # epsilon must be a scalar From fcb91aa036ff2dc3c679cde375362f4c0079d95d Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 24 Jun 2025 15:56:43 -0700 Subject: [PATCH 4/7] Update torch framework api Signed-off-by: Ganesan Ramalingam --- onnxscript/_framework_apis/torch_2_9.py | 28 +++++++++++++++++++ onnxscript/rewriter/onnx_fusions/__init__.py | 1 - .../rewriter/onnx_fusions/_onnx_fusions.py | 4 +-- ...normalization.py => _rms_normalization.py} | 0 4 files changed, 30 insertions(+), 3 deletions(-) create mode 100644 onnxscript/_framework_apis/torch_2_9.py rename onnxscript/rewriter/onnx_fusions/{rms_normalization.py => _rms_normalization.py} (100%) diff --git a/onnxscript/_framework_apis/torch_2_9.py b/onnxscript/_framework_apis/torch_2_9.py new file mode 100644 index 0000000000..e26e4bf22a --- /dev/null +++ b/onnxscript/_framework_apis/torch_2_9.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Stable APIs for PyTorch 2.9.""" + +from __future__ import annotations + +__all__ = [ + "check_model", + "convert_version", + "get_torchlib_ops", + "optimize", + "save_model_with_external_data", +] + +from onnxscript._framework_apis.torch_2_6 import ( + check_model, + convert_version, + get_torchlib_ops, + save_model_with_external_data, +) +from onnxscript import optimizer +from onnxscript.rewriter import onnx_fusions + +def optimize(model: ir.Model) -> ir.Model: + """Optimize the model.""" + optimizer.optimize_ir(model) + onnx_fusions.fuse(model) + return model diff --git a/onnxscript/rewriter/onnx_fusions/__init__.py b/onnxscript/rewriter/onnx_fusions/__init__.py index 250631a7c0..d2e8d885f0 100644 --- a/onnxscript/rewriter/onnx_fusions/__init__.py +++ b/onnxscript/rewriter/onnx_fusions/__init__.py @@ -7,4 +7,3 @@ __all__ = [ "fuse", ] - diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py index 1d864d20dd..4b06e965c4 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py @@ -4,7 +4,7 @@ import onnx_ir as ir -from onnxscript.rewriter.onnx_fusions import rms_normalization +from onnxscript.rewriter.onnx_fusions import _rms_normalization def _get_onnx_opset_version(model: ir.Model) -> int | None: @@ -22,7 +22,7 @@ def _get_onnx_opset_version(model: ir.Model) -> int | None: def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: """Apply fusions targetting ONNX opset 23.""" counts: dict[str, int] = {} - counts["RMSNormalization"] = rms_normalization.fuse_rms_normalization(model, debug=debug) + counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug) def fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: diff --git a/onnxscript/rewriter/onnx_fusions/rms_normalization.py b/onnxscript/rewriter/onnx_fusions/_rms_normalization.py similarity index 100% rename from onnxscript/rewriter/onnx_fusions/rms_normalization.py rename to onnxscript/rewriter/onnx_fusions/_rms_normalization.py From 3d341f214d6a93d679109220dc0b5c73ca438828 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 24 Jun 2025 16:02:55 -0700 Subject: [PATCH 5/7] Fix lint error Signed-off-by: Ganesan Ramalingam --- onnxscript/_framework_apis/torch_2_9.py | 5 ++++- .../rewriter/onnx_fusions/_onnx_fusions.py | 1 + .../rewriter/onnx_fusions/_rms_normalization.py | 16 +++++++++------- .../rewriter/ort_fusions/rms_normalization.py | 16 +++++++++------- 4 files changed, 23 insertions(+), 15 deletions(-) diff --git a/onnxscript/_framework_apis/torch_2_9.py b/onnxscript/_framework_apis/torch_2_9.py index e26e4bf22a..6c9ed6355b 100644 --- a/onnxscript/_framework_apis/torch_2_9.py +++ b/onnxscript/_framework_apis/torch_2_9.py @@ -12,15 +12,18 @@ "save_model_with_external_data", ] +import onnx_ir as ir + +from onnxscript import optimizer from onnxscript._framework_apis.torch_2_6 import ( check_model, convert_version, get_torchlib_ops, save_model_with_external_data, ) -from onnxscript import optimizer from onnxscript.rewriter import onnx_fusions + def optimize(model: ir.Model) -> ir.Model: """Optimize the model.""" optimizer.optimize_ir(model) diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py index 4b06e965c4..ffe29957c5 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py @@ -30,3 +30,4 @@ def fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: model_opset_version = _get_onnx_opset_version(model) if model_opset_version == 23: return _opset_23_fuse(model, debug=debug) + return {} diff --git a/onnxscript/rewriter/onnx_fusions/_rms_normalization.py b/onnxscript/rewriter/onnx_fusions/_rms_normalization.py index 2c9c78c1d5..dc7d1bc971 100644 --- a/onnxscript/rewriter/onnx_fusions/_rms_normalization.py +++ b/onnxscript/rewriter/onnx_fusions/_rms_normalization.py @@ -18,13 +18,15 @@ * Scale must be: float or double or float16 or bfloat16 """ -float_types = [ - ir.DataType.FLOAT, - ir.DataType.FLOAT16, - ir.DataType.BFLOAT16, - ir.DataType.DOUBLE, -] -fp_float_types = [ir.DataType.FLOAT, ir.DataType.DOUBLE] +float_types = frozenset( + [ + ir.DataType.FLOAT, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + ir.DataType.DOUBLE, + ] +) +fp_float_types = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE]) class RmsNormFusion(pattern.RewriteRuleClassBase): diff --git a/onnxscript/rewriter/ort_fusions/rms_normalization.py b/onnxscript/rewriter/ort_fusions/rms_normalization.py index 82bd022d9b..9aa6210401 100644 --- a/onnxscript/rewriter/ort_fusions/rms_normalization.py +++ b/onnxscript/rewriter/ort_fusions/rms_normalization.py @@ -18,13 +18,15 @@ * Normalization precision must be float or double """ -float_types = [ - ir.DataType.FLOAT, - ir.DataType.FLOAT16, - ir.DataType.BFLOAT16, - ir.DataType.DOUBLE, -] -fp_float_types = [ir.DataType.FLOAT, ir.DataType.DOUBLE] +float_types = frozenset( + [ + ir.DataType.FLOAT, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + ir.DataType.DOUBLE, + ] +) +fp_float_types = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE]) class RmsNormFusion(pattern.RewriteRuleClassBase): From 13288d71269df3cc216c17a5f86798b61cf8443c Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 25 Jun 2025 17:05:25 -0700 Subject: [PATCH 6/7] Minor fixes Signed-off-by: Ganesan Ramalingam --- onnxscript/rewriter/onnx_fusions/_onnx_fusions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py index ffe29957c5..96446e6fb4 100644 --- a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py @@ -20,13 +20,14 @@ def _get_onnx_opset_version(model: ir.Model) -> int | None: def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: - """Apply fusions targetting ONNX opset 23.""" + """Apply fusions targeting ONNX opset 23.""" counts: dict[str, int] = {} counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug) + return counts def fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: - """Apply fusions targetting ONNX ops.""" + """Apply fusions targeting ONNX ops.""" model_opset_version = _get_onnx_opset_version(model) if model_opset_version == 23: return _opset_23_fuse(model, debug=debug) From 14c7b3017cf89a39937d0d965dbe568c8b1fce3b Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 26 Jun 2025 17:17:51 -0700 Subject: [PATCH 7/7] Merge with main Signed-off-by: Ganesan Ramalingam --- onnxscript/_framework_apis/torch_2_8.py | 12 +++++++++- onnxscript/_framework_apis/torch_2_9.py | 31 ------------------------- 2 files changed, 11 insertions(+), 32 deletions(-) delete mode 100644 onnxscript/_framework_apis/torch_2_9.py diff --git a/onnxscript/_framework_apis/torch_2_8.py b/onnxscript/_framework_apis/torch_2_8.py index ee5e6089e5..bbd1ffc786 100644 --- a/onnxscript/_framework_apis/torch_2_8.py +++ b/onnxscript/_framework_apis/torch_2_8.py @@ -12,10 +12,20 @@ "save_model_with_external_data", ] +import onnx_ir as ir + +import onnxscript.optimizer +import onnxscript.rewriter.onnx_fusions from onnxscript._framework_apis.torch_2_6 import ( check_model, convert_version, get_torchlib_ops, - optimize, save_model_with_external_data, ) + + +def optimize(model: ir.Model) -> ir.Model: + """Optimize the model.""" + onnxscript.optimizer.optimize_ir(model) + onnxscript.rewriter.onnx_fusions.fuse(model) + return model diff --git a/onnxscript/_framework_apis/torch_2_9.py b/onnxscript/_framework_apis/torch_2_9.py deleted file mode 100644 index 6c9ed6355b..0000000000 --- a/onnxscript/_framework_apis/torch_2_9.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Stable APIs for PyTorch 2.9.""" - -from __future__ import annotations - -__all__ = [ - "check_model", - "convert_version", - "get_torchlib_ops", - "optimize", - "save_model_with_external_data", -] - -import onnx_ir as ir - -from onnxscript import optimizer -from onnxscript._framework_apis.torch_2_6 import ( - check_model, - convert_version, - get_torchlib_ops, - save_model_with_external_data, -) -from onnxscript.rewriter import onnx_fusions - - -def optimize(model: ir.Model) -> ir.Model: - """Optimize the model.""" - optimizer.optimize_ir(model) - onnx_fusions.fuse(model) - return model