Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions examples/BuddyNext/makefile
Original file line number Diff line number Diff line change
Expand Up @@ -808,3 +808,69 @@ tosa-matmul-transpose2-vec-run:
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

next-softmax-run:
@${MLIR_OPT} ./next-softmax.mlir \
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \
${BUDDY_OPT} \
-arith-expand \
-eliminate-empty-tensors \
-one-shot-bufferize \
-convert-linalg-to-affine-loops \
-affine-loop-fusion \
-lower-affine \
-func-bufferize \
-arith-bufferize \
-tensor-bufferize \
-buffer-deallocation \
-finalizing-bufferize \
-convert-vector-to-scf \
-expand-strided-metadata \
-convert-vector-to-llvm \
-memref-expand \
-arith-expand \
-convert-arith-to-llvm \
-finalize-memref-to-llvm \
-convert-scf-to-cf \
-convert-openmp-to-llvm \
-convert-arith-to-llvm \
-convert-math-to-llvm \
-convert-math-to-libm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}

next-softmax-lower:
@${MLIR_OPT} ./next-softmax.mlir \
-pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" | \
${BUDDY_OPT} \
-arith-expand \
-eliminate-empty-tensors \
-one-shot-bufferize \
-convert-linalg-to-affine-loops \
-affine-loop-fusion \
-lower-affine \
-func-bufferize \
-arith-bufferize \
-tensor-bufferize \
-buffer-deallocation \
-finalizing-bufferize \
-convert-vector-to-scf \
-expand-strided-metadata \
-reconcile-unrealized-casts -o next-softmax-lower.mlir

