Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable KleidiAI for FP32 #3818

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 0 additions & 46 deletions Makefile.FP16Benchmark.aarch64

This file was deleted.

39 changes: 33 additions & 6 deletions include/fbgemm/FbgemmFPCommon.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* Copyright 2024-2025 Arm Limited and/or its affiliates
* <[email protected]> All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -57,6 +58,22 @@ struct GemmParams<float16> {
#endif
};

template <>
struct GemmParams<float> {
uint64_t k;
float* A;
const float* B;
float beta;
float* C;
uint64_t ldc;
uint64_t b_block_cols;
#ifdef FBGEMM_ENABLE_KLEIDIAI
uint64_t lda;
#else
uint64_t b_block_size;
#endif
};

template <typename T>
using funcptr_t = void (*)(GemmParams<T>*);
template <typename T>
Expand Down Expand Up @@ -175,7 +192,9 @@ void cblas_gemm_compute(
assert(kernel_nrows * kb < static_cast<int64_t>(scratchpad->size()));
if (m != 1) {
#ifdef FBGEMM_ENABLE_KLEIDIAI
if constexpr (std::is_same<T, float16>::value) {
if constexpr (
std::is_same<T, float16>::value ||
std::is_same<T, float>::value) {
gp.A = const_cast<float*>(&A[m2 * k + k_ind]);
} else {
#endif
Expand All @@ -201,7 +220,9 @@ void cblas_gemm_compute(
gp.ldc = ldc * sizeof(C[0]);
gp.b_block_cols = nbcol;
#ifdef FBGEMM_ENABLE_KLEIDIAI
if constexpr (std::is_same<T, float16>::value) {
if constexpr (
std::is_same<T, float16>::value ||
std::is_same<T, float>::value) {
gp.lda = k * sizeof(A[0]);
} else {
#endif
Expand All @@ -218,7 +239,9 @@ void cblas_gemm_compute(
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
#ifdef FBGEMM_USE_REF_KERNEL
if constexpr (std::is_same<T, float16>::value) {
if constexpr (
std::is_same<T, float16>::value ||
std::is_same<T, float>::value) {
kernels[kernel_nrows](&gp);
} else {
ref_kernel<T>(kernel_nrows, &gp, C, m, n, simd_width);
Expand All @@ -238,7 +261,9 @@ void cblas_gemm_compute(
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
#ifdef FBGEMM_USE_REF_KERNEL
if constexpr (std::is_same<T, float16>::value) {
if constexpr (
std::is_same<T, float16>::value ||
std::is_same<T, float>::value) {
kernels[kernel_nrows](&gp);
} else {
ref_kernel(kernel_nrows, &gp, C, m, n, simd_width);
Expand Down Expand Up @@ -269,7 +294,9 @@ void cblas_gemm_compute(
gp.ldc = Bp.blockColSize() * sizeof(C[0]);
gp.b_block_cols = 1;
#ifdef FBGEMM_USE_REF_KERNEL
if constexpr (std::is_same<T, float16>::value) {
if constexpr (
std::is_same<T, float16>::value ||
std::is_same<T, float>::value) {
kernels[kernel_nrows](&gp);
} else {
ref_kernel<T>(
Expand Down
12 changes: 9 additions & 3 deletions include/fbgemm/FbgemmPackMatrixB.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* Copyright 2024-2025 Arm Limited and/or its affiliates
* <[email protected]> All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -63,7 +64,7 @@ class PackedGemmMatrixB {
const float* smat,
const int brow = 512)
: nrow_(nrow), ncol_(ncol), brow_(brow), kernel_ncol_blocks_(2) {
#if defined(FBGEMM_ENABLE_KLEIDIAI)
#ifdef FBGEMM_ENABLE_KLEIDIAI
if (std::is_same<T, float16>::value) {
kernel_ncol_blocks_ = 1;
}
Expand Down Expand Up @@ -92,7 +93,7 @@ class PackedGemmMatrixB {
nbcol_(nbcol),
size_(size),
kernel_ncol_blocks_(2) {
#if defined(FBGEMM_ENABLE_KLEIDIAI)
#ifdef FBGEMM_ENABLE_KLEIDIAI
if (std::is_same<T, float16>::value) {
kernel_ncol_blocks_ = 1;
}
Expand Down Expand Up @@ -120,6 +121,11 @@ class PackedGemmMatrixB {
nbcol_(nbcol),
size_(size),
kernel_ncol_blocks_(kernel_ncol_blocks) {
#ifdef FBGEMM_ENABLE_KLEIDIAI
if (std::is_same<T, float16>::value) {
kernel_ncol_blocks_ = 1;
}
#endif
pmat_ = static_cast<T*>(pmat);
packed_ = true;
pmat_passed_in = true;
Expand Down
29 changes: 29 additions & 0 deletions src/fp32/FbgemmFP32.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* Copyright 2025 Arm Limited and/or its affiliates <[email protected]>
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
Expand All @@ -11,9 +12,15 @@
#include <cmath>
#include <utility>

#ifndef __aarch64__
#include "./FbgemmFP32UKernelsAvx2.h"
#include "./FbgemmFP32UKernelsAvx512.h"
#include "./FbgemmFP32UKernelsAvx512_256.h"
#else
#ifdef FBGEMM_ENABLE_KLEIDIAI
#include "./KleidiAIFP32UKernelsNeon.h"
#endif
#endif
#include "fbgemm/Fbgemm.h"
#include "fbgemm/FbgemmFPCommon.h"

Expand Down Expand Up @@ -80,6 +87,19 @@ constexpr kernel_array_t<float> kernel_f32_avx512_256 = {
nullptr};
#endif

#ifdef __aarch64__
#ifdef FBGEMM_ENABLE_KLEIDIAI
constexpr kernel_array_t<float> kernel_fp32_neon = {
nullptr,
kleidiai::gemmkernel_1x2_Neon_fp32_fA0fB0fC0,
kleidiai::gemmkernel_2x2_Neon_fp32_fA0fB0fC0,
kleidiai::gemmkernel_3x2_Neon_fp32_fA0fB0fC0,
kleidiai::gemmkernel_4x2_Neon_fp32_fA0fB0fC0,
kleidiai::gemmkernel_5x2_Neon_fp32_fA0fB0fC0,
kleidiai::gemmkernel_6x2_Neon_fp32_fA0fB0fC0,
};
#endif
#endif
} // namespace

template <>
Expand All @@ -90,9 +110,18 @@ const isa_descriptor<float>& getIsaHandlers(inst_set_t isa, float) {
std::make_tuple(kernel_f32_avx512, partition_avx512);
static isa_descriptor<float> avx512_256_descriptor =
std::make_tuple(kernel_f32_avx512_256, partition_avx512);
#ifdef __aarch64__
#ifdef FBGEMM_ENABLE_KLEIDIAI
static isa_descriptor<float> neon_descriptor =
std::make_tuple(kernel_fp32_neon, partition_sve128);
#endif
#endif

switch (isa) {
case inst_set_t::sve:
#ifdef FBGEMM_ENABLE_KLEIDIAI
return neon_descriptor;
#endif
case inst_set_t::anyarch:
case inst_set_t::avx2:
return avx2_descriptor;
Expand Down
Loading