|
| 1 | +//===- FuseFC.cpp - Fully Connected Fusion Pass ------------------------===// |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | +// |
| 15 | +//===----------------------------------------------------------------------===// |
| 16 | +// This pass is going to deal with pattern like: |
| 17 | + // %collapsed_29 = tensor.collapse_shape %46 [[0, 1], [2]] : tensor<1x1024x1536xf32> into tensor<1024x1536xf32> |
| 18 | + // %51 = bufferization.alloc_tensor() : tensor<1536x256xf32> |
| 19 | + // %transposed_30 = linalg.transpose ins(%arg7 : tensor<256x1536xf32>) outs(%51 : tensor<1536x256xf32>) permutation = [1, 0] |
| 20 | + // %52 = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%collapsed_29, %transposed_30 : tensor<1024x1536xf32>, tensor<1536x256xf32>) outs(%cst_6 : tensor<1024x256xf32>) -> tensor<1024x256xf32> |
| 21 | + // %expanded_31 = tensor.expand_shape %arg8 [[0, 1]] output_shape [1, 256] : tensor<256xf32> into tensor<1x256xf32> |
| 22 | + // %53 = bufferization.alloc_tensor() : tensor<1024x256xf32> |
| 23 | + // %54 = linalg.generic {indexing_maps = [#map10, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%expanded_31, %52 : tensor<1x256xf32>, tensor<1024x256xf32>) outs(%53 : tensor<1024x256xf32>) { |
| 24 | + // ^bb0(%in: f32, %in_1538: f32, %out: f32): |
| 25 | + // %3044 = arith.addf %in, %in_1538 : f32 |
| 26 | + // linalg.yield %3044 : f32 |
| 27 | + // } -> tensor<1024x256xf32> |
| 28 | +// This pass will transform it into two linalg.generic Operations. So that they could be fused later in affine level. |
| 29 | +// Redundent operations needs to be cleaned with canonicalization pass. |
| 30 | +//===----------------------------------------------------------------------===// |
| 31 | + |
| 32 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
| 33 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 34 | +#include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 35 | +#include "mlir/Dialect/Linalg/Utils/Utils.h" |
| 36 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 37 | +#include "mlir/IR/AffineMap.h" |
| 38 | +#include "mlir/IR/Builders.h" |
| 39 | +#include "mlir/IR/MLIRContext.h" |
| 40 | +#include "mlir/IR/PatternMatch.h" |
| 41 | +#include "mlir/Pass/Pass.h" |
| 42 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 43 | +#include "llvm/Support/Debug.h" |
| 44 | + |
| 45 | +#define DEBUG_TYPE "fuse-fc" |
| 46 | + |
| 47 | +using namespace mlir; |
| 48 | + |
| 49 | +namespace { |
| 50 | + |
| 51 | +// Helper function to check if a linalg.generic op is elementwise |
| 52 | +static bool isElementwise(linalg::GenericOp op) { |
| 53 | + return llvm::all_of(op.getIteratorTypesArray(), |
| 54 | + [](utils::IteratorType type) { |
| 55 | + return type == utils::IteratorType::parallel; |
| 56 | + }); |
| 57 | +} |
| 58 | + |
| 59 | +class FuseFCPattern : public OpRewritePattern<linalg::GenericOp> { |
| 60 | +public: |
| 61 | + using OpRewritePattern<linalg::GenericOp>::OpRewritePattern; |
| 62 | + |
| 63 | + LogicalResult matchAndRewrite(linalg::GenericOp addOp, |
| 64 | + PatternRewriter &rewriter) const override { |
| 65 | + LLVM_DEBUG(llvm::dbgs() << "\n=== FuseFCPattern: Checking GenericOp ===\n"); |
| 66 | + LLVM_DEBUG(addOp.print(llvm::dbgs())); |
| 67 | + LLVM_DEBUG(llvm::dbgs() << "\n"); |
| 68 | + |
| 69 | + // 1. Check if the current op is a bias-add (elementwise with 2 loops) |
| 70 | + if (!isElementwise(addOp)) { |
| 71 | + LLVM_DEBUG(llvm::dbgs() << " -> Not elementwise, skipping\n"); |
| 72 | + return failure(); |
| 73 | + } |
| 74 | + |
| 75 | + if (addOp.getNumLoops() != 2) { |
| 76 | + LLVM_DEBUG(llvm::dbgs() << " -> Number of loops != 2 (got " |
| 77 | + << addOp.getNumLoops() << "), skipping\n"); |
| 78 | + return failure(); |
| 79 | + } |
| 80 | + |
| 81 | + if (addOp.getNumDpsInputs() != 2 || addOp.getNumDpsInits() != 1) { |
| 82 | + LLVM_DEBUG(llvm::dbgs() << " -> Wrong number of inputs/outputs (inputs=" |
| 83 | + << addOp.getNumDpsInputs() << ", outputs=" |
| 84 | + << addOp.getNumDpsInits() << "), skipping\n"); |
| 85 | + return failure(); |
| 86 | + } |
| 87 | + |
| 88 | + // Check for add operation in the body |
| 89 | + auto &body = addOp.getRegion().front(); |
| 90 | + bool hasAdd = !body.getOps<arith::AddFOp>().empty() || |
| 91 | + !body.getOps<arith::AddIOp>().empty(); |
| 92 | + if (!hasAdd) { |
| 93 | + LLVM_DEBUG(llvm::dbgs() << " -> Body doesn't contain add operation, skipping\n"); |
| 94 | + return failure(); |
| 95 | + } |
| 96 | + |
| 97 | + LLVM_DEBUG(llvm::dbgs() << " -> Confirmed as elementwise add operation\n"); |
| 98 | + |
| 99 | + // 2. Find which operand is the matmul result and which is the bias |
| 100 | + Value matmulResult, bias; |
| 101 | + linalg::MatmulOp matmulOp; |
| 102 | + int matmulIdx = -1; |
| 103 | + |
| 104 | + for (int i = 0; i < 2; ++i) { |
| 105 | + auto input = addOp.getDpsInputOperand(i)->get(); |
| 106 | + LLVM_DEBUG(llvm::dbgs() << " -> Checking input " << i << ": "); |
| 107 | + LLVM_DEBUG(input.print(llvm::dbgs())); |
| 108 | + LLVM_DEBUG(llvm::dbgs() << "\n"); |
| 109 | + |
| 110 | + if (auto defOp = input.getDefiningOp<linalg::MatmulOp>()) { |
| 111 | + LLVM_DEBUG(llvm::dbgs() << " Found MatmulOp at index " << i << "\n"); |
| 112 | + matmulOp = defOp; |
| 113 | + matmulResult = input; |
| 114 | + matmulIdx = i; |
| 115 | + bias = addOp.getDpsInputOperand(1 - i)->get(); |
| 116 | + break; |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + if (!matmulOp) { |
| 121 | + LLVM_DEBUG(llvm::dbgs() << " -> No MatmulOp input found, skipping\n"); |
| 122 | + return failure(); |
| 123 | + } |
| 124 | + |
| 125 | + LLVM_DEBUG(llvm::dbgs() << " -> MatmulOp found:\n"); |
| 126 | + LLVM_DEBUG(matmulOp.print(llvm::dbgs())); |
| 127 | + LLVM_DEBUG(llvm::dbgs() << "\n"); |
| 128 | + |
| 129 | + // 3. The second operand of the matmul must be the result of a transpose |
| 130 | + auto transposeOp = |
| 131 | + matmulOp.getDpsInputOperand(1)->get().getDefiningOp<linalg::TransposeOp>(); |
| 132 | + if (!transposeOp) { |
| 133 | + LLVM_DEBUG(llvm::dbgs() << " -> Second matmul operand is not TransposeOp, skipping\n"); |
| 134 | + return failure(); |
| 135 | + } |
| 136 | + |
| 137 | + LLVM_DEBUG(llvm::dbgs() << " -> TransposeOp found:\n"); |
| 138 | + LLVM_DEBUG(transposeOp.print(llvm::dbgs())); |
| 139 | + LLVM_DEBUG(llvm::dbgs() << "\n"); |
| 140 | + |
| 141 | + // Check transpose permutation is [1, 0] |
| 142 | + auto perm = transposeOp.getPermutation(); |
| 143 | + if (perm.size() != 2 || perm[0] != 1 || perm[1] != 0) { |
| 144 | + LLVM_DEBUG(llvm::dbgs() << " -> Transpose permutation is not [1, 0], skipping\n"); |
| 145 | + return failure(); |
| 146 | + } |
| 147 | + |
| 148 | + // 4. The bias must be broadcastable |
| 149 | + LLVM_DEBUG(llvm::dbgs() << " -> Checking bias: "); |
| 150 | + LLVM_DEBUG(bias.print(llvm::dbgs())); |
| 151 | + LLVM_DEBUG(llvm::dbgs() << "\n"); |
| 152 | + |
| 153 | + // Trace through expand_shape if present |
| 154 | + Value originalBias = bias; |
| 155 | + if (auto expandOp = bias.getDefiningOp<tensor::ExpandShapeOp>()) { |
| 156 | + LLVM_DEBUG(llvm::dbgs() << " Bias comes from expand_shape, using source\n"); |
| 157 | + originalBias = expandOp.getSrc(); |
| 158 | + } |
| 159 | + |
| 160 | + auto biasType = mlir::dyn_cast<RankedTensorType>(originalBias.getType()); |
| 161 | + if (!biasType) { |
| 162 | + LLVM_DEBUG(llvm::dbgs() << " -> Bias is not RankedTensorType, skipping\n"); |
| 163 | + return failure(); |
| 164 | + } |
| 165 | + |
| 166 | + LLVM_DEBUG(llvm::dbgs() << " Bias type: "); |
| 167 | + LLVM_DEBUG(biasType.print(llvm::dbgs())); |
| 168 | + LLVM_DEBUG(llvm::dbgs() << "\n"); |
| 169 | + |
| 170 | + // Bias must be rank 1 or rank 2 with one dimension being 1 |
| 171 | + if (biasType.getRank() != 1 && biasType.getRank() != 2) { |
| 172 | + LLVM_DEBUG(llvm::dbgs() << " -> Bias rank is " << biasType.getRank() |
| 173 | + << " (expected 1 or 2), skipping\n"); |
| 174 | + return failure(); |
| 175 | + } |
| 176 | + |
| 177 | + if (biasType.getRank() == 2 && biasType.getShape()[0] != 1 && |
| 178 | + biasType.getShape()[1] != 1) { |
| 179 | + LLVM_DEBUG(llvm::dbgs() << " -> 2D bias doesn't have dimension of size 1, skipping\n"); |
| 180 | + return failure(); |
| 181 | + } |
| 182 | + |
| 183 | + LLVM_DEBUG(llvm::dbgs() << "\n=== Pattern matched! Starting rewrite ===\n"); |
| 184 | + |
| 185 | + Location loc = addOp.getLoc(); |
| 186 | + MLIRContext *ctx = rewriter.getContext(); |
| 187 | + |
| 188 | + // Get the operands from the original operations |
| 189 | + Value inputA = matmulOp.getDpsInputOperand(0)->get(); |
| 190 | + Value inputB = transposeOp.getDpsInputOperand(0)->get(); // Before transpose |
| 191 | + Value output = addOp.getDpsInitOperand(0)->get(); |
| 192 | + |
| 193 | + auto outputType = mlir::cast<RankedTensorType>(output.getType()); |
| 194 | + Type elementTy = outputType.getElementType(); |
| 195 | + |
| 196 | + LLVM_DEBUG(llvm::dbgs() << " Input A type: " << inputA.getType() << "\n"); |
| 197 | + LLVM_DEBUG(llvm::dbgs() << " Input B type: " << inputB.getType() << "\n"); |
| 198 | + LLVM_DEBUG(llvm::dbgs() << " Bias type: " << originalBias.getType() << "\n"); |
| 199 | + LLVM_DEBUG(llvm::dbgs() << " Output type: " << outputType << "\n"); |
| 200 | + |
| 201 | + // --- Part 1: Create the bias broadcast operation --- |
| 202 | + LLVM_DEBUG(llvm::dbgs() << "\n--- Creating bias broadcast operation ---\n"); |
| 203 | + |
| 204 | + SmallVector<AffineMap> biasMaps; |
| 205 | + if (biasType.getRank() == 1) { |
| 206 | + // from tensor<N> to tensor<M, N> |
| 207 | + biasMaps.push_back(AffineMap::get(2, 0, {rewriter.getAffineDimExpr(1)}, ctx)); |
| 208 | + LLVM_DEBUG(llvm::dbgs() << " Broadcast 1D bias along first dimension\n"); |
| 209 | + } else { // Rank 2 |
| 210 | + if (biasType.getShape()[0] == 1) { |
| 211 | + // tensor<1, N> -> broadcast along dim 0 |
| 212 | + biasMaps.push_back(AffineMap::get(2, 0, |
| 213 | + {rewriter.getAffineConstantExpr(0), rewriter.getAffineDimExpr(1)}, ctx)); |
| 214 | + LLVM_DEBUG(llvm::dbgs() << " Broadcast 2D bias (1xN) along first dimension\n"); |
| 215 | + } else { |
| 216 | + // tensor<N, 1> -> broadcast along dim 1 |
| 217 | + biasMaps.push_back(AffineMap::get(2, 0, |
| 218 | + {rewriter.getAffineDimExpr(0), rewriter.getAffineConstantExpr(0)}, ctx)); |
| 219 | + LLVM_DEBUG(llvm::dbgs() << " Broadcast 2D bias (Nx1) along second dimension\n"); |
| 220 | + } |
| 221 | + } |
| 222 | + biasMaps.push_back(rewriter.getMultiDimIdentityMap(2)); |
| 223 | + |
| 224 | + SmallVector<utils::IteratorType> biasIteratorTypes = { |
| 225 | + utils::IteratorType::parallel, utils::IteratorType::parallel}; |
| 226 | + |
| 227 | + auto broadcastedBiasOp = rewriter.create<linalg::GenericOp>( |
| 228 | + loc, |
| 229 | + /*resultTensorTypes=*/outputType, |
| 230 | + /*inputs=*/originalBias, |
| 231 | + /*outputs=*/output, |
| 232 | + /*indexingMaps=*/biasMaps, |
| 233 | + /*iteratorTypes=*/biasIteratorTypes, |
| 234 | + [&](OpBuilder &b, Location nestedLoc, ValueRange args) { |
| 235 | + b.create<linalg::YieldOp>(nestedLoc, args[0]); |
| 236 | + }); |
| 237 | + Value broadcastedBiasResult = broadcastedBiasOp.getResult(0); |
| 238 | + |
| 239 | + LLVM_DEBUG(llvm::dbgs() << " Created bias broadcast op:\n"); |
| 240 | + LLVM_DEBUG(broadcastedBiasOp.print(llvm::dbgs())); |
| 241 | + LLVM_DEBUG(llvm::dbgs() << "\n"); |
| 242 | + |
| 243 | + // --- Part 2: Create the matmul with folded transpose --- |
| 244 | + LLVM_DEBUG(llvm::dbgs() << "\n--- Creating fused matmul operation ---\n"); |
| 245 | + |
| 246 | + // A[i,k] * B[j,k] -> C[i,j] (with B accessed as transpose) |
| 247 | + AffineMap mapA = AffineMap::get(3, 0, |
| 248 | + {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2)}, ctx); |
| 249 | + AffineMap mapB = AffineMap::get(3, 0, |
| 250 | + {rewriter.getAffineDimExpr(1), rewriter.getAffineDimExpr(2)}, ctx); |
| 251 | + AffineMap mapC = AffineMap::get(3, 0, |
| 252 | + {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}, ctx); |
| 253 | + |
| 254 | + LLVM_DEBUG(llvm::dbgs() << " Map A (input): " << mapA << "\n"); |
| 255 | + LLVM_DEBUG(llvm::dbgs() << " Map B (transposed): " << mapB << "\n"); |
| 256 | + LLVM_DEBUG(llvm::dbgs() << " Map C (output): " << mapC << "\n"); |
| 257 | + |
| 258 | + SmallVector<AffineMap> matmulMaps = {mapA, mapB, mapC}; |
| 259 | + SmallVector<utils::IteratorType> matmulIteratorTypes = { |
| 260 | + utils::IteratorType::parallel, utils::IteratorType::parallel, |
| 261 | + utils::IteratorType::reduction}; |
| 262 | + |
| 263 | + auto fusedMatmulOp = rewriter.create<linalg::GenericOp>( |
| 264 | + loc, |
| 265 | + /*resultTensorTypes=*/outputType, |
| 266 | + /*inputs=*/ValueRange{inputA, inputB}, |
| 267 | + /*outputs=*/broadcastedBiasResult, |
| 268 | + /*indexingMaps=*/matmulMaps, |
| 269 | + /*iteratorTypes=*/matmulIteratorTypes, |
| 270 | + [&](OpBuilder &b, Location nestedLoc, ValueRange args) { |
| 271 | + // args[0] = A[i,k], args[1] = B[j,k], args[2] = C[i,j] (init with bias) |
| 272 | + Value mulResult; |
| 273 | + if (isa<FloatType>(elementTy)) { |
| 274 | + mulResult = b.create<arith::MulFOp>(nestedLoc, args[0], args[1]); |
| 275 | + } else { |
| 276 | + mulResult = b.create<arith::MulIOp>(nestedLoc, args[0], args[1]); |
| 277 | + } |
| 278 | + Value addResult; |
| 279 | + if (isa<FloatType>(elementTy)) { |
| 280 | + addResult = b.create<arith::AddFOp>(nestedLoc, mulResult, args[2]); |
| 281 | + } else { |
| 282 | + addResult = b.create<arith::AddIOp>(nestedLoc, mulResult, args[2]); |
| 283 | + } |
| 284 | + b.create<linalg::YieldOp>(nestedLoc, addResult); |
| 285 | + }); |
| 286 | + |
| 287 | + LLVM_DEBUG(llvm::dbgs() << " Created fused matmul op:\n"); |
| 288 | + LLVM_DEBUG(fusedMatmulOp.print(llvm::dbgs())); |
| 289 | + LLVM_DEBUG(llvm::dbgs() << "\n"); |
| 290 | + |
| 291 | + rewriter.replaceOp(addOp, fusedMatmulOp.getResults()); |
| 292 | + |
| 293 | + LLVM_DEBUG(llvm::dbgs() << "=== Rewrite successful ===\n\n"); |
| 294 | + return success(); |
| 295 | + } |
| 296 | +}; |
| 297 | + |
| 298 | +struct FuseFCPass : public PassWrapper<FuseFCPass, OperationPass<func::FuncOp>> { |
| 299 | + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuseFCPass) |
| 300 | + |
| 301 | + StringRef getArgument() const final { return "fuse-fc"; } |
| 302 | + |
| 303 | + StringRef getDescription() const final { |
| 304 | + return "Fuse linalg.transpose + linalg.matmul + bias-add into a " |
| 305 | + "sequence of linalg.generic ops."; |
| 306 | + } |
| 307 | + |
| 308 | + void runOnOperation() override { |
| 309 | + LLVM_DEBUG(llvm::dbgs() << "\n\n====================================\n"); |
| 310 | + LLVM_DEBUG(llvm::dbgs() << "Starting FuseFCPass\n"); |
| 311 | + LLVM_DEBUG(llvm::dbgs() << "====================================\n\n"); |
| 312 | + |
| 313 | + MLIRContext *context = &getContext(); |
| 314 | + RewritePatternSet patterns(context); |
| 315 | + patterns.add<FuseFCPattern>(context); |
| 316 | + |
| 317 | + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { |
| 318 | + LLVM_DEBUG(llvm::dbgs() << "FuseFCPass: Pattern application failed\n"); |
| 319 | + signalPassFailure(); |
| 320 | + } else { |
| 321 | + LLVM_DEBUG(llvm::dbgs() << "\n====================================\n"); |
| 322 | + LLVM_DEBUG(llvm::dbgs() << "FuseFCPass completed successfully\n"); |
| 323 | + LLVM_DEBUG(llvm::dbgs() << "====================================\n\n"); |
| 324 | + } |
| 325 | + } |
| 326 | +}; |
| 327 | + |
| 328 | +} // namespace |
| 329 | + |
| 330 | +namespace mlir { |
| 331 | +namespace buddy { |
| 332 | +void registerFuseFCPass() { |
| 333 | + PassRegistration<FuseFCPass>(); |
| 334 | +} |
| 335 | +} // namespace buddy |
| 336 | +} // namespace mlir |
0 commit comments