next-softmax-manual-run:
@${BUDDY_OPT} ./next-softmax-manual.mlir \
-convert-vector-to-llvm \
-memref-expand \
-arith-expand \
-convert-arith-to-llvm \
-finalize-memref-to-llvm \
-convert-scf-to-cf \
-convert-arith-to-llvm \
-convert-math-to-llvm \
-convert-func-to-llvm \
-reconcile-unrealized-casts | \
${MLIR_CPU_RUNNER} ${OPT_FLAG} -e main -entry-point-result=void \
-shared-libs=${MLIR_RUNNER_UTILS} -shared-libs=${MLIR_C_RUNNER_UTILS}
126 changes: 126 additions & 0 deletions examples/BuddyNext/next-softmax-manual.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
module {
memref.global "private" constant @__constant_1x40x151936xf32 : memref<1x40x151936xf32> = dense<2.000000e+00> {alignment = 64 : i64}
func.func private @rtclock() -> f64
func.func private @printMemrefF32(memref<*xf32>)

func.func @softmax_kernel(%arg0: memref<1x40x151936xf32>) {
%c151936 = arith.constant 151936 : index
%c40 = arith.constant 40 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%cst = arith.constant 0.000000e+00 : f32
%cst_neg_inf = arith.constant -3.40282347E+38 : f32

%0 = call @rtclock() : () -> f64

%result = memref.alloc() {alignment = 64 : i64} : memref<1x40x151936xf32>

scf.for %row = %c0 to %c40 step %c1 {
%max_scalar = memref.alloc() : memref<f32>
memref.store %cst_neg_inf, %max_scalar[] : memref<f32>

%remainder_max = arith.remsi %c151936, %c8 : index
%vectorized_end_max = arith.subi %c151936, %remainder_max : index

// max(x)
%neg_inf_vec = vector.splat %cst_neg_inf : vector<8xf32>
%max_vec = memref.alloc() : memref<vector<8xf32>>
memref.store %neg_inf_vec, %max_vec[] : memref<vector<8xf32>>

scf.for %col = %c0 to %vectorized_end_max step %c8 {
%input_vec = vector.load %arg0[%c0, %row, %col] : memref<1x40x151936xf32>, vector<8xf32>
%current_max = memref.load %max_vec[] : memref<vector<8xf32>>
%new_max = arith.maximumf %input_vec, %current_max : vector<8xf32>
memref.store %new_max, %max_vec[] : memref<vector<8xf32>>
}
scf.for %col = %vectorized_end_max to %c151936 step %c1 {
%val = memref.load %arg0[%c0, %row, %col] : memref<1x40x151936xf32>
%current_max = memref.load %max_scalar[] : memref<f32>
%new_max = arith.maximumf %val, %current_max : f32
memref.store %new_max, %max_scalar[] : memref<f32>
}

%final_max_vec = memref.load %max_vec[] : memref<vector<8xf32>>
%max_scalar_from_vec = vector.reduction <maximumf>, %final_max_vec : vector<8xf32> into f32
%current_scalar_max = memref.load %max_scalar[] : memref<f32>
%final_max = arith.maximumf %max_scalar_from_vec, %current_scalar_max : f32

// sum(exp(x-max))
%sum_scalar = memref.alloc() : memref<f32>
memref.store %cst, %sum_scalar[] : memref<f32>

%max_vec_broadcast = vector.splat %final_max : vector<8xf32>
%zero_vec = vector.splat %cst : vector<8xf32>
%sum_vec = memref.alloc() : memref<vector<8xf32>>
memref.store %zero_vec, %sum_vec[] : memref<vector<8xf32>>

scf.for %col = %c0 to %vectorized_end_max step %c8 {
%input_vec = vector.load %arg0[%c0, %row, %col] : memref<1x40x151936xf32>, vector<8xf32>
%sub_vec = arith.subf %input_vec, %max_vec_broadcast : vector<8xf32>
%exp_vec = math.exp %sub_vec : vector<8xf32>

vector.store %exp_vec, %result[%c0, %row, %col] : memref<1x40x151936xf32>, vector<8xf32>

%current_sum = memref.load %sum_vec[] : memref<vector<8xf32>>
%new_sum = arith.addf %current_sum, %exp_vec : vector<8xf32>
memref.store %new_sum, %sum_vec[] : memref<vector<8xf32>>
}
scf.for %col = %vectorized_end_max to %c151936 step %c1 {
%val = memref.load %arg0[%c0, %row, %col] : memref<1x40x151936xf32>
%sub_val = arith.subf %val, %final_max : f32
%exp_val = math.exp %sub_val : f32
memref.store %exp_val, %result[%c0, %row, %col] : memref<1x40x151936xf32>

%current_sum = memref.load %sum_scalar[] : memref<f32>
%new_sum = arith.addf %current_sum, %exp_val : f32
memref.store %new_sum, %sum_scalar[] : memref<f32>
}

%final_sum_vec = memref.load %sum_vec[] : memref<vector<8xf32>>
%sum_from_vec = vector.reduction <add>, %final_sum_vec : vector<8xf32> into f32
%scalar_sum = memref.load %sum_scalar[] : memref<f32>
%total_sum = arith.addf %sum_from_vec, %scalar_sum : f32

// log_sum_exp = max + log(sum)
// exp(x - log_sum_exp)
%log_sum = math.log %total_sum : f32
%log_sum_exp = arith.addf %final_max, %log_sum : f32
%log_sum_exp_vec = vector.splat %log_sum_exp : vector<8xf32>

scf.for %col = %c0 to %vectorized_end_max step %c8 {
%input_vec = vector.load %arg0[%c0, %row, %col] : memref<1x40x151936xf32>, vector<8xf32>
%sub_vec = arith.subf %input_vec, %log_sum_exp_vec : vector<8xf32>
%softmax_vec = math.exp %sub_vec : vector<8xf32>
vector.store %softmax_vec, %result[%c0, %row, %col] : memref<1x40x151936xf32>, vector<8xf32>
}

scf.for %col = %vectorized_end_max to %c151936 step %c1 {
%input_val = memref.load %arg0[%c0, %row, %col] : memref<1x40x151936xf32>
%sub_val = arith.subf %input_val, %log_sum_exp : f32
%softmax_val = math.exp %sub_val : f32
memref.store %softmax_val, %result[%c0, %row, %col] : memref<1x40x151936xf32>
}

memref.dealloc %max_scalar : memref<f32>
memref.dealloc %max_vec : memref<vector<8xf32>>
memref.dealloc %sum_scalar : memref<f32>
memref.dealloc %sum_vec : memref<vector<8xf32>>
}

%1 = call @rtclock() : () -> f64
%2 = arith.subf %1, %0 : f64
%cast = memref.cast %result : memref<1x40x151936xf32> to memref<*xf32>
// call @printMemrefF32(%cast) : (memref<*xf32>) -> ()

memref.dealloc %result : memref<1x40x151936xf32>
vector.print %2 : f64
return
}

func.func @main() {
%0 = memref.get_global @__constant_1x40x151936xf32 : memref<1x40x151936xf32>
call @softmax_kernel(%0) : (memref<1x40x151936xf32>) -> ()
return
}
}
66 changes: 66 additions & 0 deletions examples/BuddyNext/next-softmax.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// RUN: buddy-opt %s \
// RUN: -pass-pipeline "builtin.module(func.func(tosa-to-linalg-named),func.func(tosa-to-linalg),func.func(tosa-to-tensor),func.func(tosa-to-arith))" \
// RUN: | buddy-opt \
// RUN: -arith-expand \
// RUN: -eliminate-empty-tensors \
// RUN: -one-shot-bufferize \
// RUN: -convert-linalg-to-affine-loops \
// RUN: -affine-loop-fusion \
// RUN: -lower-affine \
// RUN: -func-bufferize \
// RUN: -arith-bufferize \
// RUN: -tensor-bufferize \
// RUN: -buffer-deallocation \
// RUN: -finalizing-bufferize \
// RUN: -convert-vector-to-scf \
// RUN: -expand-strided-metadata \
// RUN: -convert-vector-to-llvm \
// RUN: -memref-expand \
// RUN: -arith-expand \
// RUN: -convert-arith-to-llvm \
// RUN: -finalize-memref-to-llvm \
// RUN: -convert-scf-to-cf \
// RUN: -convert-openmp-to-llvm \
// RUN: -convert-arith-to-llvm \
// RUN: -convert-math-to-llvm \
// RUN: -convert-math-to-libm \
// RUN: -convert-func-to-llvm \
// RUN: -reconcile-unrealized-casts \
// RUN: | mlir-cpu-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext \
// RUN: -shared-libs=%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext \
// RUN: | FileCheck %s

