Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fuse Initializers Graph Transform #24175

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

sunnyshu-intel
Copy link
Contributor

@sunnyshu-intel sunnyshu-intel commented Mar 25, 2025

Description

Added a graph transform for mixed precisions graphs when FP16 compute is unavailable. At session creation, this graph transform converts FP16 initializers (which were changed to FP16 to FP32 cast nodes) to FP32 initializers and fuses them with their next FP32 nodes.

  • Behavior before this change:
    "fp16 initializers -> cast_from_fp16_to_fp32 -> fp32 node/s"

  • Behavior after this change:
    "fp16 initializers converted to fp32 initializers then fused with fp32 node/s"

Motivation and Context

This change aims to run the FP16 models without the repetitive casting of FP16 initializers to FP32 initializers, by fusing FP32 initializers with their next nodes, when FP16 compute is not available.

Two separate commits For this PR

This PR consists of two separate commits.

  • The First commit adds Fuse Initializers Graph Transforms "FIGT" as the last stage of the graph transforms during session creation.
  • The second commit adds a new level to Graph Transforms, "Level 4". We purposed it to be used to support unsupported datatype computes. We added "FIGT" to Level 4 of graph optimization levels, so that it can be turned off whenever required and the Level 1 to Level 3 optimizations remain unaffected.
  • We have submitted this PR with two commits so we can track the changes needed for both commits separately. Once approved, we can squash these commits before merging.

Re-Running Level 1 to Level 3 optimizations after Level 4 / FIGT

The idea behind re-running Level1, Partitioning, Level2, and Level3 graph transforms is that, after the fusion of initializers with their respective nodes, the nodes are now in a format that might be supported by other graph transforms that were previously skipped. Hence, some of the transformations previously unable to be applied are now valid and can be applied to create a more optimal graph for execution.

Documentation

We have not yet added any details related to Level 4 in the Graph Optimizations in ONNX Runtime documentation. We might need a little bit of guidance on how to update this documentation, once the PR is accepted.

Working

Currently, the Fuse Initializers Graph Transform fuses cast nodes that casts from FP16 to FP32, back to their
next/output nodes. Below is an explanation of how this transforms works. It depends on InsertCastTransforms
to produce the intermediate representation from which it fuses the initializers (which are the cast node with
zero input, one initializer, and one output) back to the next/output node. After fusion, the link/edge between such
cast node to the next/output node will then be removed.

        "Input Graph"                       "Intermediate Representation"                 "FIGT Transforms"

          --------                   --------        --------        --------                 --------
         | X_Fp16 |                 | X_Fp16 |      | W_Fp16 |      | B_Fp16 |               | X_Fp16 |
          --------                   --------        --------        --------                 --------
             |                          |               |               |                        |
             |                          |               |               |                        |
             |                          V               V               V                        V
             |                       | Cast |        | Cast |        | Cast |                 | Cast |
             |                       | Fp16 |        | Fp16 |        | Fp16 |                 | Fp16 |
             |                       |  To  |        |  To  |        |  To  |                 |  To  |
             |                       | Fp32 |        | Fp32 |        | Fp32 |                 | Fp32 |
             |                          |               |               |                        |
             |                          |               |               |                        |
             V                          V               V               V                        V
 ----------------------------       -----------------------------------------       ----------------------------
|        Conv_Fp16           |     |                                         |     |         Conv_Fp32          |
|        --W_Fp16--          | ==> |                Conv_Fp32                | ==> |         --W_Fp32--         |
|        --B_Fp16--          |     |                                         |     |         --B_Fp32--         |
 ----------------------------       -----------------------------------------       ----------------------------
             |                                          |                                        |
             |                                          |                                        |
             |                                          V                                        V
             |                                       | Cast |                                 | Cast |
             |                                       | Fp32 |                                 | Fp32 |
             |                                       |  To  |                                 |  To  |
             |                                       | Fp16 |                                 | Fp16 |
             |                                          |                                        |
             |                                          |                                        |
             V                                          V                                        V
          --------                                   --------                                 --------
         | Y_Fp16 |                                 | Y_Fp16 |                               | Y_Fp16 |
          --------                                   --------                                 --------

The newly added Graph Transforms performs the following actions.

- Detect Cast node/s with single FP16 initializer
  converting to FP32.
- Convert all such FP16 initializer/s to FP32 initializer/s.
- Fuse newly created FP32 initializer/s to relative FP32 node/s.
- Remove FP16 to FP32 Cast node/s.

Note: For naming purpose the newly added Graph Trasnforms
in long form is called "Fused Initializers Graph Transforms",
and in short form is called "FIGT".

Signed-off-by: Sunny Shukla <[email protected]>
This change helps with the following requirements

- Ability to turn off the FIGT optimization.
- Ability to re-run Level-1 to Level-3 optimizations, only if FIGT
  optimization is applied.
