Skip to content

Commit f4534ee

Browse files
authored
Add support for onnx fusions (#2412)
* Add basic infrastructure support for fusions targeting ONNX opset 23, with RMSNormalization as one target op. * Cleanup existing RMSNormalization fusion targetting ORT's contrib op (using pattern-disjunction to simplify rules). --------- Signed-off-by: Ganesan Ramalingam <[email protected]>
1 parent 34dc350 commit f4534ee

File tree

6 files changed

+201
-36
lines changed

6 files changed

+201
-36
lines changed

onnxscript/_framework_apis/torch_2_8.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,20 @@
1212
"save_model_with_external_data",
1313
]
1414

15+
import onnx_ir as ir
16+
17+
import onnxscript.optimizer
18+
import onnxscript.rewriter.onnx_fusions
1519
from onnxscript._framework_apis.torch_2_6 import (
1620
check_model,
1721
convert_version,
1822
get_torchlib_ops,
19-
optimize,
2023
save_model_with_external_data,
2124
)
25+
26+
27+
def optimize(model: ir.Model) -> ir.Model:
28+
"""Optimize the model."""
29+
onnxscript.optimizer.optimize_ir(model)
30+
onnxscript.rewriter.onnx_fusions.fuse(model)
31+
return model
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
from onnxscript.rewriter.onnx_fusions._onnx_fusions import fuse
6+
7+
__all__ = [
8+
"fuse",
9+
]
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import onnx_ir as ir
6+
7+
from onnxscript.rewriter.onnx_fusions import _rms_normalization
8+
9+
10+
def _get_onnx_opset_version(model: ir.Model) -> int | None:
11+
"""Get the ONNX opset version imported by the model."""
12+
model_version1 = model.opset_imports.get("")
13+
model_version2 = model.opset_imports.get("ai.onnx")
14+
if model_version1 is not None and model_version2 is not None:
15+
if model_version1 != model_version2:
16+
raise ValueError(
17+
f"Model imports multiple onnx opsets: {model_version1} and {model_version2}."
18+
)
19+
return model_version1 or model_version2
20+
21+
22+
def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]:
23+
"""Apply fusions targeting ONNX opset 23."""
24+
counts: dict[str, int] = {}
25+
counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug)
26+
return counts
27+
28+
29+
def fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]:
30+
"""Apply fusions targeting ONNX ops."""
31+
model_opset_version = _get_onnx_opset_version(model)
32+
if model_opset_version == 23:
33+
return _opset_23_fuse(model, debug=debug)
34+
return {}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import unittest
6+
7+
import onnx_ir as ir
8+
9+
import onnxscript
10+
import onnxscript.rewriter.onnx_fusions as onnx_fusions
11+
12+
13+
class OnnxFusionsTest(unittest.TestCase):
14+
def test_rms_normalization_fusion(self):
15+
opset23 = onnxscript.values.Opset("", 23)
16+
17+
@onnxscript.script()
18+
def rms_norm_script(embedding, layernorm_weight):
19+
two = opset23.Constant(value_float=2.0)
20+
pow_1 = opset23.Pow(embedding, two)
21+
mean = opset23.ReduceMean(pow_1, [-1], keepdims=1, noop_with_empty_axes=0)
22+
epsilon = opset23.Constant(value_float=1e-05)
23+
add_1 = opset23.Add(mean, epsilon)
24+
val_244 = opset23.Sqrt(add_1)
25+
rsqrt = opset23.Reciprocal(val_244)
26+
mul_3 = opset23.Mul(embedding, rsqrt)
27+
mul_4 = opset23.Mul(layernorm_weight, mul_3)
28+
return mul_4
29+
30+
rms_norm_model_proto = rms_norm_script.to_model_proto(
31+
input_types=[onnxscript.FLOAT[128], onnxscript.FLOAT[128]],
32+
output_types=[onnxscript.FLOAT[128]],
33+
)
34+
model = ir.serde.deserialize_model(rms_norm_model_proto)
35+
onnx_fusions.fuse(model, debug=True)
36+
self.assertEqual(model.graph.node(-1).op_type, "RMSNormalization")
37+
38+
39+
if __name__ == "__main__":
40+
unittest.main()
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
from __future__ import annotations
4+
5+
import onnxscript.ir as ir
6+
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern
7+
8+
"""
9+
RMS Normalization: ONNX Opset 23 op
10+
See: https://onnx.ai/onnx/operators/onnx__RMSNormalization.html#l-onnx-doc-rmsnormalization
11+
12+
13+
Key points for the fusion optimization:
14+
* Input and scale are allowed to be of different types.
15+
* The normalization of the input can be done in a different precision than the input type,
16+
indicated by stash_type.
17+
* Input (x) must be: float or double or float16 or bfloat16
18+
* Scale must be: float or double or float16 or bfloat16
19+
"""
20+
21+
float_types = frozenset(
22+
[
23+
ir.DataType.FLOAT,
24+
ir.DataType.FLOAT16,
25+
ir.DataType.BFLOAT16,
26+
ir.DataType.DOUBLE,
27+
]
28+
)
29+
fp_float_types = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE])
30+
31+
32+
class RmsNormFusion(pattern.RewriteRuleClassBase):
33+
def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
34+
x = pattern.OrValue([op.Cast(x, to=compute_dtype), x])
35+
x_square = op.Pow(x, 2.0)
36+
mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0)
37+
mean_square_plus_epsilon = op.Add(mean_square, epsilon)
38+
rms = op.Sqrt(mean_square_plus_epsilon)
39+
reciprocal_rms = op.Reciprocal(rms)
40+
normalized = op.Mul(x, reciprocal_rms)
41+
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
42+
return op.Mul(scale, normalized)
43+
44+
def check(
45+
self, op, x, scale, epsilon, compute_dtype, target_dtype, **_
46+
) -> pattern.MatchResult: # type: ignore[name-defined]
47+
"""Check if the pattern matches conditions for use of SimplifiedLayerNormalization op."""
48+
check_result = pattern.MatchResult()
49+
# epsilon must be a scalar
50+
epsilon_value = _ir_utils.get_singleton_value(epsilon)
51+
if not isinstance(epsilon_value, float): # TODO: support other types
52+
return check_result.fail("Epsilon is not a float value.", epsilon)
53+
if x.dtype not in float_types:
54+
return check_result.fail("Input is not a supported float type.", x)
55+
if scale.dtype not in float_types:
56+
return check_result.fail("Scale is not a supported float type.", scale)
57+
self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype
58+
if self._stash_dtype not in fp_float_types:
59+
# TODO: ONNX documentation does not specify restrictions on stash_type, though
60+
# ORT's SimplifiedLayerNormalization requires it to be float or double.
61+
return check_result.fail("Normalization precision is not a float or double type.")
62+
# target_dtype is guaranteed to be the same as scale type in a well-typed input
63+
# for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input.
64+
# TODO (rama): Consider adding checks to protect against incorrectly typed models:
65+
return check_result
66+
67+
def rewrite(self, op, x, scale, epsilon, **_):
68+
# Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake.
69+
# No need to use com.microsoft domain here; but this is a custom op in ORT.
70+
return op.RMSNormalization(
71+
x,
72+
scale,
73+
axis=-1,
74+
epsilon=_ir_utils.get_singleton_value(epsilon),
75+
stash_type=self._stash_dtype,
76+
)
77+
78+
79+
_rule = RmsNormFusion.rule()
80+
rms_normalization_rules = [_rule]
81+
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)
82+
83+
84+
fuse_rms_normalization = _fusion_utils.apply_fusion_rules(rms_normalization_ruleset)

