-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][linalg] Decompose conv2d to series of conv1d #169082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
You can test this locally with the following command:git-clang-format --diff origin/main HEAD --extensions cpp -- mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp --diff_from_common_commit
View the diff from clang-format here.diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp
index e02755dd9..64ad41fd2 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeConv2DToConv1D.cpp
@@ -33,7 +33,8 @@ namespace {
/// Constraints:
/// - Height stride and dilation must be 1 (to allow contiguous reshaping).
/// - Width stride and dilation are preserved in the 1D convolution.
-struct DecomposeConv2DToConv1DPattern final : public OpRewritePattern<Conv2DNhwcHwcfOp> {
+struct DecomposeConv2DToConv1DPattern final
+ : public OpRewritePattern<Conv2DNhwcHwcfOp> {
using OpRewritePattern<Conv2DNhwcHwcfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(Conv2DNhwcHwcfOp convOp,
@@ -52,11 +53,14 @@ struct DecomposeConv2DToConv1DPattern final : public OpRewritePattern<Conv2DNhwc
auto stridesAttr = convOp.getStrides();
auto dilationsAttr = convOp.getDilations();
- SmallVector<int64_t> strides = llvm::to_vector(stridesAttr.getValues<int64_t>());
- SmallVector<int64_t> dilations = llvm::to_vector(dilationsAttr.getValues<int64_t>());
+ SmallVector<int64_t> strides =
+ llvm::to_vector(stridesAttr.getValues<int64_t>());
+ SmallVector<int64_t> dilations =
+ llvm::to_vector(dilationsAttr.getValues<int64_t>());
if (strides[0] != 1 || dilations[0] != 1) {
- return rewriter.notifyMatchFailure(convOp, "requires stride_h=1 and dilation_h=1");
+ return rewriter.notifyMatchFailure(
+ convOp, "requires stride_h=1 and dilation_h=1");
}
// 2. Get Dimensions
@@ -84,71 +88,83 @@ struct DecomposeConv2DToConv1DPattern final : public OpRewritePattern<Conv2DNhwc
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value one = arith::ConstantIndexOp::create(rewriter, loc, 1);
- auto scfLoop = scf::ForOp::create(rewriter,
- loc, zero, Kh, one, ValueRange{output},
+ auto scfLoop = scf::ForOp::create(
+ rewriter, loc, zero, Kh, one, ValueRange{output},
[&](OpBuilder &b, Location loc, Value r, ValueRange args) {
Value currentAccumulator = args[0];
// --- A. Extract Filter Slice ---
// Filter shape: [Kh, Kw, Cin, Cout] -> Slice at r: [1, Kw, Cin, Cout]
// We need to rank-reduce this to [Kw, Cin, Cout] for conv_1d.
- SmallVector<OpFoldResult> filterOffsets = {r, b.getIndexAttr(0), b.getIndexAttr(0), b.getIndexAttr(0)};
- SmallVector<OpFoldResult> filterSizes = {b.getIndexAttr(1), Kw, C_in, C_out};
- SmallVector<OpFoldResult> filterStrides = {b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1)};
+ SmallVector<OpFoldResult> filterOffsets = {
+ r, b.getIndexAttr(0), b.getIndexAttr(0), b.getIndexAttr(0)};
+ SmallVector<OpFoldResult> filterSizes = {b.getIndexAttr(1), Kw, C_in,
+ C_out};
+ SmallVector<OpFoldResult> filterStrides = {
+ b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1),
+ b.getIndexAttr(1)};
// Explicitly specify the desired result type (Rank 3)
- auto filterSliceType = RankedTensorType::get(
- {ShapedType::kDynamic, ShapedType::kDynamic, ShapedType::kDynamic},
- filterType.getElementType());
+ auto filterSliceType =
+ RankedTensorType::get({ShapedType::kDynamic, ShapedType::kDynamic,
+ ShapedType::kDynamic},
+ filterType.getElementType());
- Value filterSlice = tensor::ExtractSliceOp::create(b,
- loc, filterSliceType, filter, filterOffsets, filterSizes, filterStrides);
+ Value filterSlice = tensor::ExtractSliceOp::create(
+ b, loc, filterSliceType, filter, filterOffsets, filterSizes,
+ filterStrides);
// --- B. Extract Input Slice ---
// We need a view of the input shifted by 'r' along Height.
// Input: [N, H, W, C]. Slice starts at [0, r, 0, 0].
// Size: [N, H_out, W, C].
- // (Recall H_in = H_out + Kh - 1 generally, so H_out fits starting at r).
- SmallVector<OpFoldResult> inputOffsets = {b.getIndexAttr(0), r, b.getIndexAttr(0), b.getIndexAttr(0)};
+ // (Recall H_in = H_out + Kh - 1 generally, so H_out fits starting at
+ // r).
+ SmallVector<OpFoldResult> inputOffsets = {
+ b.getIndexAttr(0), r, b.getIndexAttr(0), b.getIndexAttr(0)};
SmallVector<OpFoldResult> inputSizes = {N, H_out, W_in, C_in};
- SmallVector<OpFoldResult> inputStrides = {b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1)};
+ SmallVector<OpFoldResult> inputStrides = {
+ b.getIndexAttr(1), b.getIndexAttr(1), b.getIndexAttr(1),
+ b.getIndexAttr(1)};
- Value inputSlice = tensor::ExtractSliceOp::create(b,
- loc, input, inputOffsets, inputSizes, inputStrides);
+ Value inputSlice = tensor::ExtractSliceOp::create(
+ b, loc, input, inputOffsets, inputSizes, inputStrides);
// --- C. Reshape Input for Conv1D ---
// Conv1D expects [Batch, Width, Channels].
// We have [N, H_out, W_in, C_in].
// We collapse N and H_out into a single Batch dimension.
- SmallVector<ReassociationIndices> collapseIndicesInput = {{0, 1}, {2}, {3}};
- Value reshapedInput = tensor::CollapseShapeOp::create(b,
- loc, inputSlice, collapseIndicesInput);
+ SmallVector<ReassociationIndices> collapseIndicesInput = {
+ {0, 1}, {2}, {3}};
+ Value reshapedInput = tensor::CollapseShapeOp::create(
+ b, loc, inputSlice, collapseIndicesInput);
// --- D. Reshape Accumulator for Conv1D ---
// Current Accumulator: [N, H_out, W_out, C_out].
// Target: [N * H_out, W_out, C_out].
- Value reshapedAcc = tensor::CollapseShapeOp::create(b,
- loc, currentAccumulator, collapseIndicesInput);
+ Value reshapedAcc = tensor::CollapseShapeOp::create(
+ b, loc, currentAccumulator, collapseIndicesInput);
// --- E. Perform Conv1D ---
// Op: linalg.conv_1d_nwc_wcf
- // Strides and Dilations for W are passed through from the original Op.
- // Original Strides: [Stride_H, Stride_W]. We take Stride_W.
+ // Strides and Dilations for W are passed through from the original
+ // Op. Original Strides: [Stride_H, Stride_W]. We take Stride_W.
auto strideW = strides[1];
auto dilationW = dilations[1];
- auto conv1d = Conv1DNwcWcfOp::create(b, loc,
- TypeRange{reshapedAcc.getType()},
- ValueRange{reshapedInput, filterSlice},
- ValueRange{reshapedAcc},
+ auto conv1d = Conv1DNwcWcfOp::create(
+ b, loc, TypeRange{reshapedAcc.getType()},
+ ValueRange{reshapedInput, filterSlice}, ValueRange{reshapedAcc},
b.getDenseI64ArrayAttr({strideW}),
b.getDenseI64ArrayAttr({dilationW}));
// --- F. Expand Result back to 4D ---
// Result: [N * H_out, W_out, C_out] -> [N, H_out, W_out, C_out]
- // We use the Type of the currentAccumulator to ensure correct dynamic dim reconstruction.
- Value expandedResult = tensor::ExpandShapeOp::create(b,
- loc, currentAccumulator.getType(), conv1d.getResult(0), collapseIndicesInput);
+ // We use the Type of the currentAccumulator to ensure correct dynamic
+ // dim reconstruction.
+ Value expandedResult = tensor::ExpandShapeOp::create(
+ b, loc, currentAccumulator.getType(), conv1d.getResult(0),
+ collapseIndicesInput);
scf::YieldOp::create(b, loc, expandedResult);
});
@@ -160,7 +176,9 @@ struct DecomposeConv2DToConv1DPattern final : public OpRewritePattern<Conv2DNhwc
} // namespace
-struct LinalgDecomposeConv2DtoConv1D final : public impl::LinalgDecomposeConv2DToConv1DBase<LinalgDecomposeConv2DtoConv1D> {
+struct LinalgDecomposeConv2DtoConv1D final
+ : public impl::LinalgDecomposeConv2DToConv1DBase<
+ LinalgDecomposeConv2DtoConv1D> {
using Base::Base;
void runOnOperation() override {
|
🐧 Linux x64 Test ResultsThe build failed before running any tests. Click on a failure below to see the details. tools/mlir/lib/Dialect/Linalg/Transforms/CMakeFiles/obj.MLIRLinalgTransforms.dir/DecomposeConv2DToConv1D.cpp.oIf these failures are unrelated to your changes (for example tests are broken or flaky at HEAD), please open an issue at https://github.com/llvm/llvm-project/issues and add the |
WIP...