func.func private @rtclock() -> f64
func.func private @printMemrefF32(%ptr : tensor<*xf32>)

#map = affine_map<(d0, d1) -> (d0, d1)>

func.func @softmax_kernel(%input: tensor<1x40x151936xf32>) {
%t_start = call @rtclock() : () -> f64

%max = tosa.reduce_max %input {axis = 2 : i32} : (tensor<1x40x151936xf32>) -> tensor<1x40x1xf32>
%sub = tosa.sub %input, %max : (tensor<1x40x151936xf32>, tensor<1x40x1xf32>) -> tensor<1x40x151936xf32>
%exp = tosa.exp %sub : (tensor<1x40x151936xf32>) -> tensor<1x40x151936xf32>
%sum = tosa.reduce_sum %exp {axis = 2 : i32} : (tensor<1x40x151936xf32>) -> tensor<1x40x1xf32>
%logsum = tosa.log %sum : (tensor<1x40x1xf32>) -> tensor<1x40x1xf32>
%add = tosa.add %max, %logsum : (tensor<1x40x1xf32>, tensor<1x40x1xf32>) -> tensor<1x40x1xf32>
%sub2 = tosa.sub %input, %add : (tensor<1x40x151936xf32>, tensor<1x40x1xf32>) -> tensor<1x40x151936xf32>
%softmax = tosa.exp %sub2 : (tensor<1x40x151936xf32>) -> tensor<1x40x151936xf32>

%t_end = call @rtclock() : () -> f64
%time = arith.subf %t_end, %t_start : f64

%tensor_unranked = tensor.cast %softmax : tensor<1x40x151936xf32> to tensor<*xf32>

// call @printMemrefF32(%tensor_unranked) : (tensor<*xf32>) -> ()
vector.print %time : f64

return
}

func.func @main() {
%c0 = arith.constant dense<2.0> : tensor<1x40x151936xf32>
call @softmax_kernel(%c0) : (tensor<1x40x151936xf32>) -> ()
return
}
Loading