Skip to content

Commit a42eded

Browse files
committed
Add Silu optimization and test
1 parent 366073f commit a42eded

File tree

3 files changed

+134
-70
lines changed

3 files changed

+134
-70
lines changed

midend/lib/Conversion/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,4 @@ add_subdirectory(LowerSche)
1616
add_subdirectory(FuncBufferize)
1717
add_subdirectory(DepthwiseConvOptimization)
1818
add_subdirectory(MLIRGPU)
19-
add_subdirectory(MLIRGPU)
2019
add_subdirectory(SiluOptimization)

midend/lib/Conversion/SiluOptimization/SiluOptimization.cpp

Lines changed: 78 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,24 @@
1+
//====- SiluOptimization.cpp - Silu Optimization Pass ---------------------===//
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
//
15+
//===----------------------------------------------------------------------===//
16+
//
17+
// This file implements the pass that vectorizes the linalg.generic representing
18+
// SiLU.
19+
//
20+
//===----------------------------------------------------------------------===//
21+
122
#include "mlir/Dialect/Affine/IR/AffineOps.h"
223
#include "mlir/Dialect/Arith/IR/Arith.h"
324
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -13,8 +34,6 @@
1334
#include "mlir/Pass/Pass.h"
1435
#include "mlir/Transforms/DialectConversion.h"
1536

16-
17-
1837
using namespace mlir;
1938

