1+ // ====- SiluOptimization.cpp - Silu Optimization Pass ---------------------===//
2+ //
3+ // Licensed under the Apache License, Version 2.0 (the "License");
4+ // you may not use this file except in compliance with the License.
5+ // You may obtain a copy of the License at
6+ //
7+ // http://www.apache.org/licenses/LICENSE-2.0
8+ //
9+ // Unless required by applicable law or agreed to in writing, software
10+ // distributed under the License is distributed on an "AS IS" BASIS,
11+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ // See the License for the specific language governing permissions and
13+ // limitations under the License.
14+ //
15+ // ===----------------------------------------------------------------------===//
16+ //
17+ // This file implements the pass that vectorizes the linalg.generic representing
18+ // SiLU.
19+ //
20+ // ===----------------------------------------------------------------------===//
21+
122#include " mlir/Dialect/Affine/IR/AffineOps.h"
223#include " mlir/Dialect/Arith/IR/Arith.h"
324#include " mlir/Dialect/Func/IR/FuncOps.h"
1334#include " mlir/Pass/Pass.h"
1435#include " mlir/Transforms/DialectConversion.h"
1536
16-
17-
1837using namespace mlir ;
1938
2039namespace {
@@ -35,40 +54,40 @@ class SiLUVectorizePattern : public ConversionPattern {
3554 ConversionPatternRewriter &rewriter) const override {
3655
3756 linalg::GenericOp sigmoidOp = cast<linalg::GenericOp>(op);
38-
57+
3958 // --------------sigmoid OP--------
40- // 1. Check input/output
41- if (sigmoidOp.getNumDpsInputs () != 1 || sigmoidOp.getNumDpsInits () != 1 ){
42- llvm::errs () << " 1 \n " ;
43- return failure (); }
59+ // Check input/output
60+ if (sigmoidOp.getNumDpsInputs () != 1 || sigmoidOp.getNumDpsInits () != 1 ) {
61+ return failure () ;
62+ }
4463
4564 // Check the body of the op for sigmoid computation.
4665 // The IR should be: negf, exp, addf, divf, yield.
4766 Block &block = sigmoidOp.getRegion ().front ();
48- if (block.getOperations ().size () != 5 ) // negf, exp, addf, divf, yield
49- { llvm::errs () << " 4 \n " ;
50- return failure (); }
67+ if (block.getOperations ().size () != 5 ) { // negf, exp, addf, divf, yield
68+ return failure () ;
69+ }
5170
5271 Operation &negfOp = block.getOperations ().front ();
5372 Operation &yieldOp = block.getOperations ().back ();
5473
5574 // Check the type of the two operations.
56- if (!isa<arith::NegFOp>(negfOp) || !isa<linalg::YieldOp>(yieldOp))
57- {llvm::errs () << " 5\n " ;
58- return failure ();}
59-
75+ if (!isa<arith::NegFOp>(negfOp) || !isa<linalg::YieldOp>(yieldOp)) {
76+ return failure ();
77+ }
6078
6179 // -----------Find the consumer mul operation.------------------------------
6280 // The result of the sigmoid op must be used by another linalg.generic op.
6381 Value outputBuffer = sigmoidOp.getDpsInitOperand (0 )->get ();
6482
65- // 遍历所有 uses,寻找满足条件的 consumer op
83+ // Iterate over all uses to find a suitable consumer op.
6684 linalg::GenericOp mulOp = nullptr ;
6785
6886 for (auto &use : outputBuffer.getUses ()) {
6987 Operation *user = use.getOwner ();
7088
71- // 要求是 linalg.generic,且 %alloc 是 input operand(即 ins())
89+ // It must be a linalg.generic, and the buffer must be an input operand
90+ // (i.e., ins()).
7291 auto linalgOp = dyn_cast<linalg::GenericOp>(user);
7392 if (!linalgOp)
7493 continue ;
@@ -83,7 +102,7 @@ class SiLUVectorizePattern : public ConversionPattern {
83102 if (!foundInInput)
84103 continue ;
85104
86- // 检查其内部是否有 arith.mulf 操作
105+ // Check if it contains an arith.mulf operation inside.
87106 for (auto &nestedOp : linalgOp.getRegion ().front ()) {
88107 if (isa<arith::MulFOp>(nestedOp)) {
89108 mulOp = linalgOp;
@@ -96,14 +115,14 @@ class SiLUVectorizePattern : public ConversionPattern {
96115 }
97116
98117 if (!mulOp) {
99- llvm::errs () << " Didn't find a consumer linalg.generic using sigmoid output with mulf.\n " ;
118+ llvm::errs () << " Didn't find a consumer linalg.generic using sigmoid "
119+ " output with mulf.\n " ;
100120 return failure ();
101121 }
102122
103- // Set the insertion point before the mulOp. This ensures that the new affine
104- // loop is inserted at a point that is dominated by the allocation of the
105- // output buffer.
106- // rewriter.setInsertionPoint(mulOp);
123+ // Set the insertion point before the mulOp. This ensures that the new
124+ // affine loop is inserted at a point that is dominated by the allocation of
125+ // the output buffer. rewriter.setInsertionPoint(mulOp);
107126
108127 // Now we have matched the silu pattern: sigmoid followed by a mul.
109128 // The rewrite logic will be applied to the sigmoidOp, and the mulOp will be
@@ -122,57 +141,46 @@ class SiLUVectorizePattern : public ConversionPattern {
122141
123142 // Define constants.
124143 Value c0 = rewriter.create <arith::ConstantIndexOp>(loc, 0 );
125- Value c1 = rewriter.create <arith::ConstantIndexOp>(loc, 1 );
126- const int64_t unrollFactor = 2 ;
127- Value cUnrollVec =
128- rewriter.create <arith::ConstantIndexOp>(loc, vectorSize * unrollFactor);
129- Value cst1f = rewriter.create <arith::ConstantOp>(
130- loc, rewriter.getF32FloatAttr (1.0 ));
144+ Value cst1f =
145+ rewriter.create <arith::ConstantOp>(loc, rewriter.getF32FloatAttr (1.0 ));
131146 Value vec1f = rewriter.create <vector::BroadcastOp>(loc, vectorType, cst1f);
147+ Value cst0f =
148+ rewriter.create <arith::ConstantOp>(loc, rewriter.getF32FloatAttr (0 .0f ));
132149
133150 // Get dimensions.
134151 Value d0 = rewriter.create <memref::DimOp>(loc, input, 0 );
135152 Value d1 = rewriter.create <memref::DimOp>(loc, input, 1 );
136153 Value d2 = rewriter.create <memref::DimOp>(loc, input, 2 );
137154
138155 // Create loop nest.
139- scf::ForOp iLoop = rewriter.create <scf::ForOp>(loc, c0, d0, c1);
156+ AffineMap map = rewriter.getDimIdentityMap ();
157+ affine::AffineForOp iLoop = rewriter.create <affine::AffineForOp>(
158+ loc, ValueRange{c0}, map, ValueRange{d0}, map);
140159 rewriter.setInsertionPointToStart (iLoop.getBody ());
141160 Value iv_i = iLoop.getInductionVar ();
142161
143- scf::ForOp jLoop = rewriter.create <scf::ForOp>(loc, c0, d1, c1);
162+ affine::AffineForOp jLoop = rewriter.create <affine::AffineForOp>(
163+ loc, ValueRange{c0}, map, ValueRange{d1}, map);
144164 rewriter.setInsertionPointToStart (jLoop.getBody ());
145165 Value iv_j = jLoop.getInductionVar ();
146166
147- scf::ForOp kLoop = rewriter.create <scf::ForOp>(loc, c0, d2, cUnrollVec);
167+ affine::AffineForOp kLoop = rewriter.create <affine::AffineForOp>(
168+ loc, ValueRange{c0}, map, ValueRange{d2}, map, vectorSize);
148169 rewriter.setInsertionPointToStart (kLoop .getBody ());
149170 Value iv_k = kLoop .getInductionVar ();
150171
151- // Prefetch
152- Value k_next = rewriter.create <arith::AddIOp>(loc, iv_k, cUnrollVec);
153- rewriter.create <memref::PrefetchOp>(loc, input, ValueRange{iv_i, iv_j, k_next},
154- /* isWrite=*/ false , /* localityHint=*/ 3 ,
155- /* isDataCache=*/ true );
156-
157- // Unrolled loop body
158- for (int i = 0 ; i < unrollFactor; ++i) {
159- Value k_offset =
160- rewriter.create <arith::ConstantIndexOp>(loc, i * vectorSize);
161- Value k_i = rewriter.create <arith::AddIOp>(loc, iv_k, k_offset);
162-
163- // --- Process Vector ---
164- Value x_vec = rewriter.create <vector::LoadOp>(
165- loc, vectorType, input, ValueRange{iv_i, iv_j, k_i});
166- Value neg_x_vec = rewriter.create <arith::NegFOp>(loc, x_vec);
167- Value exp_neg_x_vec = rewriter.create <math::ExpOp>(loc, neg_x_vec);
168- Value one_plus_exp_vec =
169- rewriter.create <arith::AddFOp>(loc, vec1f, exp_neg_x_vec);
170- Value sigmoid_x_vec =
171- rewriter.create <arith::DivFOp>(loc, vec1f, one_plus_exp_vec);
172- Value silu_vec = rewriter.create <arith::MulFOp>(loc, x_vec, sigmoid_x_vec);
173- rewriter.create <vector::StoreOp>(loc, silu_vec, output,
174- ValueRange{iv_i, iv_j, k_i});
175- }
172+ // --- Process Vector ---
173+ Value x_vec = rewriter.create <vector::TransferReadOp>(
174+ loc, vectorType, input, ValueRange{iv_i, iv_j, iv_k}, cst0f);
175+ Value neg_x_vec = rewriter.create <arith::NegFOp>(loc, x_vec);
176+ Value exp_neg_x_vec = rewriter.create <math::ExpOp>(loc, neg_x_vec);
177+ Value one_plus_exp_vec =
178+ rewriter.create <arith::AddFOp>(loc, vec1f, exp_neg_x_vec);
179+ Value sigmoid_x_vec =
180+ rewriter.create <arith::DivFOp>(loc, vec1f, one_plus_exp_vec);
181+ Value silu_vec = rewriter.create <arith::MulFOp>(loc, x_vec, sigmoid_x_vec);
182+ rewriter.create <vector::TransferWriteOp>(loc, silu_vec, output,
183+ ValueRange{iv_i, iv_j, iv_k});
176184
177185 // Replace the original mulOp with the result from our new computation.
178186 // The 'output' buffer now holds the final result. `replaceOp` will
@@ -214,36 +222,37 @@ class SiluOptimizationPass
214222 }
215223
216224 Option<int64_t > vectorSize{*this , " vector-size" ,
217- llvm::cl::desc (" Vector size for SiLU." ),
218- llvm::cl::init (8 )};
225+ llvm::cl::desc (" Vector size for SiLU." ),
226+ llvm::cl::init (8 )};
219227};
220228
221229void SiluOptimizationPass::runOnOperation () {
222230 MLIRContext *context = &getContext ();
223231 ModuleOp module = getOperation ();
224232
225233 ConversionTarget target (*context);
226- target. addLegalDialect <arith::ArithDialect, affine::AffineDialect,
227- memref::MemRefDialect, vector::VectorDialect ,
228- func::FuncDialect, math::MathDialect ,
229- scf::SCFDialect>();
234+ target
235+ . addLegalDialect <arith::ArithDialect, affine::AffineDialect ,
236+ memref::MemRefDialect, vector::VectorDialect ,
237+ func::FuncDialect, math::MathDialect, scf::SCFDialect>();
230238 target.addLegalOp <ModuleOp, func::FuncOp>();
231-
232- // We will manually mark linalg.generic as illegal if it is part of a SiLU pattern.
233- // The pattern itself will handle the legality checks and replacements.
234- // Therefore, we don't need to addIllegalOp<linalg::GenericOp>() here.
235-
239+
240+ // We will manually mark linalg.generic as illegal if it is part of a SiLU
241+ // pattern. The pattern itself will handle the legality checks and
242+ // replacements. Therefore, we don't need to addIllegalOp<linalg::GenericOp>()
243+ // here.
244+
236245 RewritePatternSet patterns (context);
237246 patterns.add <SiLUVectorizePattern>(context, vectorSize);
238247
239248 if (failed (applyPartialConversion (module , target, std::move (patterns))))
240249 signalPassFailure ();
241250}
242- } // end anonymous namespace
251+ } // end anonymous namespace
243252namespace mlir {
244253namespace buddy {
245254void registerSiluOptimizationPass () {
246255 PassRegistration<SiluOptimizationPass>();
247256}
248257} // namespace buddy
249- } // namespace mlir
258+ } // namespace mlir
0 commit comments