diff --git a/examples/BuddyNext/makefile b/examples/BuddyNext/makefile index 750bd93f85..641f4c62e8 100644 --- a/examples/BuddyNext/makefile +++ b/examples/BuddyNext/makefile @@ -16,6 +16,52 @@ MLIR_RUNNER_UTILS := ../../llvm/build/lib/libmlir_runner_utils.dylib MLIR_C_RUNNER_UTILS := ../../llvm/build/lib/libmlir_c_runner_utils.dylib MTRIPLE := x86_64-apple-darwin endif +next-silu-run: + @${MLIR_OPT} ./next-silu.mlir \ + -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,tosa-to-tensor,tosa-to-arith))" | \ + ${MLIR_OPT} \ + -arith-expand \ + -eliminate-empty-tensors \ + -empty-tensor-to-alloc-tensor \ + -one-shot-bufferize="bufferize-function-boundaries" \ + -convert-linalg-to-affine-loops \ + -affine-loop-fusion \ + -lower-affine \ + -convert-vector-to-scf \ + -expand-strided-metadata \ + -convert-vector-to-llvm \ + -memref-expand \ + -arith-expand \ + -convert-arith-to-llvm \ + -finalize-memref-to-llvm \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-openmp-to-llvm \ + -convert-arith-to-llvm \ + -convert-math-to-llvm \ + -convert-math-to-libm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} + +next-silu-silu-run: + @${MLIR_OPT} ./next-silu-silu.mlir \ + -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,tosa-to-tensor,tosa-to-arith))" | \ + ${MLIR_OPT} \ + -convert-linalg-to-loops \ + -lower-affine \ + -convert-vector-to-scf \ + -convert-scf-to-cf \ + -convert-cf-to-llvm \ + -convert-vector-to-llvm \ + -finalize-memref-to-llvm \ + -convert-math-to-llvm \ + -convert-arith-to-llvm \ + -convert-func-to-llvm \ + -reconcile-unrealized-casts | \ + ${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \ + -shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS} next-attention-lower: @${MLIR_OPT} ./next-attention.mlir \ diff --git a/examples/BuddyNext/next-silu-silu.mlir b/examples/BuddyNext/next-silu-silu.mlir new file mode 100644 index 0000000000..d09e071d23 --- /dev/null +++ b/examples/BuddyNext/next-silu-silu.mlir @@ -0,0 +1,77 @@ +// RUN: buddy-opt %s \ +// RUN: -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,tosa-to-tensor,tosa-to-arith))" \ +// RUN: | buddy-opt \ +// RUN: -convert-linalg-to-loops \ +// RUN: -lower-affine \ +// RUN: -convert-vector-to-scf \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-cf-to-llvm \ +// RUN: -convert-vector-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s +#map = affine_map<(d0) -> (d0)> + func.func private @rtclock() -> f64 + func.func private @printMemrefF32(%ptr: memref<*xf32>) attributes {llvm.emit_c_interface} + + func.func @kernel(%arg0: memref<1x40x8960xf32>) { + %t_start = call @rtclock() : () -> f64 + + %output = memref.alloc() : memref<1x40x8960xf32> + + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %cst_1f = arith.constant 1.0 : f32 + %vec_1f = vector.broadcast %cst_1f : f32 to vector<8xf32> + %cst_0f = arith.constant 0.0 : f32 // for padding + + %d0 = memref.dim %arg0, %c0 : memref<1x40x8960xf32> + %d1 = memref.dim %arg0, %c1 : memref<1x40x8960xf32> + %d2 = memref.dim %arg0, %c2 : memref<1x40x8960xf32> + + affine.for %i = #map(%c0) to #map(%d0) { + affine.for %j = #map(%c0) to #map(%d1) { + affine.for %k = #map(%c0) to #map(%d2) step 8 { + %x_vec = vector.transfer_read %arg0[%i, %j, %k], %cst_0f : memref<1x40x8960xf32>, vector<8xf32> + %neg_x_vec = arith.negf %x_vec : vector<8xf32> + %exp_neg_x_vec = math.exp %neg_x_vec : vector<8xf32> + %one_plus_exp_vec = arith.addf %vec_1f, %exp_neg_x_vec : vector<8xf32> + %sigmoid_x_vec = arith.divf %vec_1f, %one_plus_exp_vec : vector<8xf32> + %silu_vec = arith.mulf %x_vec, %sigmoid_x_vec : vector<8xf32> + vector.transfer_write %silu_vec, %output[%i, %j, %k] : vector<8xf32>, memref<1x40x8960xf32> + } + } + } + + %t_end = call @rtclock() : () -> f64 + %unranked_result = memref.cast %output : memref<1x40x8960xf32> to memref<*xf32> + // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 40, 8960] strides = [358400, 8960, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [2.85772{{(, 2.85772)*}}], + call @printMemrefF32(%unranked_result) : (memref<*xf32>) -> () + memref.dealloc %output : memref<1x40x8960xf32> + + %time = arith.subf %t_end, %t_start : f64 + vector.print %time : f64 + + return + } + + func.func @main() { + %input = memref.alloc() : memref<1x40x8960xf32> + %cst_neg_1_23 = arith.constant 3.0 : f32 + linalg.fill ins(%cst_neg_1_23 : f32) outs(%input : memref<1x40x8960xf32>) + + call @kernel(%input) : (memref<1x40x8960xf32>) -> () + + memref.dealloc %input : memref<1x40x8960xf32> + + return + } diff --git a/examples/BuddyNext/next-silu.mlir b/examples/BuddyNext/next-silu.mlir new file mode 100644 index 0000000000..de6a927ce4 --- /dev/null +++ b/examples/BuddyNext/next-silu.mlir @@ -0,0 +1,64 @@ +// RUN: buddy-opt %s \ +// RUN: -pass-pipeline="builtin.module(func.func(tosa-to-linalg-named,tosa-to-linalg,tosa-to-tensor,tosa-to-arith))" \ +// RUN: | buddy-opt \ +// RUN: -arith-expand \ +// RUN: -eliminate-empty-tensors \ +// RUN: -empty-tensor-to-alloc-tensor \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -convert-linalg-to-affine-loops \ +// RUN: -affine-loop-fusion \ +// RUN: -lower-affine \ +// RUN: -convert-vector-to-scf \ +// RUN: -expand-strided-metadata \ +// RUN: -convert-vector-to-llvm \ +// RUN: -memref-expand \ +// RUN: -arith-expand \ +// RUN: -convert-arith-to-llvm \ +// RUN: -finalize-memref-to-llvm \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-cf-to-llvm \ +// RUN: -convert-openmp-to-llvm \ +// RUN: -convert-arith-to-llvm \ +// RUN: -convert-math-to-llvm \ +// RUN: -convert-math-to-libm \ +// RUN: -convert-func-to-llvm \ +// RUN: -reconcile-unrealized-casts \ +// RUN: | mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \ +// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \ +// RUN: | FileCheck %s + func.func private @rtclock() -> f64 + + func.func @kenerl(%arg0: tensor<1x40x8960xf32>) { + %t_start = call @rtclock() : () -> f64 + + %sigmoid_x = tosa.sigmoid %arg0 : (tensor<1x40x8960xf32>) -> tensor<1x40x8960xf32> + + %silu_result = tosa.mul %arg0, %sigmoid_x {shift = 0 : i8} : (tensor<1x40x8960xf32>, tensor<1x40x8960xf32>) -> tensor<1x40x8960xf32> + + %t_end = call @rtclock() : () -> f64 + %time = arith.subf %t_end, %t_start : f64 + + %unranked_result = tensor.cast %silu_result : tensor<1x40x8960xf32> to tensor<*xf32> + + // All the elements of the MemRef are the same, + // only check the first line to verify the correctness. + // CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 40, 8960] strides = [358400, 8960, 1] data = + // CHECK-NEXT: [ + // CHECK-SAME: [2.85772{{(, 2.85772)*}}], + + // print results. + call @printMemrefF32(%unranked_result) : (tensor<*xf32>) -> () + // print timings. + vector.print %time : f64 + + return + } + + func.func @main() { + %input_tensor = arith.constant dense<3.0> : tensor<1x40x8960xf32> + call @kenerl(%input_tensor) : (tensor<1x40x8960xf32>) -> () + + return + } + func.func private @printMemrefF32(%ptr : tensor<*xf32>) diff --git a/midend/lib/CMakeLists.txt b/midend/lib/CMakeLists.txt index 26477e0e40..965e50012d 100644 --- a/midend/lib/CMakeLists.txt +++ b/midend/lib/CMakeLists.txt @@ -25,6 +25,7 @@ set(LinkedLibs BatchMatMulOptimization MatMulParallelVectorization TransposeOptimization + SiluOptimization ) diff --git a/midend/lib/Conversion/CMakeLists.txt b/midend/lib/Conversion/CMakeLists.txt index 6ad92c6892..48fc3d8aa5 100644 --- a/midend/lib/Conversion/CMakeLists.txt +++ b/midend/lib/Conversion/CMakeLists.txt @@ -14,3 +14,4 @@ add_subdirectory(LowerLinalgToGemmini) add_subdirectory(FuncBufferize) add_subdirectory(DepthwiseConvOptimization) add_subdirectory(MLIRGPU) +add_subdirectory(SiluOptimization) diff --git a/midend/lib/Conversion/SiluOptimization/CMakeLists.txt b/midend/lib/Conversion/SiluOptimization/CMakeLists.txt new file mode 100644 index 0000000000..64696f80c0 --- /dev/null +++ b/midend/lib/Conversion/SiluOptimization/CMakeLists.txt @@ -0,0 +1,5 @@ +add_mlir_library(SiluOptimization + SiluOptimization.cpp + LINK_LIBS PUBLIC + BuddyUtils +) diff --git a/midend/lib/Conversion/SiluOptimization/SiluOptimization.cpp b/midend/lib/Conversion/SiluOptimization/SiluOptimization.cpp new file mode 100644 index 0000000000..f036001649 --- /dev/null +++ b/midend/lib/Conversion/SiluOptimization/SiluOptimization.cpp @@ -0,0 +1,258 @@ +//====- SiluOptimization.cpp - Silu Optimization Pass ---------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the pass that vectorizes the linalg.generic representing +// SiLU. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; + +namespace { + +//===----------------------------------------------------------------------===// +// Rewrite Pattern +//===----------------------------------------------------------------------===// + +class SiLUVectorizePattern : public ConversionPattern { +public: + explicit SiLUVectorizePattern(MLIRContext *context, int64_t vectorSizeParam) + : ConversionPattern(linalg::GenericOp::getOperationName(), 1, context) { + vectorSize = vectorSizeParam; + } + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + linalg::GenericOp sigmoidOp = cast(op); + + //--------------sigmoid OP-------- + // Check input/output + if (sigmoidOp.getNumDpsInputs() != 1 || sigmoidOp.getNumDpsInits() != 1) { + return failure(); + } + + // Check the body of the op for sigmoid computation. + // The IR should be: negf, exp, addf, divf, yield. + Block &block = sigmoidOp.getRegion().front(); + if (block.getOperations().size() != 5) { // negf, exp, addf, divf, yield + return failure(); + } + + Operation &negfOp = block.getOperations().front(); + Operation &yieldOp = block.getOperations().back(); + + // Check the type of the two operations. + if (!isa(negfOp) || !isa(yieldOp)) { + return failure(); + } + + //-----------Find the consumer mul operation.------------------------------ + // The result of the sigmoid op must be used by another linalg.generic op. + Value outputBuffer = sigmoidOp.getDpsInitOperand(0)->get(); + + // Iterate over all uses to find a suitable consumer op. + linalg::GenericOp mulOp = nullptr; + + for (auto &use : outputBuffer.getUses()) { + Operation *user = use.getOwner(); + + // It must be a linalg.generic, and the buffer must be an input operand + // (i.e., ins()). + auto linalgOp = dyn_cast(user); + if (!linalgOp) + continue; + + bool foundInInput = false; + for (OpOperand *input : linalgOp.getDpsInputOperands()) { + if (input->get() == outputBuffer) { + foundInInput = true; + break; + } + } + if (!foundInInput) + continue; + + // Check if it contains an arith.mulf operation inside. + for (auto &nestedOp : linalgOp.getRegion().front()) { + if (isa(nestedOp)) { + mulOp = linalgOp; + break; + } + } + + if (mulOp) + break; + } + + if (!mulOp) { + llvm::errs() << "Didn't find a consumer linalg.generic using sigmoid " + "output with mulf.\n"; + return failure(); + } + + // Set the insertion point before the mulOp. This ensures that the new + // affine loop is inserted at a point that is dominated by the allocation of + // the output buffer. rewriter.setInsertionPoint(mulOp); + + // Now we have matched the silu pattern: sigmoid followed by a mul. + // The rewrite logic will be applied to the sigmoidOp, and the mulOp will be + // erased. + + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(mulOp); + Location loc = sigmoidOp.getLoc(); + Value input = sigmoidOp.getDpsInputOperand(0)->get(); + // The final output buffer comes from the mulOp. + Value output = mulOp.getDpsInitOperand(0)->get(); + + auto inputMemRefType = input.getType().cast(); + Type elementType = inputMemRefType.getElementType(); + VectorType vectorType = VectorType::get({vectorSize}, elementType); + + // Define constants. + Value c0 = rewriter.create(loc, 0); + Value cst1f = + rewriter.create(loc, rewriter.getF32FloatAttr(1.0)); + Value vec1f = rewriter.create(loc, vectorType, cst1f); + Value cst0f = + rewriter.create(loc, rewriter.getF32FloatAttr(0.0f)); + + // Get dimensions. + Value d0 = rewriter.create(loc, input, 0); + Value d1 = rewriter.create(loc, input, 1); + Value d2 = rewriter.create(loc, input, 2); + + // Create loop nest. + AffineMap map = rewriter.getDimIdentityMap(); + affine::AffineForOp iLoop = rewriter.create( + loc, ValueRange{c0}, map, ValueRange{d0}, map); + rewriter.setInsertionPointToStart(iLoop.getBody()); + Value iv_i = iLoop.getInductionVar(); + + affine::AffineForOp jLoop = rewriter.create( + loc, ValueRange{c0}, map, ValueRange{d1}, map); + rewriter.setInsertionPointToStart(jLoop.getBody()); + Value iv_j = jLoop.getInductionVar(); + + affine::AffineForOp kLoop = rewriter.create( + loc, ValueRange{c0}, map, ValueRange{d2}, map, vectorSize); + rewriter.setInsertionPointToStart(kLoop.getBody()); + Value iv_k = kLoop.getInductionVar(); + + // --- Process Vector --- + Value x_vec = rewriter.create( + loc, vectorType, input, ValueRange{iv_i, iv_j, iv_k}, cst0f); + Value neg_x_vec = rewriter.create(loc, x_vec); + Value exp_neg_x_vec = rewriter.create(loc, neg_x_vec); + Value one_plus_exp_vec = + rewriter.create(loc, vec1f, exp_neg_x_vec); + Value sigmoid_x_vec = + rewriter.create(loc, vec1f, one_plus_exp_vec); + Value silu_vec = rewriter.create(loc, x_vec, sigmoid_x_vec); + rewriter.create(loc, silu_vec, output, + ValueRange{iv_i, iv_j, iv_k}); + + // Replace the original mulOp with the result from our new computation. + // The 'output' buffer now holds the final result. `replaceOp` will + // replace all uses of mulOp's results with `output` and then erase mulOp. + rewriter.eraseOp(mulOp); + rewriter.eraseOp(sigmoidOp); + + return success(); + } + +private: + int64_t vectorSize; +}; + +//===----------------------------------------------------------------------===// +// Pass +//===----------------------------------------------------------------------===// + +class SiluOptimizationPass + : public PassWrapper> { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SiluOptimizationPass) + StringRef getArgument() const final { return "silu-optimization"; } + StringRef getDescription() const final { + return "Vectorize linalg.generic representing SiLU."; + } + SiluOptimizationPass() = default; + SiluOptimizationPass(const SiluOptimizationPass &) {} + explicit SiluOptimizationPass(int64_t vectorSizeParam) { + vectorSize = vectorSizeParam; + } + + void runOnOperation() override; + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + Option vectorSize{*this, "vector-size", + llvm::cl::desc("Vector size for SiLU."), + llvm::cl::init(8)}; +}; + +void SiluOptimizationPass::runOnOperation() { + MLIRContext *context = &getContext(); + ModuleOp module = getOperation(); + + ConversionTarget target(*context); + target + .addLegalDialect(); + target.addLegalOp(); + + // We will manually mark linalg.generic as illegal if it is part of a SiLU + // pattern. The pattern itself will handle the legality checks and + // replacements. Therefore, we don't need to addIllegalOp() + // here. + + RewritePatternSet patterns(context); + patterns.add(context, vectorSize); + + if (failed(applyPartialConversion(module, target, std::move(patterns)))) + signalPassFailure(); +} +} // end anonymous namespace +namespace mlir { +namespace buddy { +void registerSiluOptimizationPass() { + PassRegistration(); +} +} // namespace buddy +} // namespace mlir \ No newline at end of file diff --git a/midend/lib/InitAll.cpp b/midend/lib/InitAll.cpp index 7d8929e055..8c2ab1de56 100644 --- a/midend/lib/InitAll.cpp +++ b/midend/lib/InitAll.cpp @@ -45,6 +45,7 @@ void registerMatMulOptimizePass(); void registerMatMulParallelVectorizationPass(); void registerMatMulVectorizationPass(); void registerTransposeOptimizationPass(); +void registerSiluOptimizationPass(); } // namespace buddy } // namespace mlir @@ -74,4 +75,5 @@ void mlir::buddy::registerAllPasses() { mlir::buddy::registerMatMulParallelVectorizationPass(); mlir::buddy::registerMatMulVectorizationPass(); mlir::buddy::registerTransposeOptimizationPass(); + mlir::buddy::registerSiluOptimizationPass(); } diff --git a/tests/Conversion/silu-optimization.mlir b/tests/Conversion/silu-optimization.mlir new file mode 100644 index 0000000000..a6cbc8c169 --- /dev/null +++ b/tests/Conversion/silu-optimization.mlir @@ -0,0 +1,56 @@ +// RUN: buddy-opt -silu-optimization="vector-size=8" %s | FileCheck %s + +// CHECK: #map = affine_map<(d0) -> (d0)> +// CHECK: module { +// CHECK: func.func @silu_tosa(%arg0: memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>>) -> memref<1x40x8960xf32> { +// CHECK:%cst = arith.constant 1.000000e+00 : f32 +// CHECK: %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x40x8960xf32> +// CHECK: %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x40x8960xf32> +// CHECK: %c0 = arith.constant 0 : index +// CHECK: %cst_1 = arith.constant 1.000000e+00 : f32 +// CHECK: %0 = vector.broadcast %cst_1 : f32 to vector<8xf32> +// CHECK: %cst_2 = arith.constant 0.000000e+00 : f32 +// CHECK: %c0_3 = arith.constant 0 : index +// CHECK: %dim = memref.dim %arg0, %c0_3 : memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>> +// CHECK: %c1 = arith.constant 1 : index +// CHECK: %dim_4 = memref.dim %arg0, %c1 : memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>> +// CHECK: %c2 = arith.constant 2 : index +// CHECK: %dim_5 = memref.dim %arg0, %c2 : memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>> +// CHECK: affine.for %arg1 = #map(%c0) to #map(%dim) { +// CHECK: affine.for %arg2 = #map(%c0) to #map(%dim_4) { +// CHECK: affine.for %arg3 = #map(%c0) to #map(%dim_5) step 8 { +// CHECK: %1 = vector.transfer_read %arg0[%arg1, %arg2, %arg3], %cst_2 : memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>>, vector<8xf32> +// CHECK: %2 = arith.negf %1 : vector<8xf32> +// CHECK: %3 = math.exp %2 : vector<8xf32> +// CHECK: %4 = arith.addf %0, %3 : vector<8xf32> +// CHECK: %5 = arith.divf %0, %4 : vector<8xf32> +// CHECK: %6 = arith.mulf %1, %5 : vector<8xf32> +// CHECK: vector.transfer_write %6, %alloc_0[%arg1, %arg2, %arg3] : vector<8xf32>, memref<1x40x8960xf32> +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: return %alloc_0 : memref<1x40x8960xf32> +// CHECK: } +// CHECK: } + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func.func @silu_tosa(%arg0: memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>>) -> memref<1x40x8960xf32> { + %cst = arith.constant 1.000000e+00 : f32 + %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x40x8960xf32> + linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>>) outs(%alloc : memref<1x40x8960xf32>) { + ^bb0(%in: f32, %out: f32): + %3 = arith.negf %in : f32 + %4 = math.exp %3 : f32 + %5 = arith.addf %4, %cst : f32 + %6 = arith.divf %cst, %5 : f32 + linalg.yield %6 : f32 + } + %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<1x40x8960xf32> + linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0, %alloc : memref<1x40x8960xf32, strided<[?, ?, ?], offset: ?>>, memref<1x40x8960xf32>) outs(%alloc_0 : memref<1x40x8960xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %3 = arith.mulf %in, %in_1 : f32 + linalg.yield %3 : f32 + } + return %alloc_0 : memref<1x40x8960xf32> +} diff --git a/tools/buddy-opt/CMakeLists.txt b/tools/buddy-opt/CMakeLists.txt index 46fcfb4a57..19a9ba9724 100644 --- a/tools/buddy-opt/CMakeLists.txt +++ b/tools/buddy-opt/CMakeLists.txt @@ -42,4 +42,5 @@ target_link_libraries(buddy-opt MLIRTransforms MLIRTransformUtils MatMulTransposeBVec + SiluOptimization ) diff --git a/tools/buddy-opt/buddy-opt.cpp b/tools/buddy-opt/buddy-opt.cpp index fb668a391a..a389cd8f22 100644 --- a/tools/buddy-opt/buddy-opt.cpp +++ b/tools/buddy-opt/buddy-opt.cpp @@ -82,6 +82,7 @@ void registerConvertMemcpyToGPUPass(); void registerLegalizeShmemOutliningPass(); void registerMatMulTransposeBVecPass(); void registerLegalizeShmemOutliningPass(); +void registerSiluOptimizationPass(); } // namespace buddy } // namespace mlir @@ -121,6 +122,7 @@ int main(int argc, char **argv) { mlir::buddy::registerDepthwiseConv2DNhwcHwcOptimizePass(); mlir::buddy::registerFuncBufferizeDynamicOffsetPass(); mlir::buddy::registerMatMulTransposeBVecPass(); + mlir::buddy::registerSiluOptimizationPass(); // Register gpu passes mlir::buddy::registerConvertMemcpyToGPUPass();