onnxscript/rewriter/ort_fusions/rms_normalization.py

Lines changed: 23 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,76 +19,64 @@
1919
* Normalization precision must be float or double
2020
"""
2121

22-
float_types = [
23-
ir.DataType.FLOAT,
24-
ir.DataType.FLOAT16,
25-
ir.DataType.BFLOAT16,
26-
ir.DataType.DOUBLE,
27-
]
28-
fp_float_types = [ir.DataType.FLOAT, ir.DataType.DOUBLE]
22+
float_types = frozenset(
23+
[
24+
ir.DataType.FLOAT,
25+
ir.DataType.FLOAT16,
26+
ir.DataType.BFLOAT16,
27+
ir.DataType.DOUBLE,
28+
]
29+
)
30+
fp_float_types = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE])
2931

3032

3133
class RmsNormFusion(pattern.RewriteRuleClassBase):
32-
def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool):
33-
"""
34-
Args:
35-
name: Name of the rule.
36-
cast_input: Whether to cast input to do the normalization in a different precision.
37-
cast_normalized: Whether to cast the normalized output to the target dtype (same as scale).
38-
"""
39-
super().__init__(name=name)
40-
self._cast_input = cast_input
41-
self._cast_normalized = cast_normalized
42-
4334
def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype):
44-
if self._cast_input:
45-
x = op.Cast(x, to=compute_dtype)
35+
x = pattern.OrValue([op.Cast(x, to=compute_dtype), x])
4636
x_square = op.Pow(x, 2.0)
4737
mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0)
4838
mean_square_plus_epsilon = op.Add(mean_square, epsilon)
4939
rms = op.Sqrt(mean_square_plus_epsilon)
5040
reciprocal_rms = op.Reciprocal(rms)
5141
normalized = op.Mul(x, reciprocal_rms)
52-
if self._cast_normalized:
53-
normalized = op.Cast(normalized, to=target_dtype)
42+
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized])
5443
return op.Mul(scale, normalized)
5544

56-
def check(self, op, x, scale, epsilon, compute_dtype, target_dtype) -> pattern.MatchResult: # type: ignore[name-defined]
45+
def check(
46+
self, op, x, scale, epsilon, compute_dtype, target_dtype, **_
47+
) -> pattern.MatchResult: # type: ignore[name-defined]
5748
"""Check if the pattern matches conditions for use of SimplifiedLayerNormalization op."""
5849
check_result = pattern.MatchResult()
5950
# epsilon must be a scalar
6051
epsilon_value = _ir_utils.get_singleton_value(epsilon)
6152
if not isinstance(epsilon_value, float): # TODO: support other types
6253
return check_result.fail("Epsilon is not a float value.", epsilon)
63-
# input and output must be same dtype
6454
if x.dtype not in float_types:
6555
return check_result.fail("Input is not a float type.", x)
6656
if scale.dtype not in float_types:
6757
return check_result.fail("Scale is not a float type.", scale)
68-
stash_dtype = compute_dtype.value if self._cast_input else x.dtype
69-
if stash_dtype not in fp_float_types:
58+
self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype
59+
if self._stash_dtype not in fp_float_types:
7060
return check_result.fail("Normalization precision is not a float or double type.")
61+
# target_dtype is guaranteed to be the same as scale type in a well-typed input
62+
# for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input.
63+
# TODO (rama): Consider adding checks to protect against incorrectly typed models:
7164
return check_result
7265

73-
def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype):
74-
stash_dtype = compute_dtype.value if self._cast_input else x.dtype
66+
def rewrite(self, op, x, scale, epsilon, **_):
7567
# Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake.
7668
# No need to use com.microsoft domain here; but this is a custom op in ORT.
7769
return op.SimplifiedLayerNormalization(
7870
x,
7971
scale,
8072
axis=-1,
8173
epsilon=_ir_utils.get_singleton_value(epsilon),
82-
stash_type=stash_dtype,
74+
stash_type=self._stash_dtype,
8375
)
8476

8577

86-
_rule_0 = RmsNormFusion.rule("RmsNorm-0", cast_input=True, cast_normalized=True)
87-
_rule_1 = RmsNormFusion.rule("RmsNorm-1", cast_input=False, cast_normalized=True)
88-
_rule_2 = RmsNormFusion.rule("RmsNorm-2", cast_input=True, cast_normalized=False)
89-
_rule_3 = RmsNormFusion.rule("RmsNorm-3", cast_input=False, cast_normalized=False)
90-
91-
rms_normalization_rules = [_rule_0, _rule_1, _rule_2, _rule_3]
78+
_rule = RmsNormFusion.rule()
79+
rms_normalization_rules = [_rule]
9280
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)
9381

9482

0 commit comments

Comments
 (0)