|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 | 5 | import onnxscript.ir as ir
|
| 6 | +from onnxscript.ir.passes.common import shape_inference |
6 | 7 | from onnxscript.optimizer import optimize, remove_unused_nodes
|
7 | 8 | from onnxscript.rewriter.ort_fusions.cos_sin_cache import fuse_cos_sin_cache
|
| 9 | +from onnxscript.rewriter.ort_fusions.gelu import fuse_gelu |
8 | 10 | from onnxscript.rewriter.ort_fusions.mha import fuse_mha
|
9 | 11 | from onnxscript.rewriter.ort_fusions.rms_normalization import fuse_rms_normalization
|
10 |
| -from onnxscript.rewriter.ort_fusions.rotary_embedding import fuse_rotary_embedding |
| 12 | +from onnxscript.rewriter.ort_fusions.rotary_embedding import ( |
| 13 | + fuse_partial_rotary_embedding, |
| 14 | + fuse_rotary_embedding, |
| 15 | +) |
11 | 16 | from onnxscript.rewriter.ort_fusions.sdpa import fuse_sdpa
|
12 | 17 | from onnxscript.rewriter.ort_fusions.skip_normalization import fuse_normalization
|
13 | 18 |
|
14 | 19 |
|
15 |
| -def fuse_xformers(model: ir.Model) -> None: |
| 20 | +# Preliminary optimizations before applying the transformer fusions. |
| 21 | +# TODO: There are some potential redundancies below. Can be targeted for optimization |
| 22 | +# once we have robust fusion. |
| 23 | +def _pre_optimize(model: ir.Model) -> ir.Model: |
| 24 | + optimize(model) |
| 25 | + # TODO: Do we need this dependence on ONNX's partial-data-propagation? There are some |
| 26 | + # extra shape-propagation and partial-data-propagation rules in ONNX that are not yet |
| 27 | + # incorporated in our optimizer. |
| 28 | + model = shape_inference.infer_shapes(model) |
16 | 29 | optimize(model)
|
| 30 | + return model |
| 31 | + |
| 32 | + |
| 33 | +def fuse_xformers(model: ir.Model) -> None: |
| 34 | + model = _pre_optimize(model) |
17 | 35 | fuse_rms_normalization(model)
|
18 | 36 | fuse_normalization(model)
|
19 | 37 | fuse_rotary_embedding(model)
|
| 38 | + fuse_partial_rotary_embedding(model) |
20 | 39 | fuse_cos_sin_cache(model)
|
21 | 40 | fuse_sdpa(model)
|
22 | 41 | fuse_mha(model)
|
| 42 | + fuse_gelu(model) |
23 | 43 | remove_unused_nodes(model)
|
24 | 44 |
|
25 | 45 |
|
|
0 commit comments