Skip to content

Commit 79c0133

Browse files
committed
[midend] Add FuseFC pass
Fuse fully connected layer operations on linalg-on-tensor into two linalg.generic. These two operations could be fused into on with affine-loop-fusion. Redundent operations needs to be erased with canonicalize pass. . . .
1 parent 6996a63 commit 79c0133

File tree

8 files changed

+352
-0
lines changed

8 files changed

+352
-0
lines changed

examples/BuddyDeepSeekR1/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ add_custom_command(
3333
${BUDDY_BINARY_DIR}/buddy-opt
3434
-eliminate-empty-tensors
3535
-empty-tensor-to-alloc-tensor
36+
-fuse-fc
37+
-canonicalize
3638
-one-shot-bufferize=${BUFFERIZE_SIMPLE_OPTS}
3739
-expand-strided-metadata
3840
-ownership-based-buffer-deallocation
@@ -78,6 +80,8 @@ add_custom_command(
7880
${BUDDY_BINARY_DIR}/buddy-opt
7981
-eliminate-empty-tensors
8082
-empty-tensor-to-alloc-tensor
83+
-fuse-fc
84+
-canonicalize
8185
-convert-elementwise-to-linalg
8286
-one-shot-bufferize=${BUFFERIZE_SIMPLE_OPTS}
8387
-expand-strided-metadata

midend/lib/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ set(LinkedLibs
2525
BatchMatMulOptimization
2626
MatMulParallelVectorization
2727
TransposeOptimization
28+
FullyConnectedFusion
2829
)
2930

3031

midend/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ add_subdirectory(MLIRGPU)
1717
add_subdirectory(VIRToVector)
1818
add_subdirectory(LinalgToVIR)
1919
add_subdirectory(GraphRedundancyElimination)
20+
add_subdirectory(FullyConnectedFusion)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
add_mlir_library(FullyConnectedFusion
2+
FuseFC.cpp
3+
LINK_LIBS PUBLIC
4+
BuddyUtils
5+
)
Lines changed: 336 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,336 @@
1+
//===- FuseFC.cpp - Fully Connected Fusion Pass ------------------------===//
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
// This pass is going to deal with pattern like:
17+
// %collapsed_29 = tensor.collapse_shape %46 [[0, 1], [2]] : tensor<1x1024x1536xf32> into tensor<1024x1536xf32>
18+
// %51 = bufferization.alloc_tensor() : tensor<1536x256xf32>
19+
// %transposed_30 = linalg.transpose ins(%arg7 : tensor<256x1536xf32>) outs(%51 : tensor<1536x256xf32>) permutation = [1, 0]
20+
// %52 = linalg.matmul {cast = #linalg.type_fn<cast_signed>} ins(%collapsed_29, %transposed_30 : tensor<1024x1536xf32>, tensor<1536x256xf32>) outs(%cst_6 : tensor<1024x256xf32>) -> tensor<1024x256xf32>
21+
// %expanded_31 = tensor.expand_shape %arg8 [[0, 1]] output_shape [1, 256] : tensor<256xf32> into tensor<1x256xf32>
22+
// %53 = bufferization.alloc_tensor() : tensor<1024x256xf32>
23+
// %54 = linalg.generic {indexing_maps = [#map10, #map, #map], iterator_types = ["parallel", "parallel"]} ins(%expanded_31, %52 : tensor<1x256xf32>, tensor<1024x256xf32>) outs(%53 : tensor<1024x256xf32>) {
24+
// ^bb0(%in: f32, %in_1538: f32, %out: f32):
25+
// %3044 = arith.addf %in, %in_1538 : f32
26+
// linalg.yield %3044 : f32
27+
// } -> tensor<1024x256xf32>
28+
// This pass will transform it into two linalg.generic Operations. So that they could be fused later in affine level.
29+
// Redundent operations needs to be cleaned with canonicalization pass.
30+
//===----------------------------------------------------------------------===//
31+
32+
#include "mlir/Dialect/Arith/IR/Arith.h"
33+
#include "mlir/Dialect/Func/IR/FuncOps.h"
34+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
35+
#include "mlir/Dialect/Linalg/Utils/Utils.h"
36+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
37+
#include "mlir/IR/AffineMap.h"
38+
#include "mlir/IR/Builders.h"
39+
#include "mlir/IR/MLIRContext.h"
40+
#include "mlir/IR/PatternMatch.h"
41+
#include "mlir/Pass/Pass.h"
42+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
43+
#include "llvm/Support/Debug.h"
44+
45+
#define DEBUG_TYPE "fuse-fc"
46+
47+
using namespace mlir;
48+
49+
namespace {
50+
51+
// Helper function to check if a linalg.generic op is elementwise
52+
static bool isElementwise(linalg::GenericOp op) {
53+
return llvm::all_of(op.getIteratorTypesArray(),
54+
[](utils::IteratorType type) {
55+
return type == utils::IteratorType::parallel;
56+
});
57+
}
58+
59+
class FuseFCPattern : public OpRewritePattern<linalg::GenericOp> {
60+
public:
61+
using OpRewritePattern<linalg::GenericOp>::OpRewritePattern;
62+
63+
LogicalResult matchAndRewrite(linalg::GenericOp addOp,
64+
PatternRewriter &rewriter) const override {
65+
LLVM_DEBUG(llvm::dbgs() << "\n=== FuseFCPattern: Checking GenericOp ===\n");
66+
LLVM_DEBUG(addOp.print(llvm::dbgs()));
67+
LLVM_DEBUG(llvm::dbgs() << "\n");
68+
69+
// 1. Check if the current op is a bias-add (elementwise with 2 loops)
70+
if (!isElementwise(addOp)) {
71+
LLVM_DEBUG(llvm::dbgs() << " -> Not elementwise, skipping\n");
72+
return failure();
73+
}
74+
75+
if (addOp.getNumLoops() != 2) {
76+
LLVM_DEBUG(llvm::dbgs() << " -> Number of loops != 2 (got "
77+
<< addOp.getNumLoops() << "), skipping\n");
78+
return failure();
79+
}
80+
81+
if (addOp.getNumDpsInputs() != 2 || addOp.getNumDpsInits() != 1) {
82+
LLVM_DEBUG(llvm::dbgs() << " -> Wrong number of inputs/outputs (inputs="
83+
<< addOp.getNumDpsInputs() << ", outputs="
84+
<< addOp.getNumDpsInits() << "), skipping\n");
85+
return failure();
86+
}
87+
88+
// Check for add operation in the body
89+
auto &body = addOp.getRegion().front();
90+
bool hasAdd = !body.getOps<arith::AddFOp>().empty() ||
91+
!body.getOps<arith::AddIOp>().empty();
92+
if (!hasAdd) {
93+
LLVM_DEBUG(llvm::dbgs() << " -> Body doesn't contain add operation, skipping\n");
94+
return failure();
95+
}
96+
97+
LLVM_DEBUG(llvm::dbgs() << " -> Confirmed as elementwise add operation\n");
98+
99+
// 2. Find which operand is the matmul result and which is the bias
100+
Value matmulResult, bias;
101+
linalg::MatmulOp matmulOp;
102+
int matmulIdx = -1;
103+
104+
for (int i = 0; i < 2; ++i) {
105+
auto input = addOp.getDpsInputOperand(i)->get();
106+
LLVM_DEBUG(llvm::dbgs() << " -> Checking input " << i << ": ");
107+
LLVM_DEBUG(input.print(llvm::dbgs()));
108+
LLVM_DEBUG(llvm::dbgs() << "\n");
109+
110+
if (auto defOp = input.getDefiningOp<linalg::MatmulOp>()) {
111+
LLVM_DEBUG(llvm::dbgs() << " Found MatmulOp at index " << i << "\n");
112+
matmulOp = defOp;
113+
matmulResult = input;
114+
matmulIdx = i;
115+
bias = addOp.getDpsInputOperand(1 - i)->get();
116+
break;
117+
}
118+
}
119+
120+
if (!matmulOp) {
121+
LLVM_DEBUG(llvm::dbgs() << " -> No MatmulOp input found, skipping\n");
122+
return failure();
123+
}
124+
125+
LLVM_DEBUG(llvm::dbgs() << " -> MatmulOp found:\n");
126+
LLVM_DEBUG(matmulOp.print(llvm::dbgs()));
127+
LLVM_DEBUG(llvm::dbgs() << "\n");
128+
129+
// 3. The second operand of the matmul must be the result of a transpose
130+
auto transposeOp =
131+
matmulOp.getDpsInputOperand(1)->get().getDefiningOp<linalg::TransposeOp>();
132+
if (!transposeOp) {
133+
LLVM_DEBUG(llvm::dbgs() << " -> Second matmul operand is not TransposeOp, skipping\n");
134+
return failure();
135+
}
136+
137+
LLVM_DEBUG(llvm::dbgs() << " -> TransposeOp found:\n");
138+
LLVM_DEBUG(transposeOp.print(llvm::dbgs()));
139+
LLVM_DEBUG(llvm::dbgs() << "\n");
140+
141+
// Check transpose permutation is [1, 0]
142+
auto perm = transposeOp.getPermutation();
143+
if (perm.size() != 2 || perm[0] != 1 || perm[1] != 0) {
144+
LLVM_DEBUG(llvm::dbgs() << " -> Transpose permutation is not [1, 0], skipping\n");
145+
return failure();
146+
}
147+
148+
// 4. The bias must be broadcastable
149+
LLVM_DEBUG(llvm::dbgs() << " -> Checking bias: ");
150+
LLVM_DEBUG(bias.print(llvm::dbgs()));
151+
LLVM_DEBUG(llvm::dbgs() << "\n");
152+
153+
// Trace through expand_shape if present
154+
Value originalBias = bias;
155+
if (auto expandOp = bias.getDefiningOp<tensor::ExpandShapeOp>()) {
156+
LLVM_DEBUG(llvm::dbgs() << " Bias comes from expand_shape, using source\n");
157+
originalBias = expandOp.getSrc();
158+
}
159+
160+
auto biasType = mlir::dyn_cast<RankedTensorType>(originalBias.getType());
161+
if (!biasType) {
162+
LLVM_DEBUG(llvm::dbgs() << " -> Bias is not RankedTensorType, skipping\n");
163+
return failure();
164+
}
165+
166+
LLVM_DEBUG(llvm::dbgs() << " Bias type: ");
167+
LLVM_DEBUG(biasType.print(llvm::dbgs()));
168+
LLVM_DEBUG(llvm::dbgs() << "\n");
169+
170+
// Bias must be rank 1 or rank 2 with one dimension being 1
171+
if (biasType.getRank() != 1 && biasType.getRank() != 2) {
172+
LLVM_DEBUG(llvm::dbgs() << " -> Bias rank is " << biasType.getRank()
173+
<< " (expected 1 or 2), skipping\n");
174+
return failure();
175+
}
176+
177+
if (biasType.getRank() == 2 && biasType.getShape()[0] != 1 &&
178+
biasType.getShape()[1] != 1) {
179+
LLVM_DEBUG(llvm::dbgs() << " -> 2D bias doesn't have dimension of size 1, skipping\n");
180+
return failure();
181+
}
182+
183+
LLVM_DEBUG(llvm::dbgs() << "\n=== Pattern matched! Starting rewrite ===\n");
184+
185+
Location loc = addOp.getLoc();
186+
MLIRContext *ctx = rewriter.getContext();
187+
188+
// Get the operands from the original operations
189+
Value inputA = matmulOp.getDpsInputOperand(0)->get();
190+
Value inputB = transposeOp.getDpsInputOperand(0)->get(); // Before transpose
191+
Value output = addOp.getDpsInitOperand(0)->get();
192+
193+
auto outputType = mlir::cast<RankedTensorType>(output.getType());
194+
Type elementTy = outputType.getElementType();
195+
196+
LLVM_DEBUG(llvm::dbgs() << " Input A type: " << inputA.getType() << "\n");
197+
LLVM_DEBUG(llvm::dbgs() << " Input B type: " << inputB.getType() << "\n");
198+
LLVM_DEBUG(llvm::dbgs() << " Bias type: " << originalBias.getType() << "\n");
199+
LLVM_DEBUG(llvm::dbgs() << " Output type: " << outputType << "\n");
200+
201+
// --- Part 1: Create the bias broadcast operation ---
202+
LLVM_DEBUG(llvm::dbgs() << "\n--- Creating bias broadcast operation ---\n");
203+
204+
SmallVector<AffineMap> biasMaps;
205+
if (biasType.getRank() == 1) {
206+
// from tensor<N> to tensor<M, N>
207+
biasMaps.push_back(AffineMap::get(2, 0, {rewriter.getAffineDimExpr(1)}, ctx));
208+
LLVM_DEBUG(llvm::dbgs() << " Broadcast 1D bias along first dimension\n");
209+
} else { // Rank 2
210+
if (biasType.getShape()[0] == 1) {
211+
// tensor<1, N> -> broadcast along dim 0
212+
biasMaps.push_back(AffineMap::get(2, 0,
213+
{rewriter.getAffineConstantExpr(0), rewriter.getAffineDimExpr(1)}, ctx));
214+
LLVM_DEBUG(llvm::dbgs() << " Broadcast 2D bias (1xN) along first dimension\n");
215+
} else {
216+
// tensor<N, 1> -> broadcast along dim 1
217+
biasMaps.push_back(AffineMap::get(2, 0,
218+
{rewriter.getAffineDimExpr(0), rewriter.getAffineConstantExpr(0)}, ctx));
219+
LLVM_DEBUG(llvm::dbgs() << " Broadcast 2D bias (Nx1) along second dimension\n");
220+
}
221+
}
222+
biasMaps.push_back(rewriter.getMultiDimIdentityMap(2));
223+
224+
SmallVector<utils::IteratorType> biasIteratorTypes = {
225+
utils::IteratorType::parallel, utils::IteratorType::parallel};
226+
227+
auto broadcastedBiasOp = rewriter.create<linalg::GenericOp>(
228+
loc,
229+
/*resultTensorTypes=*/outputType,
230+
/*inputs=*/originalBias,
231+
/*outputs=*/output,
232+
/*indexingMaps=*/biasMaps,
233+
/*iteratorTypes=*/biasIteratorTypes,
234+
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
235+
b.create<linalg::YieldOp>(nestedLoc, args[0]);
236+
});
237+
Value broadcastedBiasResult = broadcastedBiasOp.getResult(0);
238+
239+
LLVM_DEBUG(llvm::dbgs() << " Created bias broadcast op:\n");
240+
LLVM_DEBUG(broadcastedBiasOp.print(llvm::dbgs()));
241+
LLVM_DEBUG(llvm::dbgs() << "\n");
242+
243+
// --- Part 2: Create the matmul with folded transpose ---
244+
LLVM_DEBUG(llvm::dbgs() << "\n--- Creating fused matmul operation ---\n");
245+
246+
// A[i,k] * B[j,k] -> C[i,j] (with B accessed as transpose)
247+
AffineMap mapA = AffineMap::get(3, 0,
248+
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(2)}, ctx);
249+
AffineMap mapB = AffineMap::get(3, 0,
250+
{rewriter.getAffineDimExpr(1), rewriter.getAffineDimExpr(2)}, ctx);
251+
AffineMap mapC = AffineMap::get(3, 0,
252+
{rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1)}, ctx);
253+
254+
LLVM_DEBUG(llvm::dbgs() << " Map A (input): " << mapA << "\n");
255+
LLVM_DEBUG(llvm::dbgs() << " Map B (transposed): " << mapB << "\n");
256+
LLVM_DEBUG(llvm::dbgs() << " Map C (output): " << mapC << "\n");
257+
258+
SmallVector<AffineMap> matmulMaps = {mapA, mapB, mapC};
259+
SmallVector<utils::IteratorType> matmulIteratorTypes = {
260+
utils::IteratorType::parallel, utils::IteratorType::parallel,
261+
utils::IteratorType::reduction};
262+
263+
auto fusedMatmulOp = rewriter.create<linalg::GenericOp>(
264+
loc,
265+
/*resultTensorTypes=*/outputType,
266+
/*inputs=*/ValueRange{inputA, inputB},
267+
/*outputs=*/broadcastedBiasResult,
268+
/*indexingMaps=*/matmulMaps,
269+
/*iteratorTypes=*/matmulIteratorTypes,
270+
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
271+
// args[0] = A[i,k], args[1] = B[j,k], args[2] = C[i,j] (init with bias)
272+
Value mulResult;
273+
if (isa<FloatType>(elementTy)) {
274+
mulResult = b.create<arith::MulFOp>(nestedLoc, args[0], args[1]);
275+
} else {
276+
mulResult = b.create<arith::MulIOp>(nestedLoc, args[0], args[1]);
277+
}
278+
Value addResult;
279+
if (isa<FloatType>(elementTy)) {
280+
addResult = b.create<arith::AddFOp>(nestedLoc, mulResult, args[2]);
281+
} else {
282+
addResult = b.create<arith::AddIOp>(nestedLoc, mulResult, args[2]);
283+
}
284+
b.create<linalg::YieldOp>(nestedLoc, addResult);
285+
});
286+
287+
LLVM_DEBUG(llvm::dbgs() << " Created fused matmul op:\n");
288+
LLVM_DEBUG(fusedMatmulOp.print(llvm::dbgs()));
289+
LLVM_DEBUG(llvm::dbgs() << "\n");
290+
291+
rewriter.replaceOp(addOp, fusedMatmulOp.getResults());
292+
293+
LLVM_DEBUG(llvm::dbgs() << "=== Rewrite successful ===\n\n");
294+
return success();
295+
}
296+
};
297+
298+
struct FuseFCPass : public PassWrapper<FuseFCPass, OperationPass<func::FuncOp>> {
299+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(FuseFCPass)
300+
301+
StringRef getArgument() const final { return "fuse-fc"; }
302+
303+
StringRef getDescription() const final {
304+
return "Fuse linalg.transpose + linalg.matmul + bias-add into a "
305+
"sequence of linalg.generic ops.";
306+
}
307+
308+
void runOnOperation() override {
309+
LLVM_DEBUG(llvm::dbgs() << "\n\n====================================\n");
310+
LLVM_DEBUG(llvm::dbgs() << "Starting FuseFCPass\n");
311+
LLVM_DEBUG(llvm::dbgs() << "====================================\n\n");
312+
313+
MLIRContext *context = &getContext();
314+
RewritePatternSet patterns(context);
315+
patterns.add<FuseFCPattern>(context);
316+
317+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
318+
LLVM_DEBUG(llvm::dbgs() << "FuseFCPass: Pattern application failed\n");
319+
signalPassFailure();
320+
} else {
321+
LLVM_DEBUG(llvm::dbgs() << "\n====================================\n");
322+
LLVM_DEBUG(llvm::dbgs() << "FuseFCPass completed successfully\n");
323+
LLVM_DEBUG(llvm::dbgs() << "====================================\n\n");
324+
}
325+
}
326+
};
327+
328+
} // namespace
329+
330+
namespace mlir {
331+
namespace buddy {
332+
void registerFuseFCPass() {
333+
PassRegistration<FuseFCPass>();
334+
}
335+
} // namespace buddy
336+
} // namespace mlir

midend/lib/InitAll.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ void registerMatMulOptimizePass();
4545
void registerMatMulParallelVectorizationPass();
4646
void registerMatMulVectorizationPass();
4747
void registerTransposeOptimizationPass();
48+
void registerFuseFCPass();
4849
} // namespace buddy
4950
} // namespace mlir
5051

@@ -74,4 +75,5 @@ void mlir::buddy::registerAllPasses() {
7475
mlir::buddy::registerMatMulParallelVectorizationPass();
7576
mlir::buddy::registerMatMulVectorizationPass();
7677
mlir::buddy::registerTransposeOptimizationPass();
78+
mlir::buddy::registerFuseFCPass();
7779
}

tools/buddy-opt/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,5 @@ target_link_libraries(buddy-opt
4545
MatMulTransposeBVec
4646
LinalgToVIRPass
4747
VIRToVectorPass
48+
FullyConnectedFusion
4849
)

0 commit comments

Comments
 (0)