2039
namespace {
@@ -35,40 +54,40 @@ class SiLUVectorizePattern : public ConversionPattern {
3554
ConversionPatternRewriter &rewriter) const override {
3655

3756
linalg::GenericOp sigmoidOp = cast<linalg::GenericOp>(op);
38-
57+
3958
//--------------sigmoid OP--------
40-
// 1. Check input/output
41-
if (sigmoidOp.getNumDpsInputs() != 1 || sigmoidOp.getNumDpsInits() != 1){
42-
llvm::errs() << "1\n";
43-
return failure();}
59+
// Check input/output
60+
if (sigmoidOp.getNumDpsInputs() != 1 || sigmoidOp.getNumDpsInits() != 1) {
61+
return failure();
62+
}
4463

4564
// Check the body of the op for sigmoid computation.
4665
// The IR should be: negf, exp, addf, divf, yield.
4766
Block &block = sigmoidOp.getRegion().front();
48-
if (block.getOperations().size() != 5) // negf, exp, addf, divf, yield
49-
{llvm::errs() << "4\n";
50-
return failure();}
67+
if (block.getOperations().size() != 5) { // negf, exp, addf, divf, yield
68+
return failure();
69+
}
5170

5271
Operation &negfOp = block.getOperations().front();
5372
Operation &yieldOp = block.getOperations().back();
5473

5574
// Check the type of the two operations.
56-
if (!isa<arith::NegFOp>(negfOp) || !isa<linalg::YieldOp>(yieldOp))
57-
{llvm::errs() << "5\n";
58-
return failure();}
59-
75+
if (!isa<arith::NegFOp>(negfOp) || !isa<linalg::YieldOp>(yieldOp)) {
76+
return failure();
77+
}
6078

6179
//-----------Find the consumer mul operation.------------------------------
6280
// The result of the sigmoid op must be used by another linalg.generic op.
6381
Value outputBuffer = sigmoidOp.getDpsInitOperand(0)->get();
6482

65-
// 遍历所有 uses,寻找满足条件的 consumer op
83+
// Iterate over all uses to find a suitable consumer op.
6684
linalg::GenericOp mulOp = nullptr;
6785

6886
for (auto &use : outputBuffer.getUses()) {
6987
Operation *user = use.getOwner();
7088

71-
// 要求是 linalg.generic,且 %alloc 是 input operand(即 ins())
89+
// It must be a linalg.generic, and the buffer must be an input operand
90+
// (i.e., ins()).
7291
auto linalgOp = dyn_cast<linalg::GenericOp>(user);
7392
if (!linalgOp)
7493
continue;
@@ -83,7 +102,7 @@ class SiLUVectorizePattern : public ConversionPattern {
83102
if (!foundInInput)
84103
continue;
85104

86-
// 检查其内部是否有 arith.mulf 操作
105+
// Check if it contains an arith.mulf operation inside.
87106
for (auto &nestedOp : linalgOp.getRegion().front()) {
88107
if (isa<arith::MulFOp>(nestedOp)) {
89108
mulOp = linalgOp;
@@ -96,14 +115,14 @@ class SiLUVectorizePattern : public ConversionPattern {
96115
}
97116

98117
if (!mulOp) {
99-
llvm::errs() << "Didn't find a consumer linalg.generic using sigmoid output with mulf.\n";
118+
llvm::errs() << "Didn't find a consumer linalg.generic using sigmoid "
119+
"output with mulf.\n";
100120
return failure();
101121
}
102122

103-
// Set the insertion point before the mulOp. This ensures that the new affine
104-
// loop is inserted at a point that is dominated by the allocation of the
105-
// output buffer.
106-
// rewriter.setInsertionPoint(mulOp);
123+
// Set the insertion point before the mulOp. This ensures that the new
124+
// affine loop is inserted at a point that is dominated by the allocation of
125+
// the output buffer. rewriter.setInsertionPoint(mulOp);
107126

108127
// Now we have matched the silu pattern: sigmoid followed by a mul.
109128
// The rewrite logic will be applied to the sigmoidOp, and the mulOp will be
@@ -122,57 +141,46 @@ class SiLUVectorizePattern : public ConversionPattern {
122141

123142
// Define constants.
124143
Value c0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
125-
Value c1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
126-
const int64_t unrollFactor = 2;
127-
Value cUnrollVec =
128-
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize * unrollFactor);
129-
Value cst1f = rewriter.create<arith::ConstantOp>(
130-
loc, rewriter.getF32FloatAttr(1.0));
144+
Value cst1f =
145+
rewriter.create<arith::ConstantOp>(loc, rewriter.getF32FloatAttr(1.0));
131146
Value vec1f = rewriter.create<vector::BroadcastOp>(loc, vectorType, cst1f);
147+
Value cst0f =
148+
rewriter.create<arith::ConstantOp>(loc, rewriter.getF32FloatAttr(0.0f));
132149

133150
// Get dimensions.
134151
Value d0 = rewriter.create<memref::DimOp>(loc, input, 0);
135152
Value d1 = rewriter.create<memref::DimOp>(loc, input, 1);
136153
Value d2 = rewriter.create<memref::DimOp>(loc, input, 2);
137154

138155
// Create loop nest.
139-
scf::ForOp iLoop = rewriter.create<scf::ForOp>(loc, c0, d0, c1);
156+
AffineMap map = rewriter.getDimIdentityMap();
157+
affine::AffineForOp iLoop = rewriter.create<affine::AffineForOp>(
158+
loc, ValueRange{c0}, map, ValueRange{d0}, map);
140159
rewriter.setInsertionPointToStart(iLoop.getBody());
141160
Value iv_i = iLoop.getInductionVar();
142161

143-
scf::ForOp jLoop = rewriter.create<scf::ForOp>(loc, c0, d1, c1);
162+
affine::AffineForOp jLoop = rewriter.create<affine::AffineForOp>(
163+
loc, ValueRange{c0}, map, ValueRange{d1}, map);
144164
rewriter.setInsertionPointToStart(jLoop.getBody());
145165
Value iv_j = jLoop.getInductionVar();
146166

147-
scf::ForOp kLoop = rewriter.create<scf::ForOp>(loc, c0, d2, cUnrollVec);
167+
affine::AffineForOp kLoop = rewriter.create<affine::AffineForOp>(
168+
loc, ValueRange{c0}, map, ValueRange{d2}, map, vectorSize);
148169
rewriter.setInsertionPointToStart(kLoop.getBody());
149170
Value iv_k = kLoop.getInductionVar();
150171

151-
// Prefetch
152-
Value k_next = rewriter.create<arith::AddIOp>(loc, iv_k, cUnrollVec);
153-
rewriter.create<memref::PrefetchOp>(loc, input, ValueRange{iv_i, iv_j, k_next},
154-
/*isWrite=*/false, /*localityHint=*/3,
155-
/*isDataCache=*/true);
156-
157-
// Unrolled loop body
158-
for (int i = 0; i < unrollFactor; ++i) {
159-
Value k_offset =
160-
rewriter.create<arith::ConstantIndexOp>(loc, i * vectorSize);
161-
Value k_i = rewriter.create<arith::AddIOp>(loc, iv_k, k_offset);
162-
163-
// --- Process Vector ---
164-
Value x_vec = rewriter.create<vector::LoadOp>(
165-
loc, vectorType, input, ValueRange{iv_i, iv_j, k_i});
166-
Value neg_x_vec = rewriter.create<arith::NegFOp>(loc, x_vec);
167-
Value exp_neg_x_vec = rewriter.create<math::ExpOp>(loc, neg_x_vec);
168-
Value one_plus_exp_vec =
169-
rewriter.create<arith::AddFOp>(loc, vec1f, exp_neg_x_vec);
170-
Value sigmoid_x_vec =
171-
rewriter.create<arith::DivFOp>(loc, vec1f, one_plus_exp_vec);
172-
Value silu_vec = rewriter.create<arith::MulFOp>(loc, x_vec, sigmoid_x_vec);
173-
rewriter.create<vector::StoreOp>(loc, silu_vec, output,
174-
ValueRange{iv_i, iv_j, k_i});
175-
}
172+
// --- Process Vector ---
173+
Value x_vec = rewriter.create<vector::TransferReadOp>(
174+
loc, vectorType, input, ValueRange{iv_i, iv_j, iv_k}, cst0f);
175+
Value neg_x_vec = rewriter.create<arith::NegFOp>(loc, x_vec);
176+
Value exp_neg_x_vec = rewriter.create<math::ExpOp>(loc, neg_x_vec);
177+
Value one_plus_exp_vec =
178+
rewriter.create<arith::AddFOp>(loc, vec1f, exp_neg_x_vec);
179+
Value sigmoid_x_vec =
180+
rewriter.create<arith::DivFOp>(loc, vec1f, one_plus_exp_vec);
181+
Value silu_vec = rewriter.create<arith::MulFOp>(loc, x_vec, sigmoid_x_vec);
182+
rewriter.create<vector::TransferWriteOp>(loc, silu_vec, output,
183+
ValueRange{iv_i, iv_j, iv_k});
176184

177185
// Replace the original mulOp with the result from our new computation.
178186
// The 'output' buffer now holds the final result. `replaceOp` will
@@ -214,36 +222,37 @@ class SiluOptimizationPass
214222
}
215223

216224
Option<int64_t> vectorSize{*this, "vector-size",
217-
llvm::cl::desc("Vector size for SiLU."),
218-
llvm::cl::init(8)};
225+
llvm::cl::desc("Vector size for SiLU."),
226+
llvm::cl::init(8)};
219227
};
220228

221229
void SiluOptimizationPass::runOnOperation() {
222230
MLIRContext *context = &getContext();
223231
ModuleOp module = getOperation();
224232

225233
ConversionTarget target(*context);
226-
target.addLegalDialect<arith::ArithDialect, affine::AffineDialect,
227-
memref::MemRefDialect, vector::VectorDialect,
228-
func::FuncDialect, math::MathDialect,
229-
scf::SCFDialect>();
234+
target
235+
.addLegalDialect<arith::ArithDialect, affine::AffineDialect,
236+
memref::MemRefDialect, vector::VectorDialect,
237+
func::FuncDialect, math::MathDialect, scf::SCFDialect>();
230238
target.addLegalOp<ModuleOp, func::FuncOp>();
231-
232-
// We will manually mark linalg.generic as illegal if it is part of a SiLU pattern.
233-
// The pattern itself will handle the legality checks and replacements.
234-
// Therefore, we don't need to addIllegalOp<linalg::GenericOp>() here.
235-
239+
240+
// We will manually mark linalg.generic as illegal if it is part of a SiLU
241+
// pattern. The pattern itself will handle the legality checks and
242+
// replacements. Therefore, we don't need to addIllegalOp<linalg::GenericOp>()
243+
// here.
244+
236245
RewritePatternSet patterns(context);
237246
patterns.add<SiLUVectorizePattern>(context, vectorSize);
238247

239248
if (failed(applyPartialConversion(module, target, std::move(patterns))))
240249
signalPassFailure();
241250
}
242-
} // end anonymous namespace
251+
} // end anonymous namespace
243252
namespace mlir {
244253
namespace buddy {
245254
void registerSiluOptimizationPass() {
246255
PassRegistration<SiluOptimizationPass>();
247256
}
248257
} // namespace buddy
249-
} // namespace mlir
258+
} // namespace mlir
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: buddy-opt -silu-optimization="vector-size=8" %s | FileCheck %s
2+
3+
// CHECK: #map = affine_map<(d0) -> (d0)>
4+
// CHECK: module {
5+
// CHECK: func.func @silu_tosa(%arg0: memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>>) -> memref<1x40x8960xf32> {
6+
// CHECK:%cst = arith.constant 1.000000e+00 : f32
7+
// CHECK: %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x40x8960xf32>
8+
// CHECK: %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x40x8960xf32>
9+
// CHECK: %c0 = arith.constant 0 : index
10+
// CHECK: %cst_1 = arith.constant 1.000000e+00 : f32
11+
// CHECK: %0 = vector.broadcast %cst_1 : f32 to vector<8xf32>
12+
// CHECK: %cst_2 = arith.constant 0.000000e+00 : f32
13+
// CHECK: %c0_3 = arith.constant 0 : index
14+
// CHECK: %dim = memref.dim %arg0, %c0_3 : memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>>
15+
// CHECK: %c1 = arith.constant 1 : index
16+
// CHECK: %dim_4 = memref.dim %arg0, %c1 : memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>>
17+
// CHECK: %c2 = arith.constant 2 : index
18+
// CHECK: %dim_5 = memref.dim %arg0, %c2 : memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>>
19+
// CHECK: affine.for %arg1 = #map(%c0) to #map(%dim) {
20+
// CHECK: affine.for %arg2 = #map(%c0) to #map(%dim_4) {
21+
// CHECK: affine.for %arg3 = #map(%c0) to #map(%dim_5) step 8 {
22+
// CHECK: %1 = vector.transfer_read %arg0[%arg1, %arg2, %arg3], %cst_2 : memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>>, vector<8xf32>
23+
// CHECK: %2 = arith.negf %1 : vector<8xf32>
24+
// CHECK: %3 = math.exp %2 : vector<8xf32>
25+
// CHECK: %4 = arith.addf %0, %3 : vector<8xf32>
26+
// CHECK: %5 = arith.divf %0, %4 : vector<8xf32>
27+
// CHECK: %6 = arith.mulf %1, %5 : vector<8xf32>
28+
// CHECK: vector.transfer_write %6, %alloc_0[%arg1, %arg2, %arg3] : vector<8xf32>, memref<1x40x8960xf32>
29+
// CHECK: }
30+
// CHECK: }
31+
// CHECK: }
32+
// CHECK: return %alloc_0 : memref<1x40x8960xf32>
33+
// CHECK: }
34+
// CHECK: }
35+
36+
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
37+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
38+
func.func @silu_tosa(%arg0: memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>>) -> memref<1x40x8960xf32> {
39+
%cst = arith.constant 1.000000e+00 : f32
40+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<1x40x8960xf32>
41+
linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>>) outs(%alloc : memref<1x40x8960xf32>) {
42+
^bb0(%in: f32, %out: f32):
43+
%3 = arith.negf %in : f32
44+
%4 = math.exp %3 : f32
45+
%5 = arith.addf %4, %cst : f32
46+
%6 = arith.divf %cst, %5 : f32
47+
linalg.yield %6 : f32
48+
}
49+
%alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x40x8960xf32>
50+
linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %alloc : memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>>, memref<1x40x8960xf32>) outs(%alloc_0 : memref<1x40x8960xf32>) {
51+
^bb0(%in: f32, %in_1: f32, %out: f32):
52+
%3 = arith.mulf %in, %in_1 : f32
53+
linalg.yield %3 : f32
54+
}
55+
return %alloc_0 : memref<1x40x8960xf32>
56+
}

0 commit comments

Comments
 (0)