Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions examples/BuddyNext/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -811,3 +811,69 @@ next-compass-run:
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} \
-shared-libs=${MLIR_C_RUNNER_UTILS}

next-rmsnorm-run:
@${MLIR_OPT} ./next-rmsnorm.mlir \
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \
${MLIR_OPT} \
-arith-expand \
-eliminate-empty-tensors \
-empty-tensor-to-alloc-tensor \
-one-shot-bufferize \
-convert-linalg-to-affine-loops \
-affine-loop-fusion \
-lower-affine \
-func-bufferize \
-arith-bufferize \
-tensor-bufferize \
-buffer-deallocation \
-finalizing-bufferize \
-convert-vector-to-scf \
-expand-strided-metadata \
-convert-vector-to-llvm \
-memref-expand \
-arith-expand \
-convert-arith-to-llvm \
-finalize-memref-to-llvm \
-convert-scf-to-cf \
-convert-openmp-to-llvm \
-convert-arith-to-llvm \
-convert-math-to-llvm \
-convert-math-to-libm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

next-rmsnorm-manual-run:
@${MLIR_OPT} ./next-rmsnorm-manual.mlir \
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \
${MLIR_OPT} \
-arith-expand \
-eliminate-empty-tensors \
-empty-tensor-to-alloc-tensor \
-one-shot-bufferize \
-convert-linalg-to-affine-loops \
-affine-loop-fusion \
-lower-affine \
-func-bufferize \
-arith-bufferize \
-tensor-bufferize \
-buffer-deallocation \
-finalizing-bufferize \
-convert-vector-to-scf \
-expand-strided-metadata \
-convert-vector-to-llvm \
-memref-expand \
-arith-expand \
-convert-arith-to-llvm \
-finalize-memref-to-llvm \
-convert-scf-to-cf \
-convert-openmp-to-llvm \
-convert-arith-to-llvm \
-convert-math-to-llvm \
-convert-math-to-libm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}
207 changes: 207 additions & 0 deletions examples/BuddyNext/next-rmsnorm-manual.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
// RUN: buddy-opt %s \
// RUN: -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" \
// RUN: | buddy-opt \
// RUN: -arith-expand \
// RUN: -eliminate-empty-tensors \
// RUN: -empty-tensor-to-alloc-tensor \
// RUN: -one-shot-bufferize \
// RUN: -convert-linalg-to-affine-loops \
// RUN: -affine-loop-fusion \
// RUN: -lower-affine \
// RUN: -func-bufferize \
// RUN: -arith-bufferize \
// RUN: -tensor-bufferize \
// RUN: -buffer-deallocation \
// RUN: -finalizing-bufferize \
// RUN: -convert-vector-to-scf \
// RUN: -expand-strided-metadata \
// RUN: -convert-vector-to-llvm \
// RUN: -memref-expand \
// RUN: -arith-expand \
// RUN: -convert-arith-to-llvm \
// RUN: -finalize-memref-to-llvm \
// RUN: -convert-scf-to-cf \
// RUN: -convert-openmp-to-llvm \
// RUN: -convert-arith-to-llvm \
// RUN: -convert-math-to-llvm \
// RUN: -convert-math-to-libm \
// RUN: -convert-func-to-llvm \
// RUN: -reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

func.func private @rtclock() -> f64
func.func private @printMemrefF32(memref<*xf32>)
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

