Skip to content

Commit ed7e08e

Browse files
gramalingambmehta001
authored andcommitted
Cleanup ort transformer fusions (microsoft#2115)
Cleanup ort transformer-fusions.
1 parent 9d860f8 commit ed7e08e

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

onnxscript/rewriter/ort_fusions/_core.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,43 @@
33
from __future__ import annotations
44

55
import onnxscript.ir as ir
6+
from onnxscript.ir.passes.common import shape_inference
67
from onnxscript.optimizer import optimize, remove_unused_nodes
78
from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
9+
from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu
810
from onnxscript.rewriter.ort_fusions.mha import fuse_mha
911
from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
10-
from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding
12+
from onnxscript.rewriter.ort_fusions.rotary_embedding import (
13+
fuse_partial_rotary_embedding,
14+
fuse_rotary_embedding,
15+
)
1116
from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
1217
from onnxscript.rewriter.ort_fusions.skip_normalization import fuse_normalization
1318

1419

15-
def fuse_xformers(model: ir.Model) -> None:
20+
# Preliminary optimizations before applying the transformer fusions.
21+
# TODO: There are some potential redundancies below. Can be targeted for optimization
22+
# once we have robust fusion.
23+
def _pre_optimize(model: ir.Model) -> ir.Model:
24+
optimize(model)
25+
# TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some
26+
# extra shape-propagation and partial-data-propagation rules in ONNX that are not yet
27+
# incorporated in our optimizer.
28+
model = shape_inference.infer_shapes(model)
1629
optimize(model)
30+
return model
31+
32+
33+
def fuse_xformers(model: ir.Model) -> None:
34+
model = _pre_optimize(model)
1735
fuse_rms_normalization(model)
1836
fuse_normalization(model)
1937
fuse_rotary_embedding(model)
38+
fuse_partial_rotary_embedding(model)
2039
fuse_cos_sin_cache(model)
2140
fuse_sdpa(model)
2241
fuse_mha(model)
42+
fuse_gelu(model)
2343
remove_unused_nodes(model)
2444

2545

0 commit comments

Comments
 (0)