Skip to content

Add support for onnx fusions #2412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 1, 2025
31 changes: 31 additions & 0 deletions onnxscript/_framework_apis/torch_2_9.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Stable APIs for PyTorch 2.9."""

from __future__ import annotations

__all__ = [
"check_model",
"convert_version",
"get_torchlib_ops",
"optimize",
"save_model_with_external_data",
]

import onnx_ir as ir

from onnxscript import optimizer
from onnxscript._framework_apis.torch_2_6 import (
check_model,
convert_version,
get_torchlib_ops,
save_model_with_external_data,
)
from onnxscript.rewriter import onnx_fusions


def optimize(model: ir.Model) -> ir.Model:
"""Optimize the model."""
optimizer.optimize_ir(model)
onnx_fusions.fuse(model)
return model

Check warning on line 31 in onnxscript/_framework_apis/torch_2_9.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/_framework_apis/torch_2_9.py#L29-L31

Added lines #L29 - L31 were not covered by tests
9 changes: 9 additions & 0 deletions onnxscript/rewriter/onnx_fusions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter.onnx_fusions._onnx_fusions import fuse

__all__ = [
"fuse",
]
34 changes: 34 additions & 0 deletions onnxscript/rewriter/onnx_fusions/_onnx_fusions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnx_ir as ir

from onnxscript.rewriter.onnx_fusions import _rms_normalization


def _get_onnx_opset_version(model: ir.Model) -> int | None:
"""Get the ONNX opset version imported by the model."""
model_version1 = model.opset_imports.get("")
model_version2 = model.opset_imports.get("ai.onnx")
if model_version1 is not None and model_version2 is not None:
if model_version1 != model_version2:
raise ValueError(

Check warning on line 16 in onnxscript/rewriter/onnx_fusions/_onnx_fusions.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_onnx_fusions.py#L16

Added line #L16 was not covered by tests
f"Model imports multiple onnx opsets: {model_version1} and {model_version2}."
)
return model_version1 or model_version2


def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]:
"""Apply fusions targeting ONNX opset 23."""
counts: dict[str, int] = {}
counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug)
return counts


def fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]:
"""Apply fusions targeting ONNX ops."""
model_opset_version = _get_onnx_opset_version(model)
if model_opset_version == 23:
return _opset_23_fuse(model, debug=debug)
return {}

Check warning on line 34 in onnxscript/rewriter/onnx_fusions/_onnx_fusions.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_onnx_fusions.py#L34

Added line #L34 was not covered by tests
40 changes: 40 additions & 0 deletions onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnx_ir as ir

import onnxscript
import onnxscript.rewriter.onnx_fusions as onnx_fusions


class OnnxFusionsTest(unittest.TestCase):
def test_rms_normalization_fusion(self):
opset23 = onnxscript.values.Opset("", 23)

@onnxscript.script()
def rms_norm_script(embedding, layernorm_weight):
two = opset23.Constant(value_float=2.0)
pow_1 = opset23.Pow(embedding, two)
mean = opset23.ReduceMean(pow_1, [-1], keepdims=1, noop_with_empty_axes=0)
epsilon = opset23.Constant(value_float=1e-05)
add_1 = opset23.Add(mean, epsilon)
val_244 = opset23.Sqrt(add_1)
rsqrt = opset23.Reciprocal(val_244)
mul_3 = opset23.Mul(embedding, rsqrt)
mul_4 = opset23.Mul(layernorm_weight, mul_3)
return mul_4

Check warning on line 28 in onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py#L19-L28

Added lines #L19 - L28 were not covered by tests

rms_norm_model_proto = rms_norm_script.to_model_proto(
input_types=[onnxscript.FLOAT[128], onnxscript.FLOAT[128]],
output_types=[onnxscript.FLOAT[128]],
)
model = ir.serde.deserialize_model(rms_norm_model_proto)
onnx_fusions.fuse(model, debug=True)
self.assertEqual(model.graph.node(-1).op_type, "RMSNormalization")


if __name__ == "__main__":
unittest.main()

Check warning on line 40 in onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py#L40

Added line #L40 was not covered by tests
84 changes: 84 additions & 0 deletions onnxscript/rewriter/onnx_fusions/_rms_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern

"""
RMS Normalization: ONNX Opset 23 op
See: https://onnx.ai/onnx/operators/onnx__RMSNormalization.html#l-onnx-doc-rmsnormalization