- Keep the current flow of graph optimizations untouched.

Signed-off-by: Sunny Shukla <[email protected]>
@sunnyshu-intel sunnyshu-intel requested a review from a team as a code owner March 25, 2025 21:43
@tianleiwu
Copy link
Contributor

Regarding to "fp16 initializers -> cast_from_fp16_to_fp32 -> fp32 node/s", I think it is possible to do constant folding for fp16 initializers -> cast_from_fp16_to_fp32 to be fp32 initializer as long as the initializer is not in graph input. Then there is no need to add a new level.

@yuslepukhin
Copy link
Member

Transformers actually run in a loop, until no more graph modifications are made.

@sunnyshu-intel
Copy link
Contributor Author

Regarding to "fp16 initializers -> cast_from_fp16_to_fp32 -> fp32 node/s", I think it is possible to do constant folding for fp16 initializers -> cast_from_fp16_to_fp32 to be fp32 initializer as long as the initializer is not in graph input. Then there is no need to add a new level.

I have updated the description of this PR with a new Working tab. This particular transform depends on "Insert Cast Transforms" to detect the unsupported nodes and produce the intermediate representation.

@sunnyshu-intel
Copy link
Contributor Author

sunnyshu-intel commented Apr 9, 2025

Transformers actually run in a loop, until no more graph modifications are made.

To the best of my knowledge, Transformers of a particular Level runs in a loop, until no more graph modifications are required. This deduction of mine is based on GraphTransformerManager::ApplyTransformers functionality in onnxruntime/core/optimizer/graph_transformer_mgr.cc file.

As of now, our current graph optimizations flow (in this PR) is as follows.

image

If required, I can try to run all graph optimization in a loop in onnxruntime/core/session/inference_session.cc as follows.

image

@tianleiwu
Copy link
Contributor

@sunnyshu-intel,

I have concern on the design change: "Insert Copy nodes" assumes that partition of nodes to EP is finalized. We shall not re-run partition later.

I think there are two options:
(1) As a post-processing of "Insert Cast Nodes". We might not need rerun optimizations in this option.
(2) Add to either level 1 or level 2 optimization, then we can rerun optimizations once.

image

@sunnyshu-intel
Copy link
Contributor Author

@sunnyshu-intel,

I have concern on the design change: "Insert Copy nodes" assumes that partition of nodes to EP is finalized. We shall not re-run partition later.

I think there are two options: (1) As a post-processing of "Insert Cast Nodes". We might not need rerun optimizations in this option. (2) Add to either level 1 or level 2 optimization, then we can rerun optimizations once.

image

Ok, I was not aware that "Insert Copy Nodes" assumes that the partition of nodes to EP is finalized. In that case, yes I agree, we shouldn't be running the partitioning after the "Insert Copy Nodes".

Also, keeping this type of optimization under a different level (Level 4 in this case) might have some benefits as follows.

  • Ability to turn off the Fusion (FIGT) Optimization, this can be used when there is a considerable (and not desirable) increase in the memory footprint of the models as we upconvert initializers from FP16 to FP32.
  • Ability to only rerun Level 1 to Level 3 optimization if Fusion Optimization is applied. Keeping the current graph manipulations untouched (including Insert Cast Transforms).

I can move the Level 4 Fusion Optimization to execute before "Insert Copy Nodes" and re-run Level 1 -> Partitioning -> Level 2 - > Level 3 -> Insert Cast Nodes -> Level 4 in a loop until no more graph optimizations are applied. As follows.

image

In addition, we saw a considerable performance gains when we re-run the Level 1, 2, and 3 optimizations after this fusion optimization is applied. The reason behind re-running Level1, Partitioning, Level2, and Level3 graph transforms is that, after the fusion, the nodes are now in a format that might be supported by other graph transforms that were skipped before. Hence, some of the transforms not applied before are now valid and can be applied to create a more optimal graph for execution.

@tianleiwu
Copy link
Contributor

tianleiwu commented Apr 11, 2025

@sunnyshu-intel, could you merge latest main and resolve the conflicts.

level >=2 has assumption that the partition is done since those optimizations are provider specified, so we cannot run partition twice.

It is probably no need to add a new level. It is because the optimizer could reduce memory usage: Previously, it need fp16 initializer and also a temp buffer for fp32 after Cast. After the fusion, it only need memory for fp32 version of initializer. I do not think users need to exclude it explicitly.

I suggest to add the new optimizer to level 2, then the workflow is like:

Level 1 -> Partitioning -> loop{ Level 2 - > Level 3 -> Insert Cast Nodes } -> Insert Copy Nodes.

OR

loop{ Level 1 -> Partitioning (only once) -> Level 2 - > Level 3 -> Insert Cast Nodes } -> Insert Copy Nodes.

If you want to enable/disable it in testing, ORT has an internal option to disable some optimizers during creating session like this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants