Skip to content

Add support for onnx fusions #2412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 1, 2025
28 changes: 28 additions & 0 deletions onnxscript/_framework_apis/torch_2_9.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Stable APIs for PyTorch 2.9."""

from __future__ import annotations

__all__ = [
"check_model",
"convert_version",
"get_torchlib_ops",
"optimize",
"save_model_with_external_data",
]

from onnxscript._framework_apis.torch_2_6 import (
check_model,
convert_version,
get_torchlib_ops,
save_model_with_external_data,
)
from onnxscript import optimizer
from onnxscript.rewriter import onnx_fusions

def optimize(model: ir.Model) -> ir.Model:
"""Optimize the model."""
optimizer.optimize_ir(model)
onnx_fusions.fuse(model)
return model
9 changes: 9 additions & 0 deletions onnxscript/rewriter/onnx_fusions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

from onnxscript.rewriter.onnx_fusions._onnx_fusions import fuse

__all__ = [
"fuse",
]
32 changes: 32 additions & 0 deletions onnxscript/rewriter/onnx_fusions/_onnx_fusions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnx_ir as ir

from onnxscript.rewriter.onnx_fusions import _rms_normalization


def _get_onnx_opset_version(model: ir.Model) -> int | None:
"""Get the ONNX opset version imported by the model."""
model_version1 = model.opset_imports.get("")
model_version2 = model.opset_imports.get("ai.onnx")
if model_version1 is not None and model_version2 is not None:
if model_version1 != model_version2:
raise ValueError(
f"Model imports multiple onnx opsets: {model_version1} and {model_version2}."
)
return model_version1 or model_version2


def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]:
"""Apply fusions targetting ONNX opset 23."""

Check warning on line 23 in onnxscript/rewriter/onnx_fusions/_onnx_fusions.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "targetting" is a misspelling of "targeting" Raw Output: ./onnxscript/rewriter/onnx_fusions/_onnx_fusions.py:23:21: "targetting" is a misspelling of "targeting"
counts: dict[str, int] = {}
counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug)


def fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]:
"""Apply fusions targetting ONNX ops."""

Check warning on line 29 in onnxscript/rewriter/onnx_fusions/_onnx_fusions.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "targetting" is a misspelling of "targeting" Raw Output: ./onnxscript/rewriter/onnx_fusions/_onnx_fusions.py:29:21: "targetting" is a misspelling of "targeting"
model_opset_version = _get_onnx_opset_version(model)
if model_opset_version == 23:
return _opset_23_fuse(model, debug=debug)
40 changes: 40 additions & 0 deletions onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import onnx_ir as ir

import onnxscript
import onnxscript.rewriter.onnx_fusions as onnx_fusions


class OnnxFusionsTest(unittest.TestCase):
def test_rms_normalization_fusion(self):
opset23 = onnxscript.values.Opset("", 23)

@onnxscript.script()
def rms_norm_script(embedding, layernorm_weight):
two = opset23.Constant(value_float=2.0)
pow_1 = opset23.Pow(embedding, two)
mean = opset23.ReduceMean(pow_1, [-1], keepdims=1, noop_with_empty_axes=0)
epsilon = opset23.Constant(value_float=1e-05)
add_1 = opset23.Add(mean, epsilon)
val_244 = opset23.Sqrt(add_1)
rsqrt = opset23.Reciprocal(val_244)
mul_3 = opset23.Mul(embedding, rsqrt)
mul_4 = opset23.Mul(layernorm_weight, mul_3)
return mul_4

rms_norm_model_proto = rms_norm_script.to_model_proto(
input_types=[onnxscript.FLOAT[128], onnxscript.FLOAT[128]],
output_types=[onnxscript.FLOAT[128]],
)
model = ir.serde.deserialize_model(rms_norm_model_proto)
onnx_fusions.fuse(model, debug=True)
self.assertEqual(model.graph.node(-1).op_type, "RMSNormalization")


if __name__ == "__main__":
unittest.main()
82 changes: 82 additions & 0 deletions onnxscript/rewriter/onnx_fusions/_rms_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import onnxscript.ir as ir
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern

"""
RMS Normalization: ONNX Opset 23 op
See: https://onnx.ai/onnx/operators/onnx__RMSNormalization.html#l-onnx-doc-rmsnormalization


