diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 666c0cde71..ff79b0e7d5 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -1244,6 +1244,7 @@ def ONNXCompressOp:ONNX_Op<"Compress", def ONNXConcatOp:ONNX_Op<"Concat", [Pure, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { + let hasCanonicalizer = 1; let summary = "ONNX Concat operation"; let description = [{ Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on. diff --git a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp index 62957db7af..f2cd70c534 100644 --- a/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp +++ b/src/Dialect/ONNX/ONNXOps/Canonicalize.cpp @@ -1502,6 +1502,62 @@ class FuseTwoReshapesPattern : public OpRewritePattern { } }; +// ============================================================================= +// Rewrite pattern concat +// ============================================================================= + +struct RecomposeConcatPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + // Helper function to check if an input is a mergeable Concat. + static bool isMergeableConcat(Value input, int64_t axis) { + ONNXConcatOp concatOp = input.getDefiningOp(); + if (!concatOp) + return false; + return (concatOp.getAxis() == axis) && (concatOp.getResult().hasOneUse()); + } + + LogicalResult matchAndRewrite( + ONNXConcatOp concatOp, PatternRewriter &rewriter) const final { + Location loc = concatOp.getLoc(); + ValueRange inputs = concatOp.getOperands(); + int64_t axis = concatOp.getAxis(); + + // If there is only a single input, replace the concat with that input. + if (inputs.size() == 1) { + rewriter.replaceOp(concatOp, inputs[0]); + return success(); + } + + SmallVector newInputs; + bool merged = false; + + // Flatten nested concat nodes. + for (Value input : inputs) { + if (isMergeableConcat(input, axis)) { + // Remove the nested concat and append its inputs. + ONNXConcatOp innerConcat = cast(input.getDefiningOp()); + newInputs.append( + innerConcat.getOperands().begin(), innerConcat.getOperands().end()); + merged = true; + } else { + // Push non-mergeable input. + newInputs.push_back(input); + } + } + + if (merged) { + // Create a new ONNXConcat op with the flattened inputs. + auto newConcat = rewriter.create( + loc, concatOp.getResult().getType(), newInputs, axis); + rewriter.replaceOp(concatOp, newConcat.getResult()); + return success(); + } + + return failure(); + } +}; + // ============================================================================= // Rewrite pattern LayerNormalization // ============================================================================= @@ -1722,6 +1778,12 @@ void ONNXCastOp::getCanonicalizationPatterns( // result.insert(context); } +/// on the ONNXConcatOp. +void ONNXConcatOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.insert(context); +} + /// on the ONNXConstantOp. void ONNXConstantOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) {} diff --git a/test/mlir/onnx/onnx_canonicalization.mlir b/test/mlir/onnx/onnx_canonicalization.mlir index 98f2ad5adb..7c9962d65b 100644 --- a/test/mlir/onnx/onnx_canonicalization.mlir +++ b/test/mlir/onnx/onnx_canonicalization.mlir @@ -1899,6 +1899,22 @@ func.func @test_remove_where_equal_4(%arg0: tensor) -> tensor<2xi64> { // ----- +func.func @test_recompose_concat(%arg0: tensor<1x3x4xf32>, %arg1: tensor<1x3x4xf32> ) -> tensor<1x12x4xf32> { +%0 = "onnx.Concat"(%arg0, %arg1) {axis = 1 : si64, onnx_node_name = "onnx.Concat_0"} : (tensor<1x3x4xf32>, tensor<1x3x4xf32>) -> tensor<1x6x4xf32> +%1 = "onnx.Concat"(%0, %arg0) {axis = 1 : si64, onnx_node_name = "onnx.Concat_1"} : (tensor<1x6x4xf32>, tensor<1x3x4xf32>) -> tensor<1x9x4xf32> +%2 = "onnx.Concat"(%1, %arg1) {axis = 1 : si64, onnx_node_name = "onnx.Concat_2"} : (tensor<1x9x4xf32>, tensor<1x3x4xf32>) -> tensor<1x12x4xf32> +return %2 : tensor<1x12x4xf32> + + // CHECK-LABEL: func @test_recompose_concat + // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4xf32>, [[PARAM_1_:%.+]]: tensor<1x3x4xf32>) -> tensor<1x12x4xf32> { + // CHECK: [[FINAL_OUT:%.+]] = "onnx.Concat"([[PARAM_0_]], [[PARAM_1_]], [[PARAM_0_]], [[PARAM_1_]]) + // CHECK-SAME: {axis = 1 : si64} + // CHECK-NEXT: return [[FINAL_OUT]] : tensor<1x12x4xf32> + +} + +// ----- + // Not rewriting since the operand in ConcatOp is neither DimOp nor ConstantOp. func.func @test_remove_where_equal_5(%arg0: tensor, %arg1: tensor<1xi64>, %arg2: tensor<1xi64>) -> tensor<2xi64> { %0 = onnx.Constant dense<-1> : tensor<2xi64> diff --git a/utils/gen_onnx_mlir.py b/utils/gen_onnx_mlir.py index 355569deab..b072d25b17 100755 --- a/utils/gen_onnx_mlir.py +++ b/utils/gen_onnx_mlir.py @@ -330,6 +330,7 @@ "Add", "And", "Cast", + "Concat", "Constant", "DepthToSpace", "DequantizeLinear",