func.func @kernel(%t0: tensor<1x40x1536xf32>, %t1: tensor<1536xf32>) {
%t_start = call @rtclock() : () -> f64

%idx_0 = arith.constant 0 : index
%idx_1 = arith.constant 1 : index
%idx_40 = arith.constant 40 : index

%idx_128 = arith.constant 128 : index
%idx_256 = arith.constant 256 : index
%idx_384 = arith.constant 384 : index
%idx_512 = arith.constant 512 : index
%idx_640 = arith.constant 640 : index
%idx_768 = arith.constant 768 : index
%idx_896 = arith.constant 896 : index
%idx_1024 = arith.constant 1024 : index
%idx_1152 = arith.constant 1152 : index
%idx_1280 = arith.constant 1280 : index
%idx_1408 = arith.constant 1408 : index

%memref_t0 = bufferization.to_memref %t0 : memref<1x40x1536xf32>

%memref_t1 = bufferization.to_memref %t1 : memref<1536xf32>
%weight_0 = vector.load %memref_t1[%idx_0] : memref<1536xf32>, vector<128xf32>
%weight_1 = vector.load %memref_t1[%idx_128] : memref<1536xf32>, vector<128xf32>
%weight_2 = vector.load %memref_t1[%idx_256] : memref<1536xf32>, vector<128xf32>
%weight_3 = vector.load %memref_t1[%idx_384] : memref<1536xf32>, vector<128xf32>
%weight_4 = vector.load %memref_t1[%idx_512] : memref<1536xf32>, vector<128xf32>
%weight_5 = vector.load %memref_t1[%idx_640] : memref<1536xf32>, vector<128xf32>
%weight_6 = vector.load %memref_t1[%idx_768] : memref<1536xf32>, vector<128xf32>
%weight_7 = vector.load %memref_t1[%idx_896] : memref<1536xf32>, vector<128xf32>
%weight_8 = vector.load %memref_t1[%idx_1024] : memref<1536xf32>, vector<128xf32>
%weight_9 = vector.load %memref_t1[%idx_1152] : memref<1536xf32>, vector<128xf32>
%weight_10 = vector.load %memref_t1[%idx_1280] : memref<1536xf32>, vector<128xf32>
%weight_11 = vector.load %memref_t1[%idx_1408] : memref<1536xf32>, vector<128xf32>

%zero = arith.constant 0.0 : f32
%rsqrt_eps = arith.constant 9.99999997E-7 : f32
%scale = arith.constant 1536.0 : f32
%result_memref = memref.alloc() : memref<1x40x1536xf32>

scf.parallel (%i) = (%idx_0) to (%idx_40) step (%idx_1) {
%vec_0 = vector.load %memref_t0[%idx_0, %i, %idx_0] : memref<1x40x1536xf32>, vector<128xf32>
%vec_1 = vector.load %memref_t0[%idx_0, %i, %idx_128] : memref<1x40x1536xf32>, vector<128xf32>
%vec_2 = vector.load %memref_t0[%idx_0, %i, %idx_256] : memref<1x40x1536xf32>, vector<128xf32>
%vec_3 = vector.load %memref_t0[%idx_0, %i, %idx_384] : memref<1x40x1536xf32>, vector<128xf32>
%square_0 = arith.mulf %vec_0, %vec_0 : vector<128xf32>
%square_1 = arith.mulf %vec_1, %vec_1 : vector<128xf32>
%square_2 = arith.mulf %vec_2, %vec_2 : vector<128xf32>
%square_3 = arith.mulf %vec_3, %vec_3 : vector<128xf32>
%sum_0 = vector.reduction <add>, %square_0 : vector<128xf32> into f32
%sum_1 = vector.reduction <add>, %square_1 : vector<128xf32> into f32
%sum_2 = vector.reduction <add>, %square_2 : vector<128xf32> into f32
%sum_3 = vector.reduction <add>, %square_3 : vector<128xf32> into f32

%vec_4 = vector.load %memref_t0[%idx_0, %i, %idx_512] : memref<1x40x1536xf32>, vector<128xf32>
%vec_5 = vector.load %memref_t0[%idx_0, %i, %idx_640] : memref<1x40x1536xf32>, vector<128xf32>
%vec_6 = vector.load %memref_t0[%idx_0, %i, %idx_768] : memref<1x40x1536xf32>, vector<128xf32>
%vec_7 = vector.load %memref_t0[%idx_0, %i, %idx_896] : memref<1x40x1536xf32>, vector<128xf32>
%square_4 = arith.mulf %vec_4, %vec_4 : vector<128xf32>
%square_5 = arith.mulf %vec_5, %vec_5 : vector<128xf32>
%square_6 = arith.mulf %vec_6, %vec_6 : vector<128xf32>
%square_7 = arith.mulf %vec_7, %vec_7 : vector<128xf32>
%sum_4 = vector.reduction <add>, %square_4 : vector<128xf32> into f32
%sum_5 = vector.reduction <add>, %square_5 : vector<128xf32> into f32
%sum_6 = vector.reduction <add>, %square_6 : vector<128xf32> into f32
%sum_7 = vector.reduction <add>, %square_7 : vector<128xf32> into f32

%vec_8 = vector.load %memref_t0[%idx_0, %i, %idx_1024] : memref<1x40x1536xf32>, vector<128xf32>
%vec_9 = vector.load %memref_t0[%idx_0, %i, %idx_1152] : memref<1x40x1536xf32>, vector<128xf32>
%vec_10 = vector.load %memref_t0[%idx_0, %i, %idx_1280] : memref<1x40x1536xf32>, vector<128xf32>
%vec_11 = vector.load %memref_t0[%idx_0, %i, %idx_1408] : memref<1x40x1536xf32>, vector<128xf32>
%square_8 = arith.mulf %vec_8, %vec_8 : vector<128xf32>
%square_9 = arith.mulf %vec_9, %vec_9 : vector<128xf32>
%square_10 = arith.mulf %vec_10, %vec_10 : vector<128xf32>
%square_11 = arith.mulf %vec_11, %vec_11 : vector<128xf32>
%sum_8 = vector.reduction <add>, %square_8 : vector<128xf32> into f32
%sum_9 = vector.reduction <add>, %square_9 : vector<128xf32> into f32
%sum_10 = vector.reduction <add>, %square_10 : vector<128xf32> into f32
%sum_11 = vector.reduction <add>, %square_11 : vector<128xf32> into f32

// level 1
%l1_0 = arith.addf %sum_0, %sum_1 : f32
%l1_1 = arith.addf %sum_2, %sum_3 : f32
%l1_2 = arith.addf %sum_4, %sum_5 : f32
%l1_3 = arith.addf %sum_6, %sum_7 : f32
%l1_4 = arith.addf %sum_8, %sum_9 : f32
%l1_5 = arith.addf %sum_10, %sum_11 : f32
// level 2
%l2_0 = arith.addf %l1_0, %l1_1 : f32
%l2_1 = arith.addf %l1_2, %l1_3 : f32
%l2_2 = arith.addf %l1_4, %l1_5 : f32
// level 3
%l3_0 = arith.addf %l2_0, %l2_1 : f32
// final sum
%sum_all = arith.addf %l3_0, %l2_2 : f32

%mean = arith.divf %sum_all, %scale : f32
%var = arith.addf %mean, %rsqrt_eps : f32
%inv_std = math.rsqrt %var : f32
%inv_std_vec = vector.splat %inv_std : vector<128xf32>

%vec_0_new = arith.mulf %vec_0, %inv_std_vec : vector<128xf32>
%vec_1_new = arith.mulf %vec_1, %inv_std_vec : vector<128xf32>
%vec_2_new = arith.mulf %vec_2, %inv_std_vec : vector<128xf32>
%vec_3_new = arith.mulf %vec_3, %inv_std_vec : vector<128xf32>
%vec_0_result = arith.mulf %vec_0_new, %weight_0 : vector<128xf32>
%vec_1_result = arith.mulf %vec_1_new, %weight_1 : vector<128xf32>
%vec_2_result = arith.mulf %vec_2_new, %weight_2 : vector<128xf32>
%vec_3_result = arith.mulf %vec_3_new, %weight_3 : vector<128xf32>
vector.store %vec_0_result, %result_memref[%idx_0, %i, %idx_0] : memref<1x40x1536xf32>, vector<128xf32>
vector.store %vec_1_result, %result_memref[%idx_0, %i, %idx_128] : memref<1x40x1536xf32>, vector<128xf32>
vector.store %vec_2_result, %result_memref[%idx_0, %i, %idx_256] : memref<1x40x1536xf32>, vector<128xf32>
vector.store %vec_3_result, %result_memref[%idx_0, %i, %idx_384] : memref<1x40x1536xf32>, vector<128xf32>

%vec_4_new = arith.mulf %vec_4, %inv_std_vec : vector<128xf32>
%vec_5_new = arith.mulf %vec_5, %inv_std_vec : vector<128xf32>
%vec_6_new = arith.mulf %vec_6, %inv_std_vec : vector<128xf32>
%vec_7_new = arith.mulf %vec_7, %inv_std_vec : vector<128xf32>
%vec_4_result = arith.mulf %vec_4_new, %weight_4 : vector<128xf32>
%vec_5_result = arith.mulf %vec_5_new, %weight_5 : vector<128xf32>
%vec_6_result = arith.mulf %vec_6_new, %weight_6 : vector<128xf32>
%vec_7_result = arith.mulf %vec_7_new, %weight_7 : vector<128xf32>
vector.store %vec_4_result, %result_memref[%idx_0, %i, %idx_512] : memref<1x40x1536xf32>, vector<128xf32>
vector.store %vec_5_result, %result_memref[%idx_0, %i, %idx_640] : memref<1x40x1536xf32>, vector<128xf32>
vector.store %vec_6_result, %result_memref[%idx_0, %i, %idx_768] : memref<1x40x1536xf32>, vector<128xf32>
vector.store %vec_7_result, %result_memref[%idx_0, %i, %idx_896] : memref<1x40x1536xf32>, vector<128xf32>

%vec_8_new = arith.mulf %vec_8, %inv_std_vec : vector<128xf32>
%vec_9_new = arith.mulf %vec_9, %inv_std_vec : vector<128xf32>
%vec_10_new = arith.mulf %vec_10, %inv_std_vec : vector<128xf32>
%vec_11_new = arith.mulf %vec_11, %inv_std_vec : vector<128xf32>
%vec_8_result = arith.mulf %vec_8_new, %weight_8 : vector<128xf32>
%vec_9_result = arith.mulf %vec_9_new, %weight_9 : vector<128xf32>
%vec_10_result = arith.mulf %vec_10_new, %weight_10 : vector<128xf32>
%vec_11_result = arith.mulf %vec_11_new, %weight_11 : vector<128xf32>
vector.store %vec_8_result, %result_memref[%idx_0, %i, %idx_1024] : memref<1x40x1536xf32>, vector<128xf32>
vector.store %vec_9_result, %result_memref[%idx_0, %i, %idx_1152] : memref<1x40x1536xf32>, vector<128xf32>
vector.store %vec_10_result, %result_memref[%idx_0, %i, %idx_1280] : memref<1x40x1536xf32>, vector<128xf32>
vector.store %vec_11_result, %result_memref[%idx_0, %i, %idx_1408] : memref<1x40x1536xf32>, vector<128xf32>
}

%t_end = call @rtclock() : () -> f64
%time = arith.subf %t_end, %t_start : f64
%print_result = memref.cast %result_memref : memref<1x40x1536xf32> to memref<*xf32>

// All the elements of the MemRef are the same,
// only check the first line to verify the correctness.
// CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 40, 1536] strides = [61440, 1536, 1] data =
// CHECK-NEXT: [
// CHECK-SAME: [
// CHECK-SAME: [2{{(, 2)*}}],

// Print results.
call @printMemrefF32(%print_result) : (memref<*xf32>) -> ()
// Print timings.
vector.print %time : f64

return
}