Key points for the fusion optimization:
* Input and scale are allowed to be of different types.
* The normalization of the input can be done in a different precision than the input type,
indicated by stash_type.
* Input (x) must be: float or double or float16 or bfloat16
* Scale must be: float or double or float16 or bfloat16
"""

float_types = [
ir.DataType.FLOAT,
ir.DataType.FLOAT16,
ir.DataType.BFLOAT16,
ir.DataType.DOUBLE,
]
fp_float_types = [ir.DataType.FLOAT, ir.DataType.DOUBLE]


class RmsNormFusion(pattern.RewriteRuleClassBase):
def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
x = pattern.OrValue([op.Cast(x, to=compute_dtype), x])
x_square = op.Pow(x, 2.0)
mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0)
mean_square_plus_epsilon = op.Add(mean_square, epsilon)
rms = op.Sqrt(mean_square_plus_epsilon)
reciprocal_rms = op.Reciprocal(rms)
normalized = op.Mul(x, reciprocal_rms)
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
return op.Mul(scale, normalized)

def check(
self, op, x, scale, epsilon, compute_dtype, target_dtype, **_
) -> pattern.MatchResult: # type: ignore[name-defined]
"""Check if the pattern matches conditions for use of SimplifiedLayerNormalization op."""
check_result = pattern.MatchResult()
# epsilon must be a scalar
epsilon_value = _ir_utils.get_singleton_value(epsilon)
if not isinstance(epsilon_value, float): # TODO: support other types
return check_result.fail("Epsilon is not a float value.", epsilon)
if x.dtype not in float_types:
return check_result.fail("Input is not a supported float type.", x)
if scale.dtype not in float_types:
return check_result.fail("Scale is not a supported float type.", scale)
self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype
if self._stash_dtype not in fp_float_types:
# TODO: ONNX documentation does not specify restrictions on stash_type, though
# ORT's SimplifiedLayerNormalization requires it to be float or double.
return check_result.fail("Normalization precision is not a float or double type.")
# target_dtype is guaranteed to be the same as scale type in a well-typed input
# for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input.
# TODO (rama): Consider adding checks to protect against incorrectly typed models:
return check_result

def rewrite(self, op, x, scale, epsilon, **_):
# Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake.
# No need to use com.microsoft domain here; but this is a custom op in ORT.
return op.RMSNormalization(
x,
scale,
axis=-1,
epsilon=_ir_utils.get_singleton_value(epsilon),
stash_type=self._stash_dtype,
)


_rule = RmsNormFusion.rule()
rms_normalization_rules = [_rule]
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)


fuse_rms_normalization = _fusion_utils.apply_fusion_rules(rms_normalization_ruleset)
42 changes: 14 additions & 28 deletions onnxscript/rewriter/ort_fusions/rms_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,66 +28,52 @@


class RmsNormFusion(pattern.RewriteRuleClassBase):
def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool):
"""
Args:
name: Name of the rule.
cast_input: Whether to cast input to do the normalization in a different precision.
cast_normalized: Whether to cast the normalized output to the target dtype (same as scale).
"""
super().__init__(name=name)
self._cast_input = cast_input
self._cast_normalized = cast_normalized

def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
if self._cast_input:
x = op.Cast(x, to=compute_dtype)
x = pattern.OrValue([op.Cast(x, to=compute_dtype), x])
x_square = op.Pow(x, 2.0)
mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0)
mean_square_plus_epsilon = op.Add(mean_square, epsilon)
rms = op.Sqrt(mean_square_plus_epsilon)
reciprocal_rms = op.Reciprocal(rms)
normalized = op.Mul(x, reciprocal_rms)
if self._cast_normalized:
normalized = op.Cast(normalized, to=target_dtype)
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
return op.Mul(scale, normalized)

def check(self, op, x, scale, epsilon, compute_dtype, target_dtype) -> pattern.MatchResult: # type: ignore[name-defined]
def check(
self, op, x, scale, epsilon, compute_dtype, target_dtype, **_
) -> pattern.MatchResult: # type: ignore[name-defined]
"""Check if the pattern matches conditions for use of SimplifiedLayerNormalization op."""
check_result = pattern.MatchResult()
# epsilon must be a scalar
epsilon_value = _ir_utils.get_singleton_value(epsilon)
if not isinstance(epsilon_value, float): # TODO: support other types
return check_result.fail("Epsilon is not a float value.", epsilon)
# input and output must be same dtype
if x.dtype not in float_types:
return check_result.fail("Input is not a float type.", x)
if scale.dtype not in float_types:
return check_result.fail("Scale is not a float type.", scale)
stash_dtype = compute_dtype.value if self._cast_input else x.dtype
if stash_dtype not in fp_float_types:
self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype
if self._stash_dtype not in fp_float_types:
return check_result.fail("Normalization precision is not a float or double type.")
# target_dtype is guaranteed to be the same as scale type in a well-typed input
# for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input.
# TODO (rama): Consider adding checks to protect against incorrectly typed models:
return check_result

def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):
stash_dtype = compute_dtype.value if self._cast_input else x.dtype
def rewrite(self, op, x, scale, epsilon, **_):
# Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake.
# No need to use com.microsoft domain here; but this is a custom op in ORT.
return op.SimplifiedLayerNormalization(
x,
scale,
axis=-1,
epsilon=_ir_utils.get_singleton_value(epsilon),
stash_type=stash_dtype,
stash_type=self._stash_dtype,
)


_rule_0 = RmsNormFusion.rule("RmsNorm-0", cast_input=True, cast_normalized=True)
_rule_1 = RmsNormFusion.rule("RmsNorm-1", cast_input=False, cast_normalized=True)
_rule_2 = RmsNormFusion.rule("RmsNorm-2", cast_input=True, cast_normalized=False)
_rule_3 = RmsNormFusion.rule("RmsNorm-3", cast_input=False, cast_normalized=False)

rms_normalization_rules = [_rule_0, _rule_1, _rule_2, _rule_3]
_rule = RmsNormFusion.rule()
rms_normalization_rules = [_rule]
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)


Expand Down
Loading