-
Notifications
You must be signed in to change notification settings - Fork 72
Add support for onnx fusions #2412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
97119bd
Add support for onnx fusions
gramalingam 865e9c0
Add test case (partial)
gramalingam c1e346d
Finish test case
gramalingam fcb91aa
Update torch framework api
gramalingam 3d341f2
Fix lint error
gramalingam 221a4f0
Merge branch 'main' into rama/onnx-fusion
gramalingam 13288d7
Minor fixes
gramalingam 5f67557
Merge branch 'rama/onnx-fusion' of https://github.com/microsoft/onnxs…
gramalingam 7d3420e
Merge branch 'main' into rama/onnx-fusion
gramalingam 6ab81a4
Merge branch 'main' into rama/onnx-fusion
gramalingam 14c7b30
Merge with main
gramalingam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
"""Stable APIs for PyTorch 2.9.""" | ||
|
||
from __future__ import annotations | ||
|
||
__all__ = [ | ||
"check_model", | ||
"convert_version", | ||
"get_torchlib_ops", | ||
"optimize", | ||
"save_model_with_external_data", | ||
] | ||
|
||
from onnxscript._framework_apis.torch_2_6 import ( | ||
|
||
check_model, | ||
convert_version, | ||
get_torchlib_ops, | ||
save_model_with_external_data, | ||
) | ||
from onnxscript import optimizer | ||
from onnxscript.rewriter import onnx_fusions | ||
|
||
def optimize(model: ir.Model) -> ir.Model: | ||
|
||
"""Optimize the model.""" | ||
optimizer.optimize_ir(model) | ||
gramalingam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
onnx_fusions.fuse(model) | ||
return model |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
from onnxscript.rewriter.onnx_fusions._onnx_fusions import fuse | ||
|
||
__all__ = [ | ||
"fuse", | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
|
||
import onnx_ir as ir | ||
|
||
from onnxscript.rewriter.onnx_fusions import _rms_normalization | ||
|
||
|
||
def _get_onnx_opset_version(model: ir.Model) -> int | None: | ||
"""Get the ONNX opset version imported by the model.""" | ||
model_version1 = model.opset_imports.get("") | ||
model_version2 = model.opset_imports.get("ai.onnx") | ||
if model_version1 is not None and model_version2 is not None: | ||
if model_version1 != model_version2: | ||
raise ValueError( | ||
f"Model imports multiple onnx opsets: {model_version1} and {model_version2}." | ||
) | ||
return model_version1 or model_version2 | ||
|
||
|
||
def _opset_23_fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: | ||
"""Apply fusions targetting ONNX opset 23.""" | ||
Check warning on line 23 in onnxscript/rewriter/onnx_fusions/_onnx_fusions.py
|
||
counts: dict[str, int] = {} | ||
counts["RMSNormalization"] = _rms_normalization.fuse_rms_normalization(model, debug=debug) | ||
|
||
|
||
def fuse(model: ir.Model, *, debug: bool = False) -> dict[str, int]: | ||
"""Apply fusions targetting ONNX ops.""" | ||
Check warning on line 29 in onnxscript/rewriter/onnx_fusions/_onnx_fusions.py
|
||
model_opset_version = _get_onnx_opset_version(model) | ||
if model_opset_version == 23: | ||
return _opset_23_fuse(model, debug=debug) | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
|
||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
|
||
import unittest | ||
|
||
import onnx_ir as ir | ||
|
||
import onnxscript | ||
import onnxscript.rewriter.onnx_fusions as onnx_fusions | ||
|
||
|
||
class OnnxFusionsTest(unittest.TestCase): | ||
def test_rms_normalization_fusion(self): | ||
opset23 = onnxscript.values.Opset("", 23) | ||
|
||
@onnxscript.script() | ||
def rms_norm_script(embedding, layernorm_weight): | ||
two = opset23.Constant(value_float=2.0) | ||
pow_1 = opset23.Pow(embedding, two) | ||
mean = opset23.ReduceMean(pow_1, [-1], keepdims=1, noop_with_empty_axes=0) | ||
epsilon = opset23.Constant(value_float=1e-05) | ||
add_1 = opset23.Add(mean, epsilon) | ||
val_244 = opset23.Sqrt(add_1) | ||
rsqrt = opset23.Reciprocal(val_244) | ||
mul_3 = opset23.Mul(embedding, rsqrt) | ||
|
||
mul_4 = opset23.Mul(layernorm_weight, mul_3) | ||
return mul_4 | ||
|
||
rms_norm_model_proto = rms_norm_script.to_model_proto( | ||
input_types=[onnxscript.FLOAT[128], onnxscript.FLOAT[128]], | ||
output_types=[onnxscript.FLOAT[128]], | ||
) | ||
model = ir.serde.deserialize_model(rms_norm_model_proto) | ||
onnx_fusions.fuse(model, debug=True) | ||
self.assertEqual(model.graph.node(-1).op_type, "RMSNormalization") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
from __future__ import annotations | ||
|
||
import onnxscript.ir as ir | ||
from onnxscript.rewriter import _fusion_utils, _ir_utils, pattern | ||
|
||
""" | ||
RMS Normalization: ONNX Opset 23 op | ||
See: https://onnx.ai/onnx/operators/onnx__RMSNormalization.html#l-onnx-doc-rmsnormalization | ||
|
||
|
||
Key points for the fusion optimization: | ||
* Input and scale are allowed to be of different types. | ||
* The normalization of the input can be done in a different precision than the input type, | ||
indicated by stash_type. | ||
* Input (x) must be: float or double or float16 or bfloat16 | ||
* Scale must be: float or double or float16 or bfloat16 | ||
""" | ||
|
||
float_types = [ | ||
ir.DataType.FLOAT, | ||
ir.DataType.FLOAT16, | ||
ir.DataType.BFLOAT16, | ||
ir.DataType.DOUBLE, | ||
] | ||
fp_float_types = [ir.DataType.FLOAT, ir.DataType.DOUBLE] | ||
|
||
|
||
class RmsNormFusion(pattern.RewriteRuleClassBase): | ||
def pattern(self, op, x, scale, epsilon, compute_dtype, target_dtype): | ||
x = pattern.OrValue([op.Cast(x, to=compute_dtype), x]) | ||
x_square = op.Pow(x, 2.0) | ||
mean_square = op.ReduceMean(x_square, [-1], keepdims=1, noop_with_empty_axes=0) | ||
mean_square_plus_epsilon = op.Add(mean_square, epsilon) | ||
rms = op.Sqrt(mean_square_plus_epsilon) | ||
reciprocal_rms = op.Reciprocal(rms) | ||
normalized = op.Mul(x, reciprocal_rms) | ||
normalized = pattern.OrValue([op.Cast(normalized, to=target_dtype), normalized]) | ||
return op.Mul(scale, normalized) | ||
|
||
def check( | ||
self, op, x, scale, epsilon, compute_dtype, target_dtype, **_ | ||
) -> pattern.MatchResult: # type: ignore[name-defined] | ||
"""Check if the pattern matches conditions for use of SimplifiedLayerNormalization op.""" | ||
check_result = pattern.MatchResult() | ||
# epsilon must be a scalar | ||
epsilon_value = _ir_utils.get_singleton_value(epsilon) | ||
if not isinstance(epsilon_value, float): # TODO: support other types | ||
return check_result.fail("Epsilon is not a float value.", epsilon) | ||
if x.dtype not in float_types: | ||
return check_result.fail("Input is not a supported float type.", x) | ||
if scale.dtype not in float_types: | ||
return check_result.fail("Scale is not a supported float type.", scale) | ||
self._stash_dtype = compute_dtype.as_int() if compute_dtype is not None else x.dtype | ||
if self._stash_dtype not in fp_float_types: | ||
# TODO: ONNX documentation does not specify restrictions on stash_type, though | ||
# ORT's SimplifiedLayerNormalization requires it to be float or double. | ||
return check_result.fail("Normalization precision is not a float or double type.") | ||
# target_dtype is guaranteed to be the same as scale type in a well-typed input | ||
# for Mul(scale, normalized) to work. There is no need to check it here for a well-typed input. | ||
# TODO (rama): Consider adding checks to protect against incorrectly typed models: | ||
return check_result | ||
|
||
def rewrite(self, op, x, scale, epsilon, **_): | ||
# Note: ORT's SimplifiedLayerNormalization was placed in onnx domain by mistake. | ||
# No need to use com.microsoft domain here; but this is a custom op in ORT. | ||
return op.RMSNormalization( | ||
x, | ||
scale, | ||
axis=-1, | ||
epsilon=_ir_utils.get_singleton_value(epsilon), | ||
stash_type=self._stash_dtype, | ||
) | ||
|
||
|
||
_rule = RmsNormFusion.rule() | ||
rms_normalization_rules = [_rule] | ||
rms_normalization_ruleset = pattern.RewriteRuleSet(rms_normalization_rules) | ||
|
||
|
||
fuse_rms_normalization = _fusion_utils.apply_fusion_rules(rms_normalization_ruleset) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.