Key points for the fusion optimization:
* Input and scale are allowed to be of different types.
* The normalization of the input can be done in a different precision than the input type,
indicated by stash_type.
* Input (x) must be: float or double or float16 or bfloat16
* Scale must be: float or double or float16 or bfloat16
"""

float_types = frozenset(
[
ir.DataType.FLOAT,
ir.DataType.FLOAT16,
ir.DataType.BFLOAT16,
ir.DataType.DOUBLE,
]
)
fp_float_types = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE])


class RmsNormFusion(pattern.RewriteRuleClassBase):
def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
x = pattern.OrValue([op.Cast(x, to=compute_dtype), x])
x_square = op.Pow(x, 2.0)
mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0)
mean_square_plus_epsilon = op.Add(mean_square, epsilon)
rms = op.Sqrt(mean_square_plus_epsilon)
reciprocal_rms = op.Reciprocal(rms)
normalized = op.Mul(x, reciprocal_rms)
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
return op.Mul(scale, normalized)

def check(
self, op, x, scale, epsilon, compute_dtype, target_dtype, **_
) -> pattern.MatchResult: # type: ignore[name-defined]
"""Check if the pattern matches conditions for use of SimplifiedLayerNormalization op."""
check_result = pattern.MatchResult()
# epsilon must be a scalar
epsilon_value = _ir_utils.get_singleton_value(epsilon)
if not isinstance(epsilon_value, float): # TODO: support other types
return check_result.fail("Epsilon is not a float value.", epsilon)

Check warning on line 52 in onnxscript/rewriter/onnx_fusions/_rms_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_rms_normalization.py#L52

Added line #L52 was not covered by tests
if x.dtype not in float_types:
return check_result.fail("Input is not a supported float type.", x)

Check warning on line 54 in onnxscript/rewriter/onnx_fusions/_rms_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_rms_normalization.py#L54

Added line #L54 was not covered by tests
if scale.dtype not in float_types:
return check_result.fail("Scale is not a supported float type.", scale)

Check warning on line 56 in onnxscript/rewriter/onnx_fusions/_rms_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_rms_normalization.py#L56

Added line #L56 was not covered by tests
self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype
if self._stash_dtype not in fp_float_types:
# TODO: ONNX documentation does not specify restrictions on stash_type, though
# ORT's SimplifiedLayerNormalization requires it to be float or double.
return check_result.fail("Normalization precision is not a float or double type.")

Check warning on line 61 in onnxscript/rewriter/onnx_fusions/_rms_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_rms_normalization.py#L61

Added line #L61 was not covered by tests
# target_dtype is guaranteed to be the same as scale type in a well-typed input
# for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input.
# TODO (rama): Consider adding checks to protect against incorrectly typed models:
return check_result

def rewrite(self, op, x, scale, epsilon, **_):
# Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake.
# No need to use com.microsoft domain here; but this is a custom op in ORT.
return op.RMSNormalization(
x,
scale,
axis=-1,
epsilon=_ir_utils.get_singleton_value(epsilon),
stash_type=self._stash_dtype,
)


_rule = RmsNormFusion.rule()
rms_normalization_rules = [_rule]
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)


fuse_rms_normalization = _fusion_utils.apply_fusion_rules(rms_normalization_ruleset)
58 changes: 23 additions & 35 deletions onnxscript/rewriter/ort_fusions/rms_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,76 +18,64 @@
* Normalization precision must be float or double
"""

float_types = [
ir.DataType.FLOAT,
ir.DataType.FLOAT16,
ir.DataType.BFLOAT16,
ir.DataType.DOUBLE,
]
fp_float_types = [ir.DataType.FLOAT, ir.DataType.DOUBLE]
float_types = frozenset(
[
ir.DataType.FLOAT,
ir.DataType.FLOAT16,
ir.DataType.BFLOAT16,
ir.DataType.DOUBLE,
]
)
fp_float_types = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE])


class RmsNormFusion(pattern.RewriteRuleClassBase):
def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool):
"""
Args:
name: Name of the rule.
cast_input: Whether to cast input to do the normalization in a different precision.
cast_normalized: Whether to cast the normalized output to the target dtype (same as scale).
"""
super().__init__(name=name)
self._cast_input = cast_input
self._cast_normalized = cast_normalized

def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
if self._cast_input:
x = op.Cast(x, to=compute_dtype)
x = pattern.OrValue([op.Cast(x, to=compute_dtype), x])
x_square = op.Pow(x, 2.0)
mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0)
mean_square_plus_epsilon = op.Add(mean_square, epsilon)
rms = op.Sqrt(mean_square_plus_epsilon)
reciprocal_rms = op.Reciprocal(rms)
normalized = op.Mul(x, reciprocal_rms)
if self._cast_normalized:
normalized = op.Cast(normalized, to=target_dtype)
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
return op.Mul(scale, normalized)

def check(self, op, x, scale, epsilon, compute_dtype, target_dtype) -> pattern.MatchResult: # type: ignore[name-defined]
def check(
self, op, x, scale, epsilon, compute_dtype, target_dtype, **_
) -> pattern.MatchResult: # type: ignore[name-defined]
"""Check if the pattern matches conditions for use of SimplifiedLayerNormalization op."""
check_result = pattern.MatchResult()
# epsilon must be a scalar
epsilon_value = _ir_utils.get_singleton_value(epsilon)
if not isinstance(epsilon_value, float): # TODO: support other types
return check_result.fail("Epsilon is not a float value.", epsilon)
# input and output must be same dtype
if x.dtype not in float_types:
return check_result.fail("Input is not a float type.", x)
if scale.dtype not in float_types:
return check_result.fail("Scale is not a float type.", scale)
stash_dtype = compute_dtype.value if self._cast_input else x.dtype
if stash_dtype not in fp_float_types:
self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype
if self._stash_dtype not in fp_float_types:
return check_result.fail("Normalization precision is not a float or double type.")
# target_dtype is guaranteed to be the same as scale type in a well-typed input
# for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input.
# TODO (rama): Consider adding checks to protect against incorrectly typed models:
return check_result

def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):
stash_dtype = compute_dtype.value if self._cast_input else x.dtype
def rewrite(self, op, x, scale, epsilon, **_):
# Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake.
# No need to use com.microsoft domain here; but this is a custom op in ORT.
return op.SimplifiedLayerNormalization(
x,
scale,
axis=-1,
epsilon=_ir_utils.get_singleton_value(epsilon),
stash_type=stash_dtype,
stash_type=self._stash_dtype,
)


_rule_0 = RmsNormFusion.rule("RmsNorm-0", cast_input=True, cast_normalized=True)
_rule_1 = RmsNormFusion.rule("RmsNorm-1", cast_input=False, cast_normalized=True)
_rule_2 = RmsNormFusion.rule("RmsNorm-2", cast_input=True, cast_normalized=False)
_rule_3 = RmsNormFusion.rule("RmsNorm-3", cast_input=False, cast_normalized=False)

rms_normalization_rules = [_rule_0, _rule_1, _rule_2, _rule_3]
_rule = RmsNormFusion.rule()
rms_normalization_rules = [_rule]
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)


Expand Down
Loading