Skip to content

Commit 5634377

Browse files
committed
[midend] Add FuseFC pass
Fuse fully connected layer operations on linalg-on-tensor into two linalg.generic. These two operations could be fused into on with affine-loop-fusion. Redundent operations needs to be erased with canonicalize pass. .
1 parent 6996a63 commit 5634377

File tree

8 files changed

+353
-0
lines changed

8 files changed

+353
-0
lines changed

examples/BuddyDeepSeekR1/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ add_custom_command(
3333
${BUDDY_BINARY_DIR}/buddy-opt
3434
-eliminate-empty-tensors
3535
-empty-tensor-to-alloc-tensor
36+
-fuse-fc
37+
-canonicalize
3638
-one-shot-bufferize=${BUFFERIZE_SIMPLE_OPTS}
3739
-expand-strided-metadata
3840
-ownership-based-buffer-deallocation
@@ -78,6 +80,8 @@ add_custom_command(
7880
${BUDDY_BINARY_DIR}/buddy-opt
7981
-eliminate-empty-tensors
8082
-empty-tensor-to-alloc-tensor
83+
-fuse-fc
84+
-canonicalize
8185
-convert-elementwise-to-linalg
8286
-one-shot-bufferize=${BUFFERIZE_SIMPLE_OPTS}
8387
-expand-strided-metadata

midend/lib/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ set(LinkedLibs
2525
BatchMatMulOptimization
2626
MatMulParallelVectorization
2727
TransposeOptimization
28+
FullyConnectedFusion
2829
)
2930

3031

midend/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ add_subdirectory(MLIRGPU)
1717
add_subdirectory(VIRToVector)
1818
add_subdirectory(LinalgToVIR)
1919
add_subdirectory(GraphRedundancyElimination)
20+
add_subdirectory(FullyConnectedFusion)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
add_mlir_library(FullyConnectedFusion
2+
FuseFC.cpp
3+
LINK_LIBS PUBLIC
4+
BuddyUtils
5+
)
Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
//===- FuseFC.cpp - Fully Connected Fusion Pass ------------------------===//
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
// This pass is going to deal with pattern like:
17+
// %collapsed_29 = tensor.collapse_shape %46 [[0, 1], [2]] : tensor<1x1024x1536xf32> into tensor<1024x1536xf32>
18+
// %51 = bufferization.alloc_tensor() : tensor<1536x256xf32>
19+
// %transposed_30 = linalg.transpose ins(%arg7 : tensor<256x1536xf32>) outs(%51 : tensor<1536x256xf32>) permutation = [1, 0]
20+
// %52 = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%collapsed_29, %transposed_30 : tensor<1024x1536xf32>, tensor<1536x256xf32>) outs(%cst_6 : tensor<1024x256xf32>) -> tensor<1024x256xf32>
21+
// %expanded_31 = tensor.expand_shape %arg8 [[0, 1]] output_shape [1, 256] : tensor<256xf32> into tensor<1x256xf32>
22+
// %53 = bufferization.alloc_tensor() : tensor<1024x256xf32>
23+
// %54 = linalg.generic {indexing_maps = [#map10, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%expanded_31, %52 : tensor<1x256xf32>, tensor<1024x256xf32>) outs(%53 : tensor<1024x256xf32>) {
24+
// ^bb0(%in: f32, %in_1538: f32, %out: f32):
25+
// %3044 = arith.addf %in, %in_1538 : f32
26+
// linalg.yield %3044 : f32
27+
// } -> tensor<1024x256xf32>
28+
// This pass will transform it into two linalg.generic Operations. So that they could be fused later in affine level.
29+
// Redundent operations needs to be cleaned with canonicalization pass.
30+
//===----------------------------------------------------------------------===//
31+
32+
#include "mlir/Dialect/Arith/IR/Arith.h"
33+
#include "mlir/Dialect/Func/IR/FuncOps.h"
34+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
35+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
36+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
37+
#include "mlir/IR/AffineMap.h"
38+
#include "mlir/IR/Builders.h"
39+
#include "mlir/IR/MLIRContext.h"
40+
#include "mlir/IR/PatternMatch.h"
41+
#include "mlir/Pass/Pass.h"
42+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
43+
#include "llvm/Support/Debug.h"
44+
45+
#define DEBUG_TYPE "fuse-fc"
46+
47+
using namespace mlir;
48+
using namespace llvm::mlir; // For silencing errors emitted by LSP
49+
50+
namespace {
51+
52+
// Helper function to check if a linalg.generic op is elementwise
53+
static bool isElementwise(linalg::GenericOp op) {
54+
return llvm::all_of(op.getIteratorTypesArray(),
55+
[](utils::IteratorType type) {
56+
return type == utils::IteratorType::parallel;
57+
});
58+
}
59+
60+
class FuseFCPattern : public OpRewritePattern<linalg::GenericOp> {
61+
public:
62+
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
63+
64+
LogicalResult matchAndRewrite(linalg::GenericOp addOp,
65+
PatternRewriter &rewriter) const override {
66+
LLVM_DEBUG(llvm::dbgs() << "\n=== FuseFCPattern: Checking GenericOp ===\n");
67+
LLVM_DEBUG(addOp.print(llvm::dbgs()));
68+
LLVM_DEBUG(llvm::dbgs() << "\n");
69+
70+
// 1. Check if the current op is a bias-add (elementwise with 2 loops)
71+
if (!isElementwise(addOp)) {
72+
LLVM_DEBUG(llvm::dbgs() << " -> Not elementwise, skipping\n");
73+
return failure();
74+
}
75+
76+
if (addOp.getNumLoops() != 2) {
77+
LLVM_DEBUG(llvm::dbgs() << " -> Number of loops != 2 (got "
78+
<< addOp.getNumLoops() << "), skipping\n");
79+
return failure();
80+
}
81+
82+
if (addOp.getNumDpsInputs() != 2 || addOp.getNumDpsInits() != 1) {
83+
LLVM_DEBUG(llvm::dbgs() << " -> Wrong number of inputs/outputs (inputs="
84+
<< addOp.getNumDpsInputs() << ", outputs="
85+
<< addOp.getNumDpsInits() << "), skipping\n");
86+
return failure();
87+
}
88+
89+
// Check for add operation in the body
90+
auto &body = addOp.getRegion().front();
91+
bool hasAdd = !body.getOps<arith::AddFOp>().empty() ||
92+
!body.getOps<arith::AddIOp>().empty();
93+
if (!hasAdd) {
94+
LLVM_DEBUG(llvm::dbgs() << " -> Body doesn't contain add operation, skipping\n");
95+
return failure();
96+
}
97+
98+
LLVM_DEBUG(llvm::dbgs() << " -> Confirmed as elementwise add operation\n");
99+
100+
// 2. Find which operand is the matmul result and which is the bias
101+
Value matmulResult, bias;
102+
linalg::MatmulOp matmulOp;
103+
int matmulIdx = -1;
104+
105+
for (int i = 0; i < 2; ++i) {
106+
auto input = addOp.getDpsInputOperand(i)->get();
107+
LLVM_DEBUG(llvm::dbgs() << " -> Checking input " << i << ": ");
108+
LLVM_DEBUG(input.print(llvm::dbgs()));
109+
LLVM_DEBUG(llvm::dbgs() << "\n");
110+
111+
if (auto defOp = input.getDefiningOp<linalg::MatmulOp>()) {
112+
LLVM_DEBUG(llvm::dbgs() << " Found MatmulOp at index " << i << "\n");
113+
matmulOp = defOp;
114+
matmulResult = input;
115+
matmulIdx = i;
116+
bias = addOp.getDpsInputOperand(1 - i)->get();
117+
break;
118+
}
119+
}
120+
121+
if (!matmulOp) {
122+
LLVM_DEBUG(llvm::dbgs() << " -> No MatmulOp input found, skipping\n");
123+
return failure();
124+
}
125+
126+
LLVM_DEBUG(llvm::dbgs() << " -> MatmulOp found:\n");
127+
LLVM_DEBUG(matmulOp.print(llvm::dbgs()));
128+
LLVM_DEBUG(llvm::dbgs() << "\n");
129+
130+
// 3. The second operand of the matmul must be the result of a transpose
131+
auto transposeOp =
132+
matmulOp.getDpsInputOperand(1)->get().getDefiningOp<linalg::TransposeOp>();
133+
if (!transposeOp) {
134+
LLVM_DEBUG(llvm::dbgs() << " -> Second matmul operand is not TransposeOp, skipping\n");
135+
return failure();
136+
}
137+
138+
LLVM_DEBUG(llvm::dbgs() << " -> TransposeOp found:\n");
139+
LLVM_DEBUG(transposeOp.print(llvm::dbgs()));
140+
LLVM_DEBUG(llvm::dbgs() << "\n");
141+
142+
// Check transpose permutation is [1, 0]
143+
auto perm = transposeOp.getPermutation();
144+
if (perm.size() != 2 || perm[0] != 1 || perm[1] != 0) {
145+
LLVM_DEBUG(llvm::dbgs() << " -> Transpose permutation is not [1, 0], skipping\n");
146+
return failure();
147+
}
148+
149+
// 4. The bias must be broadcastable
150+
LLVM_DEBUG(llvm::dbgs() << " -> Checking bias: ");
151+
LLVM_DEBUG(bias.print(llvm::dbgs()));
152+
LLVM_DEBUG(llvm::dbgs() << "\n");
153+
154+
// Trace through expand_shape if present
155+
Value originalBias = bias;
156+
if (auto expandOp = bias.getDefiningOp<tensor::ExpandShapeOp>()) {
157+
LLVM_DEBUG(llvm::dbgs() << " Bias comes from expand_shape, using source\n");
158+
originalBias = expandOp.getSrc();
159+
}
160+
161+
auto biasType = mlir::dyn_cast<RankedTensorType>(originalBias.getType());
162+
if (!biasType) {
163+
LLVM_DEBUG(llvm::dbgs() << " -> Bias is not RankedTensorType, skipping\n");
164+
return failure();
165+
}
166+
167+
LLVM_DEBUG(llvm::dbgs() << " Bias type: ");
168+
LLVM_DEBUG(biasType.print(llvm::dbgs()));
169+
LLVM_DEBUG(llvm::dbgs() << "\n");
170+
171+
// Bias must be rank 1 or rank 2 with one dimension being 1
172+
if (biasType.getRank() != 1 && biasType.getRank() != 2) {
173+
LLVM_DEBUG(llvm::dbgs() << " -> Bias rank is " << biasType.getRank()
174+
<< " (expected 1 or 2), skipping\n");
175+
return failure();
176+
}
177+
178+
if (biasType.getRank() == 2 && biasType.getShape()[0] != 1 &&
179+
biasType.getShape()[1] != 1) {
180+
LLVM_DEBUG(llvm::dbgs() << " -> 2D bias doesn't have dimension of size 1, skipping\n");
181+
return failure();
182+
}
183+
184+
LLVM_DEBUG(llvm::dbgs() << "\n=== Pattern matched! Starting rewrite ===\n");
185+
186+
Location loc = addOp.getLoc();
187+
MLIRContext *ctx = rewriter.getContext();
188+
189+
// Get the operands from the original operations
190+
Value inputA = matmulOp.getDpsInputOperand(0)->get();
191+
Value inputB = transposeOp.getDpsInputOperand(0)->get(); // Before transpose
192+
Value output = addOp.getDpsInitOperand(0)->get();
193+
194+
auto outputType = mlir::cast<RankedTensorType>(output.getType());
195+
Type elementTy = outputType.getElementType();
196+
197+
LLVM_DEBUG(llvm::dbgs() << " Input A type: " << inputA.getType() << "\n");
198+
LLVM_DEBUG(llvm::dbgs() << " Input B type: " << inputB.getType() << "\n");
199+
LLVM_DEBUG(llvm::dbgs() << " Bias type: " << originalBias.getType() << "\n");
200+
LLVM_DEBUG(llvm::dbgs() << " Output type: " << outputType << "\n");
201+
202+
// --- Part 1: Create the bias broadcast operation ---
203+
LLVM_DEBUG(llvm::dbgs() << "\n--- Creating bias broadcast operation ---\n");
204+
205+
SmallVector<AffineMap> biasMaps;
206+
if (biasType.getRank() == 1) {
207+
// from tensor<N> to tensor<M, N>
208+
biasMaps.push_back(AffineMap::get(2, 0, {rewriter.getAffineDimExpr(1)}, ctx));
209+
LLVM_DEBUG(llvm::dbgs() << " Broadcast 1D bias along first dimension\n");
210+
} else { // Rank 2
211+
if (biasType.getShape()[0] == 1) {
212+
// tensor<1, N> -> broadcast along dim 0
213+
biasMaps.push_back(AffineMap::get(2, 0,
214+
{rewriter.getAffineConstantExpr(0), rewriter.getAffineDimExpr(1)}, ctx));
215+
LLVM_DEBUG(llvm::dbgs() << " Broadcast 2D bias (1xN) along first dimension\n");
216+
} else {
217+
// tensor<N, 1> -> broadcast along dim 1
218+
biasMaps.push_back(AffineMap::get(2, 0,
219+
{rewriter.getAffineDimExpr(0), rewriter.getAffineConstantExpr(0)}, ctx));
220+
LLVM_DEBUG(llvm::dbgs() << " Broadcast 2D bias (Nx1) along second dimension\n");
221+
}
222+
}
223+
biasMaps.push_back(rewriter.getMultiDimIdentityMap(2));
224+
225+
SmallVector<linalg::utils::IteratorType> biasIteratorTypes = {
226+
linalg::utils::IteratorType::parallel, linalg::utils::IteratorType::parallel};
227+
228+
auto broadcastedBiasOp = rewriter.create<linalg::GenericOp>(
229+
loc,
230+
/*resultTensorTypes=*/outputType,
231+
/*inputs=*/originalBias,
232+
/*outputs=*/output,
233+
/*indexingMaps=*/biasMaps,
234+
/*iteratorTypes=*/biasIteratorTypes,
235+
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
236+
b.create<linalg::YieldOp>(nestedLoc, args[0]);
237+
});
238+
Value broadcastedBiasResult = broadcastedBiasOp.getResult(0);
239+
240+
LLVM_DEBUG(llvm::dbgs() << " Created bias broadcast op:\n");
241+
LLVM_DEBUG(broadcastedBiasOp.print(llvm::dbgs()));
242+
LLVM_DEBUG(llvm::dbgs() << "\n");
243+
244+
// --- Part 2: Create the matmul with folded transpose ---
245+
LLVM_DEBUG(llvm::dbgs() << "\n--- Creating fused matmul operation ---\n");
246+
247+
// A[i,k] * B[j,k] -> C[i,j] (with B accessed as transpose)
248+
AffineMap mapA = AffineMap::get(3, 0,
249+
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2)}, ctx);
250+
AffineMap mapB = AffineMap::get(3, 0,
251+
{rewriter.getAffineDimExpr(1), rewriter.getAffineDimExpr(2)}, ctx);
252+
AffineMap mapC = AffineMap::get(3, 0,
253+
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}, ctx);
254+
255+
LLVM_DEBUG(llvm::dbgs() << " Map A (input): " << mapA << "\n");
256+
LLVM_DEBUG(llvm::dbgs() << " Map B (transposed): " << mapB << "\n");
257+
LLVM_DEBUG(llvm::dbgs() << " Map C (output): " << mapC << "\n");
258+
259+
SmallVector<AffineMap> matmulMaps = {mapA, mapB, mapC};
260+
SmallVector<linalg::utils::IteratorType> matmulIteratorTypes = {
261+
linalg::utils::IteratorType::parallel, linalg::utils::IteratorType::parallel,
262+
linalg::utils::IteratorType::reduction};
263+
264+
auto fusedMatmulOp = rewriter.create<linalg::GenericOp>(
265+
loc,
266+
/*resultTensorTypes=*/outputType,
267+
/*inputs=*/ValueRange{inputA, inputB},
268+
/*outputs=*/broadcastedBiasResult,
269+
/*indexingMaps=*/matmulMaps,
270+
/*iteratorTypes=*/matmulIteratorTypes,
271+
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
272+
// args[0] = A[i,k], args[1] = B[j,k], args[2] = C[i,j] (init with bias)
273+
Value mulResult;
274+
if (isa<FloatType>(elementTy)) {
275+
mulResult = b.create<arith::MulFOp>(nestedLoc, args[0], args[1]);
276+
} else {
277+
mulResult = b.create<arith::MulIOp>(nestedLoc, args[0], args[1]);
278+
}
279+
Value addResult;
280+
if (isa<FloatType>(elementTy)) {
281+
addResult = b.create<arith::AddFOp>(nestedLoc, mulResult, args[2]);
282+
} else {
283+
addResult = b.create<arith::AddIOp>(nestedLoc, mulResult, args[2]);
284+
}
285+
b.create<linalg::YieldOp>(nestedLoc, addResult);
286+
});
287+
288+
LLVM_DEBUG(llvm::dbgs() << " Created fused matmul op:\n");
289+
LLVM_DEBUG(fusedMatmulOp.print(llvm::dbgs()));
290+
LLVM_DEBUG(llvm::dbgs() << "\n");
291+
292+
rewriter.replaceOp(addOp, fusedMatmulOp.getResults());
293+
294+
LLVM_DEBUG(llvm::dbgs() << "=== Rewrite successful ===\n\n");
295+
return success();
296+
}
297+
};
298+
299+
struct FuseFCPass : public PassWrapper<FuseFCPass, OperationPass<func::FuncOp>> {
300+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuseFCPass)
301+
302+
StringRef getArgument() const final { return "fuse-fc"; }
303+
304+
StringRef getDescription() const final {
305+
return "Fuse linalg.transpose + linalg.matmul + bias-add into a "
306+
"sequence of linalg.generic ops.";
307+
}
308+
309+
void runOnOperation() override {
310+
LLVM_DEBUG(llvm::dbgs() << "\n\n====================================\n");
311+
LLVM_DEBUG(llvm::dbgs() << "Starting FuseFCPass\n");
312+
LLVM_DEBUG(llvm::dbgs() << "====================================\n\n");
313+
314+
MLIRContext *context = &getContext();
315+
RewritePatternSet patterns(context);
316+
patterns.add<FuseFCPattern>(context);
317+
318+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
319+
LLVM_DEBUG(llvm::dbgs() << "FuseFCPass: Pattern application failed\n");
320+
signalPassFailure();
321+
} else {
322+
LLVM_DEBUG(llvm::dbgs() << "\n====================================\n");
323+
LLVM_DEBUG(llvm::dbgs() << "FuseFCPass completed successfully\n");
324+
LLVM_DEBUG(llvm::dbgs() << "====================================\n\n");
325+
}
326+
}
327+
};
328+
329+
} // namespace
330+
331+
namespace mlir {
332+
namespace buddy {
333+
void registerFuseFCPass() {
334+
PassRegistration<FuseFCPass>();
335+
}
336+
} // namespace buddy
337+
} // namespace mlir

midend/lib/InitAll.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ void registerMatMulOptimizePass();
4545
void registerMatMulParallelVectorizationPass();
4646
void registerMatMulVectorizationPass();
4747
void registerTransposeOptimizationPass();
48+
void registerFuseFCPass();
4849
} // namespace buddy
4950
} // namespace mlir
5051

@@ -74,4 +75,5 @@ void mlir::buddy::registerAllPasses() {
7475
mlir::buddy::registerMatMulParallelVectorizationPass();
7576
mlir::buddy::registerMatMulVectorizationPass();
7677
mlir::buddy::registerTransposeOptimizationPass();
78+
mlir::buddy::registerFuseFCPass();
7779
}

tools/buddy-opt/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,5 @@ target_link_libraries(buddy-opt
4545
MatMulTransposeBVec
4646
LinalgToVIRPass
4747
VIRToVectorPass
48+
FullyConnectedFusion
4849
)

0 commit comments

Comments
 (0)