diff --git a/onnxscript/_framework_apis/torch_2_8.py b/onnxscript/_framework_apis/torch_2_8.py index ee5e6089e5..bbd1ffc786 100644 --- a/onnxscript/_framework_apis/torch_2_8.py +++ b/onnxscript/_framework_apis/torch_2_8.py @@ -12,10 +12,20 @@ "save_model_with_external_data", ] +import onnx_ir as ir + +import onnxscript.optimizer +import onnxscript.rewriter.onnx_fusions from onnxscript._framework_apis.torch_2_6 import ( check_model, convert_version, get_torchlib_ops, - optimize, save_model_with_external_data, ) + + +def optimize(model: ir.Model) -> ir.Model: + """Optimize the model.""" + onnxscript.optimizer.optimize_ir(model) + onnxscript.rewriter.onnx_fusions.fuse(model) + return model diff --git a/onnxscript/rewriter/onnx_fusions/__init__.py b/onnxscript/rewriter/onnx_fusions/__init__.py new file mode 100644 index 0000000000..d2e8d885f0 --- /dev/null +++ b/onnxscript/rewriter/onnx_fusions/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +from onnxscript.rewriter.onnx_fusions._onnx_fusions import fuse + +__all__ = [ + "fuse", +] diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py new file mode 100644 index 0000000000..96446e6fb4 --- /dev/null +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnx_ir as ir + +from onnxscript.rewriter.onnx_fusions import _rms_normalization + + +def _get_onnx_opset_version(model: ir.Model) -> int | None: + """Get the ONNX opset version imported by the model.""" + model_version1 = model.opset_imports.get("") + model_version2 = model.opset_imports.get("ai.onnx") + if model_version1 is not None and model_version2 is not None: + if model_version1 != model_version2: + raise ValueError( + f"Model imports multiple onnx opsets: {model_version1} and {model_version2}." + ) + return model_version1 or model_version2 + + +def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: + """Apply fusions targeting ONNX opset 23.""" + counts: dict[str, int] = {} + counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug) + return counts + + +def fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: + """Apply fusions targeting ONNX ops.""" + model_opset_version = _get_onnx_opset_version(model) + if model_opset_version == 23: + return _opset_23_fuse(model, debug=debug) + return {} diff --git a/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py b/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py new file mode 100644 index 0000000000..dfd9ca4296 --- /dev/null +++ b/onnxscript/rewriter/onnx_fusions/_onnx_fusions_test.py @@ -0,0 +1,40 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest + +import onnx_ir as ir + +import onnxscript +import onnxscript.rewriter.onnx_fusions as onnx_fusions + + +class OnnxFusionsTest(unittest.TestCase): + def test_rms_normalization_fusion(self): + opset23 = onnxscript.values.Opset("", 23) + + @onnxscript.script() + def rms_norm_script(embedding, layernorm_weight): + two = opset23.Constant(value_float=2.0) + pow_1 = opset23.Pow(embedding, two) + mean = opset23.ReduceMean(pow_1, [-1], keepdims=1, noop_with_empty_axes=0) + epsilon = opset23.Constant(value_float=1e-05) + add_1 = opset23.Add(mean, epsilon) + val_244 = opset23.Sqrt(add_1) + rsqrt = opset23.Reciprocal(val_244) + mul_3 = opset23.Mul(embedding, rsqrt) + mul_4 = opset23.Mul(layernorm_weight, mul_3) + return mul_4 + + rms_norm_model_proto = rms_norm_script.to_model_proto( + input_types=[onnxscript.FLOAT[128], onnxscript.FLOAT[128]], + output_types=[onnxscript.FLOAT[128]], + ) + model = ir.serde.deserialize_model(rms_norm_model_proto) + onnx_fusions.fuse(model, debug=True) + self.assertEqual(model.graph.node(-1).op_type, "RMSNormalization") + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/rewriter/onnx_fusions/_rms_normalization.py b/onnxscript/rewriter/onnx_fusions/_rms_normalization.py new file mode 100644 index 0000000000..dc7d1bc971 --- /dev/null +++ b/onnxscript/rewriter/onnx_fusions/_rms_normalization.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import onnxscript.ir as ir +from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern + +""" +RMS Normalization: ONNX Opset 23 op +See: https://onnx.ai/onnx/operators/onnx__RMSNormalization.html#l-onnx-doc-rmsnormalization + + +Key points for the fusion optimization: +* Input and scale are allowed to be of different types. +* The normalization of the input can be done in a different precision than the input type, +indicated by stash_type. +* Input (x) must be: float or double or float16 or bfloat16 +* Scale must be: float or double or float16 or bfloat16 +""" + +float_types = frozenset( + [ + ir.DataType.FLOAT, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + ir.DataType.DOUBLE, + ] +) +fp_float_types = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE]) + + +class RmsNormFusion(pattern.RewriteRuleClassBase): + def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): + x = pattern.OrValue([op.Cast(x, to=compute_dtype), x]) + x_square = op.Pow(x, 2.0) + mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0) + mean_square_plus_epsilon = op.Add(mean_square, epsilon) + rms = op.Sqrt(mean_square_plus_epsilon) + reciprocal_rms = op.Reciprocal(rms) + normalized = op.Mul(x, reciprocal_rms) + normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized]) + return op.Mul(scale, normalized) + + def check( + self, op, x, scale, epsilon, compute_dtype, target_dtype, **_ + ) -> pattern.MatchResult: # type: ignore[name-defined] + """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" + check_result = pattern.MatchResult() + # epsilon must be a scalar + epsilon_value = _ir_utils.get_singleton_value(epsilon) + if not isinstance(epsilon_value, float): # TODO: support other types + return check_result.fail("Epsilon is not a float value.", epsilon) + if x.dtype not in float_types: + return check_result.fail("Input is not a supported float type.", x) + if scale.dtype not in float_types: + return check_result.fail("Scale is not a supported float type.", scale) + self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype + if self._stash_dtype not in fp_float_types: + # TODO: ONNX documentation does not specify restrictions on stash_type, though + # ORT's SimplifiedLayerNormalization requires it to be float or double. + return check_result.fail("Normalization precision is not a float or double type.") + # target_dtype is guaranteed to be the same as scale type in a well-typed input + # for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input. + # TODO (rama): Consider adding checks to protect against incorrectly typed models: + return check_result + + def rewrite(self, op, x, scale, epsilon, **_): + # Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake. + # No need to use com.microsoft domain here; but this is a custom op in ORT. + return op.RMSNormalization( + x, + scale, + axis=-1, + epsilon=_ir_utils.get_singleton_value(epsilon), + stash_type=self._stash_dtype, + ) + + +_rule = RmsNormFusion.rule() +rms_normalization_rules = [_rule] +rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules) + + +fuse_rms_normalization = _fusion_utils.apply_fusion_rules(rms_normalization_ruleset) diff --git a/onnxscript/rewriter/ort_fusions/rms_normalization.py b/onnxscript/rewriter/ort_fusions/rms_normalization.py index 7bb631b0ea..b12da46e8b 100644 --- a/onnxscript/rewriter/ort_fusions/rms_normalization.py +++ b/onnxscript/rewriter/ort_fusions/rms_normalization.py @@ -19,59 +19,51 @@ * Normalization precision must be float or double """ -float_types = [ - ir.DataType.FLOAT, - ir.DataType.FLOAT16, - ir.DataType.BFLOAT16, - ir.DataType.DOUBLE, -] -fp_float_types = [ir.DataType.FLOAT, ir.DataType.DOUBLE] +float_types = frozenset( + [ + ir.DataType.FLOAT, + ir.DataType.FLOAT16, + ir.DataType.BFLOAT16, + ir.DataType.DOUBLE, + ] +) +fp_float_types = frozenset([ir.DataType.FLOAT, ir.DataType.DOUBLE]) class RmsNormFusion(pattern.RewriteRuleClassBase): - def __init__(self, name: str, *, cast_input: bool, cast_normalized: bool): - """ - Args: - name: Name of the rule. - cast_input: Whether to cast input to do the normalization in a different precision. - cast_normalized: Whether to cast the normalized output to the target dtype (same as scale). - """ - super().__init__(name=name) - self._cast_input = cast_input - self._cast_normalized = cast_normalized - def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): - if self._cast_input: - x = op.Cast(x, to=compute_dtype) + x = pattern.OrValue([op.Cast(x, to=compute_dtype), x]) x_square = op.Pow(x, 2.0) mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0) mean_square_plus_epsilon = op.Add(mean_square, epsilon) rms = op.Sqrt(mean_square_plus_epsilon) reciprocal_rms = op.Reciprocal(rms) normalized = op.Mul(x, reciprocal_rms) - if self._cast_normalized: - normalized = op.Cast(normalized, to=target_dtype) + normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized]) return op.Mul(scale, normalized) - def check(self, op, x, scale, epsilon, compute_dtype, target_dtype) -> pattern.MatchResult: # type: ignore[name-defined] + def check( + self, op, x, scale, epsilon, compute_dtype, target_dtype, **_ + ) -> pattern.MatchResult: # type: ignore[name-defined] """Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" check_result = pattern.MatchResult() # epsilon must be a scalar epsilon_value = _ir_utils.get_singleton_value(epsilon) if not isinstance(epsilon_value, float): # TODO: support other types return check_result.fail("Epsilon is not a float value.", epsilon) - # input and output must be same dtype if x.dtype not in float_types: return check_result.fail("Input is not a float type.", x) if scale.dtype not in float_types: return check_result.fail("Scale is not a float type.", scale) - stash_dtype = compute_dtype.value if self._cast_input else x.dtype - if stash_dtype not in fp_float_types: + self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype + if self._stash_dtype not in fp_float_types: return check_result.fail("Normalization precision is not a float or double type.") + # target_dtype is guaranteed to be the same as scale type in a well-typed input + # for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input. + # TODO (rama): Consider adding checks to protect against incorrectly typed models: return check_result - def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype): - stash_dtype = compute_dtype.value if self._cast_input else x.dtype + def rewrite(self, op, x, scale, epsilon, **_): # Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake. # No need to use com.microsoft domain here; but this is a custom op in ORT. return op.SimplifiedLayerNormalization( @@ -79,16 +71,12 @@ def rewrite(self, op, x, scale, epsilon, compute_dtype, target_dtype): scale, axis=-1, epsilon=_ir_utils.get_singleton_value(epsilon), - stash_type=stash_dtype, + stash_type=self._stash_dtype, ) -_rule_0 = RmsNormFusion.rule("RmsNorm-0", cast_input=True, cast_normalized=True) -_rule_1 = RmsNormFusion.rule("RmsNorm-1", cast_input=False, cast_normalized=True) -_rule_2 = RmsNormFusion.rule("RmsNorm-2", cast_input=True, cast_normalized=False) -_rule_3 = RmsNormFusion.rule("RmsNorm-3", cast_input=False, cast_normalized=False) - -rms_normalization_rules = [_rule_0, _rule_1, _rule_2, _rule_3] +_rule = RmsNormFusion.rule() +rms_normalization_rules = [_rule] rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules)