Skip to content

Merge nested concat Ops optimization pass in ONNX dialect #3111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

Arkar-Hema
Copy link
Contributor

Concat merging

The RecomposeConcat pass is an ONNX-MLIR optimization pass that simplifies and merges ONNXConcatOp operations to improve model performance and reduce redundant operations.

The pass optimizes Concat operations by:

  • Identifies Concat layers with the same axis and combines their inputs into a single Concat node.
  • If the Concat has only one input, it directly replaces the Concat with that input.

Input Representation

  • Input tensors: {X1,X2,...,Xn}
  • Concat axis: A
  • Shape of each tensor: (N, C, H, W)

Original Flow (Nested Concat Operations)

  • A Concat operation has an input that is itself a Concat:
    Y=Concat(X1,Concat(X2,X3),X4)
  • Computational Cost: Each Concat operation involves memory allocations and data movement
    image

Optimized Flow (Flattening Nested Concats)

  • Instead of a nested structure, the pass flattens the Concat operations:
    Y=Concat(X1,X2,X3,X4)
  • Computational Savings: Since Concat now operates in a single step, memory and computation are optimized
    image
  • The reduction factor in computation can be expressed as:
    image

@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@Arkar-Hema Arkar-Hema closed this Apr 8, 2025
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@Arkar-Hema Arkar-Hema reopened this Apr 8, 2025
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
@jenkins-droid
Copy link
Collaborator

Can one of the admins verify this patch?

@@ -1385,6 +1385,7 @@ void DecomposeONNXToONNXPass::runOnOperation() {
op.getValueStringAttr() || op.getValueStringsAttr());
});


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please undo this change.

@@ -602,6 +602,57 @@ struct RecomposeQLinearMatMulFromQuantizeLinearPattern
}
};

/// Merges nested ONNXConcatOps
struct RecomposeConcatPattern : public OpRewritePattern<ONNXConcatOp> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move this pattern into Canonicalize.cpp? I would helpful to call this everytime we have concat. Putting it inside Recompose.cpp may only trigger at the beginning of the compilation.


// Helper function to check if an input is a mergeable Concat.
static bool isMergeableConcat(Value input, int64_t axis) {
auto innerConcat = input.getDefiningOp<ONNXConcatOp>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use concrete type here. I think it's ONNXConcatOp innerConcat .... We avoid using auto as much as possible.

if (isMergeableConcat(input, concatOp.getAxis())) {
merged = true;
// Remove the nested concat and append its inputs.
newInputs.pop_back();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

push and pop would be expensive. Why don't you push input in the else part of if(isMergeableConcat(input, concatOp.getAxis())) {?

merged = true;
// Remove the nested concat and append its inputs.
newInputs.pop_back();
auto innerConcat = cast<ONNXConcatOp>(input.getDefiningOp());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do this instead: ONNXConcatOp innerConcat = input.getDefiningOp<ONNXConcatOp>()

}

// If there is only a single input, replace the concat with that input.
if (concatOp.getOperands().size() == 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a light check, please do it at the beginning before flattening.


// If there is only a single input, replace the concat with that input.
if (concatOp.getOperands().size() == 1) {
rewriter.replaceOp(concatOp, concatOp.getOperands()[0]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: concatOp.getOperands()[0] -> inputs[0]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants