-
Notifications
You must be signed in to change notification settings - Fork 337
Merge nested concat Ops optimization pass in ONNX dialect #3111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
Can one of the admins verify this patch? |
Signed-off-by: Arkar-Hema <[email protected]>
Signed-off-by: Arkar-Hema <[email protected]>
Can one of the admins verify this patch? |
@@ -1385,6 +1385,7 @@ void DecomposeONNXToONNXPass::runOnOperation() { | |||
op.getValueStringAttr() || op.getValueStringsAttr()); | |||
}); | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please undo this change.
@@ -602,6 +602,57 @@ struct RecomposeQLinearMatMulFromQuantizeLinearPattern | |||
} | |||
}; | |||
|
|||
/// Merges nested ONNXConcatOps | |||
struct RecomposeConcatPattern : public OpRewritePattern<ONNXConcatOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you move this pattern into Canonicalize.cpp? I would helpful to call this everytime we have concat. Putting it inside Recompose.cpp may only trigger at the beginning of the compilation.
|
||
// Helper function to check if an input is a mergeable Concat. | ||
static bool isMergeableConcat(Value input, int64_t axis) { | ||
auto innerConcat = input.getDefiningOp<ONNXConcatOp>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use concrete type here. I think it's ONNXConcatOp innerConcat ...
. We avoid using auto
as much as possible.
if (isMergeableConcat(input, concatOp.getAxis())) { | ||
merged = true; | ||
// Remove the nested concat and append its inputs. | ||
newInputs.pop_back(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
push and pop would be expensive. Why don't you push input
in the else
part of if(isMergeableConcat(input, concatOp.getAxis())) {
?
merged = true; | ||
// Remove the nested concat and append its inputs. | ||
newInputs.pop_back(); | ||
auto innerConcat = cast<ONNXConcatOp>(input.getDefiningOp()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do this instead: ONNXConcatOp innerConcat = input.getDefiningOp<ONNXConcatOp>()
} | ||
|
||
// If there is only a single input, replace the concat with that input. | ||
if (concatOp.getOperands().size() == 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a light check, please do it at the beginning before flattening.
|
||
// If there is only a single input, replace the concat with that input. | ||
if (concatOp.getOperands().size() == 1) { | ||
rewriter.replaceOp(concatOp, concatOp.getOperands()[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: concatOp.getOperands()[0]
-> inputs[0]
Concat merging
The RecomposeConcat pass is an ONNX-MLIR optimization pass that simplifies and merges ONNXConcatOp operations to improve model performance and reduce redundant operations.
The pass optimizes Concat operations by:
Input Representation
Original Flow (Nested Concat Operations)
Y=Concat(X1,Concat(X2,X3),X4)
Optimized Flow (Flattening Nested Concats)
Y=Concat(X1,X2,X3,X4)