func.func @main() {

%c0 = arith.constant dense<3.0> : tensor<1x40x1536xf32>
%c1 = arith.constant dense <2.0> : tensor<1536xf32>

call @kernel(%c0, %c1) : (tensor<1x40x1536xf32>, tensor<1536xf32>) -> ()

return
}
91 changes: 91 additions & 0 deletions examples/BuddyNext/next-rmsnorm.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// RUN: buddy-opt %s \
// RUN: -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" \
// RUN: | buddy-opt \
// RUN: -arith-expand \
// RUN: -eliminate-empty-tensors \
// RUN: -empty-tensor-to-alloc-tensor \
// RUN: -one-shot-bufferize \
// RUN: -convert-linalg-to-affine-loops \
// RUN: -affine-loop-fusion \
// RUN: -lower-affine \
// RUN: -func-bufferize \
// RUN: -arith-bufferize \
// RUN: -tensor-bufferize \
// RUN: -buffer-deallocation \
// RUN: -finalizing-bufferize \
// RUN: -convert-vector-to-scf \
// RUN: -expand-strided-metadata \
// RUN: -convert-vector-to-llvm \
// RUN: -memref-expand \
// RUN: -arith-expand \
// RUN: -convert-arith-to-llvm \
// RUN: -finalize-memref-to-llvm \
// RUN: -convert-scf-to-cf \
// RUN: -convert-openmp-to-llvm \
// RUN: -convert-arith-to-llvm \
// RUN: -convert-math-to-llvm \
// RUN: -convert-math-to-libm \
// RUN: -convert-func-to-llvm \
// RUN: -reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

func.func private @rtclock() -> f64
func.func private @printMemrefF32(%ptr : tensor<*xf32>)

#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

func.func @kernel(%t0: tensor<1x40x1536xf32>, %t1: tensor<1536xf32>) {
%t_start = call @rtclock() : () -> f64

%54 = tensor.empty() : tensor<1x40x1536xf32>
// %55 = (%3)^2
%c2_i32 = arith.constant 2 : i32
%55 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel"]} ins(%t0 : tensor<1x40x1536xf32>) outs(%54 : tensor<1x40x1536xf32>) {
^bb0(%in: f32, %out: f32):
%3879 = math.fpowi %in, %c2_i32 : f32, i32
linalg.yield %3879 : f32
} -> tensor<1x40x1536xf32>
%56 = tosa.reduce_sum %55 {axis = 2 : i32} : (tensor<1x40x1536xf32>) -> tensor<1x40x1xf32>
%57 = "tosa.const"() <{value = dense<1.536000e+03> : tensor<1xf32>}> : () -> tensor<1xf32>
%58 = tosa.reciprocal %57 : (tensor<1xf32>) -> tensor<1xf32>
%59 = tosa.mul %58, %56 {shift = 0 : i8} : (tensor<1xf32>, tensor<1x40x1xf32>) -> tensor<1x40x1xf32>
%60 = "tosa.const"() <{value = dense<9.99999997E-7> : tensor<1x40x1xf32>}> : () -> tensor<1x40x1xf32>
%61 = tosa.add %59, %60 : (tensor<1x40x1xf32>, tensor<1x40x1xf32>) -> tensor<1x40x1xf32>
%62 = tosa.rsqrt %61 : (tensor<1x40x1xf32>) -> tensor<1x40x1xf32>
%63 = tosa.mul %t0, %62 {shift = 0 : i8} : (tensor<1x40x1536xf32>, tensor<1x40x1xf32>) -> tensor<1x40x1536xf32>
// [NOTE]: %t1 rms norm weight
%64 = tosa.reshape %t1 {new_shape = array<i64: 1, 1, 1536>} : (tensor<1536xf32>) -> tensor<1x1x1536xf32>
%65 = tosa.mul %64, %63 {shift = 0 : i8} : (tensor<1x1x1536xf32>, tensor<1x40x1536xf32>) -> tensor<1x40x1536xf32>

%t_end = call @rtclock() : () -> f64
%time = arith.subf %t_end, %t_start : f64

%tensor_unranked = tensor.cast %65 : tensor<1x40x1536xf32> to tensor<*xf32>

// All the elements of the MemRef are the same,
// only check the first line to verify the correctness.
// CHECK: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [1, 40, 1536] strides = [61440, 1536, 1] data =
// CHECK-NEXT: [
// CHECK-SAME: [
// CHECK-SAME: [2{{(, 2)*}}],

// Print results.
call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> ()
// Print timings.
vector.print %time : f64

return
}

func.func @main() {

%c0 = arith.constant dense<3.0> : tensor<1x40x1536xf32>
%c1 = arith.constant dense <2.0> : tensor<1536xf32>

call @kernel(%c0, %c1) : (tensor<1x40x1536xf32>, tensor<1536xf32>) -> ()

return
}