Skip to content

Add support for onnx fusions #2412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 1, 2025
37 changes: 37 additions & 0 deletions onnxscript/rewriter/onnx_fusions/_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnx_ir as ir
from onnxscript.rewriter.onnx_fusions import rms_normalization


def _get_onnx_opset_version(model: ir.Model) -> int | None:
"""Get the ONNX opset version imported by the model."""
model_version1 = model.opset_imports.get("")
model_version2 = model.opset_imports.get("ai.onnx")

Check warning on line 12 in onnxscript/rewriter/onnx_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_core.py#L11-L12

Added lines #L11 - L12 were not covered by tests
if model_version1 is not None and model_version2 is not None:
if model_version1 != model_version2:
raise ValueError(

Check warning on line 15 in onnxscript/rewriter/onnx_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_core.py#L15

Added line #L15 was not covered by tests
f"Model imports multiple onnx opsets: {model_version1} and {model_version2}."
)
return model_version1 or model_version2

Check warning on line 18 in onnxscript/rewriter/onnx_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_core.py#L18

Added line #L18 was not covered by tests


def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]:
"""
Apply fusions targetting ONNX opset 23.

Check warning on line 23 in onnxscript/rewriter/onnx_fusions/_core.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "targetting" is a misspelling of "targeting" Raw Output: ./onnxscript/rewriter/onnx_fusions/_core.py:23:18: "targetting" is a misspelling of "targeting"
"""
counts: dict[str, int] = {}
counts["RMSNormalization"] = rms_normalization.fuse(

Check warning on line 26 in onnxscript/rewriter/onnx_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_core.py#L25-L26

Added lines #L25 - L26 were not covered by tests
model, debug=debug
)


def fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]:
"""
Apply fusions targetting ONNX ops.

Check warning on line 33 in onnxscript/rewriter/onnx_fusions/_core.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "targetting" is a misspelling of "targeting" Raw Output: ./onnxscript/rewriter/onnx_fusions/_core.py:33:18: "targetting" is a misspelling of "targeting"
"""
model_opset_version = _get_onnx_opset_version(model)

Check warning on line 35 in onnxscript/rewriter/onnx_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_core.py#L35

Added line #L35 was not covered by tests
if model_opset_version == 23:
return _opset_23_fuse(model, debug=debug)

Check warning on line 37 in onnxscript/rewriter/onnx_fusions/_core.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/_core.py#L37

Added line #L37 was not covered by tests
80 changes: 80 additions & 0 deletions onnxscript/rewriter/onnx_fusions/rms_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern

"""
RMS Normalization: ONNX Opset 23 op
See: https://onnx.ai/onnx/operators/onnx__RMSNormalization.html#l-onnx-doc-rmsnormalization


Key points for the fusion optimization:
* Input and scale are allowed to be of different types.
* The normalization of the input can be done in a different precision than the input type,
indicated by stash_type.
* Input (x) must be: float or double or float16 or bfloat16
* Scale must be: float or double or float16 or bfloat16
"""

float_types = [
ir.DataType.FLOAT,
ir.DataType.FLOAT16,
ir.DataType.BFLOAT16,
ir.DataType.DOUBLE,
]
fp_float_types = [ir.DataType.FLOAT, ir.DataType.DOUBLE]


class RmsNormFusion(pattern.RewriteRuleClassBase):
def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
x = pattern.OrValue([op.Cast(x, to=compute_dtype), x])
x_square = op.Pow(x, 2.0)
mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0)
mean_square_plus_epsilon = op.Add(mean_square, epsilon)
rms = op.Sqrt(mean_square_plus_epsilon)
reciprocal_rms = op.Reciprocal(rms)
normalized = op.Mul(x, reciprocal_rms)
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
return op.Mul(scale, normalized)

def check(self, op, x, scale, epsilon, compute_dtype, target_dtype, **_) -> pattern.MatchResult: # type: ignore[name-defined]
"""Check if the pattern matches conditions for use of SimplifiedLayerNormalization op."""
check_result = pattern.MatchResult()

Check warning on line 44 in onnxscript/rewriter/onnx_fusions/rms_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/rms_normalization.py#L44

Added line #L44 was not covered by tests
# epsilon must be a scalar
epsilon_value = _ir_utils.get_singleton_value(epsilon)

Check warning on line 46 in onnxscript/rewriter/onnx_fusions/rms_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/rms_normalization.py#L46

Added line #L46 was not covered by tests
if not isinstance(epsilon_value, float): # TODO: support other types
return check_result.fail("Epsilon is not a float value.", epsilon)

Check warning on line 48 in onnxscript/rewriter/onnx_fusions/rms_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/rms_normalization.py#L48

Added line #L48 was not covered by tests
if x.dtype not in float_types:
return check_result.fail("Input is not a supported float type.", x)

Check warning on line 50 in onnxscript/rewriter/onnx_fusions/rms_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/rms_normalization.py#L50

Added line #L50 was not covered by tests
if scale.dtype not in float_types:
return check_result.fail("Scale is not a supported float type.", scale)
self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype

Check warning on line 53 in onnxscript/rewriter/onnx_fusions/rms_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/rms_normalization.py#L52-L53

Added lines #L52 - L53 were not covered by tests
if self._stash_dtype not in fp_float_types:
# TODO: ONNX documentation does not specify restrictions on stash_type, though
# ORT's SimplifiedLayerNormalization requires it to be float or double.
return check_result.fail("Normalization precision is not a float or double type.")

Check warning on line 57 in onnxscript/rewriter/onnx_fusions/rms_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/rms_normalization.py#L57

Added line #L57 was not covered by tests
# target_dtype is guaranteed to be the same as scale type in a well-typed input
# for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input.
# TODO (rama): Consider adding checks to protect against incorrectly typed models:
return check_result

Check warning on line 61 in onnxscript/rewriter/onnx_fusions/rms_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/rms_normalization.py#L61

Added line #L61 was not covered by tests

def rewrite(self, op, x, scale, epsilon, **_):
# Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake.
# No need to use com.microsoft domain here; but this is a custom op in ORT.
return op.RMSNormalization(

Check warning on line 66 in onnxscript/rewriter/onnx_fusions/rms_normalization.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/onnx_fusions/rms_normalization.py#L66

Added line #L66 was not covered by tests
x,
scale,
axis=-1,
epsilon=_ir_utils.get_singleton_value(epsilon),
stash_type=self._stash_dtype,
)


_rule = RmsNormFusion.rule()
rms_normalization_rules = [_rule]
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)


fuse_rms_normalization = _fusion_utils.apply_fusion_rules(rms_normalization_ruleset)
42 changes: 14 additions & 28 deletions onnxscript/rewriter/ort_fusions/rms_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,66 +28,52 @@


class RmsNormFusion(pattern.RewriteRuleClassBase):
def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool):
"""
Args:
name: Name of the rule.
cast_input: Whether to cast input to do the normalization in a different precision.
cast_normalized: Whether to cast the normalized output to the target dtype (same as scale).
"""
super().__init__(name=name)
self._cast_input = cast_input
self._cast_normalized = cast_normalized

def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
if self._cast_input:
x = op.Cast(x, to=compute_dtype)
x = pattern.OrValue([op.Cast(x, to=compute_dtype), x])
x_square = op.Pow(x, 2.0)
mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0)
mean_square_plus_epsilon = op.Add(mean_square, epsilon)
rms = op.Sqrt(mean_square_plus_epsilon)
reciprocal_rms = op.Reciprocal(rms)
normalized = op.Mul(x, reciprocal_rms)
if self._cast_normalized:
normalized = op.Cast(normalized, to=target_dtype)
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
return op.Mul(scale, normalized)

def check(self, op, x, scale, epsilon, compute_dtype, target_dtype) -> pattern.MatchResult: # type: ignore[name-defined]
def check(
self, op, x, scale, epsilon, compute_dtype, target_dtype, **_
) -> pattern.MatchResult: # type: ignore[name-defined]
"""Check if the pattern matches conditions for use of SimplifiedLayerNormalization op."""
check_result = pattern.MatchResult()
# epsilon must be a scalar
epsilon_value = _ir_utils.get_singleton_value(epsilon)
if not isinstance(epsilon_value, float): # TODO: support other types
return check_result.fail("Epsilon is not a float value.", epsilon)
# input and output must be same dtype
if x.dtype not in float_types:
return check_result.fail("Input is not a float type.", x)
if scale.dtype not in float_types:
return check_result.fail("Scale is not a float type.", scale)
stash_dtype = compute_dtype.value if self._cast_input else x.dtype
if stash_dtype not in fp_float_types:
self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype
if self._stash_dtype not in fp_float_types:
return check_result.fail("Normalization precision is not a float or double type.")
# target_dtype is guaranteed to be the same as scale type in a well-typed input
# for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input.
# TODO (rama): Consider adding checks to protect against incorrectly typed models:
return check_result

def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):
stash_dtype = compute_dtype.value if self._cast_input else x.dtype
def rewrite(self, op, x, scale, epsilon, **_):
# Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake.
# No need to use com.microsoft domain here; but this is a custom op in ORT.
return op.SimplifiedLayerNormalization(
x,
scale,
axis=-1,
epsilon=_ir_utils.get_singleton_value(epsilon),
stash_type=stash_dtype,
stash_type=self._stash_dtype,
)


_rule_0 = RmsNormFusion.rule("RmsNorm-0", cast_input=True, cast_normalized=True)
_rule_1 = RmsNormFusion.rule("RmsNorm-1", cast_input=False, cast_normalized=True)
_rule_2 = RmsNormFusion.rule("RmsNorm-2", cast_input=True, cast_normalized=False)
_rule_3 = RmsNormFusion.rule("RmsNorm-3", cast_input=False, cast_normalized=False)

rms_normalization_rules = [_rule_0, _rule_1, _rule_2, _rule_3]
_rule = RmsNormFusion.rule()
rms_normalization_rules = [_rule]
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)


Expand Down
Loading