Skip to content

Commit

Permalink
backend-cpu: add online flow for aarch64 Q4_0 GEMV/GEMM kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
chaxu01 committed Oct 17, 2024
1 parent f010b77 commit c9c1afb
Show file tree
Hide file tree
Showing 10 changed files with 261 additions and 90 deletions.
7 changes: 7 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1993,6 +1993,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
common_log_set_timestamps(common_log_main(), true);
}
).set_env("LLAMA_LOG_TIMESTAMPS"));
add_opt(common_arg(
{"-rtrp", "--runtime-repack"},
string_format("Allow runtime requantization and repacking of Q4_0 to enable optimized GEMM and GEMV kernels (default: %d)", params.runtime_repack),
[](common_params & params) {
params.runtime_repack = true;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));

return ctx_arg;
}
3 changes: 2 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ struct llama_model_params common_model_params_to_llama(const common_params & par
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mmap = params.use_mmap && !params.runtime_repack;
mparams.use_mlock = params.use_mlock;
mparams.check_tensors = params.check_tensors;
if (params.kv_overrides.empty()) {
Expand Down Expand Up @@ -1066,6 +1066,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.runtime_repack = params.runtime_repack;

if (params.reranking) {
cparams.embeddings = true;
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ struct common_params {
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data

bool runtime_repack = false; // runtime repack weight for optimized kernels

std::string cache_type_k = "f16"; // KV cache data type for the K
std::string cache_type_v = "f16"; // KV cache data type for the V

Expand Down
196 changes: 112 additions & 84 deletions examples/llama-bench/llama-bench.cpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions ggml/include/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ extern "C" {
GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
GGML_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
GGML_API void ggml_backend_cpu_set_runtime_repack(ggml_backend_t backend_cpu, bool runtime_repack);

// Create a backend buffer from an existing pointer
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);
Expand Down
99 changes: 99 additions & 0 deletions ggml/src/ggml-aarch64.c
Original file line number Diff line number Diff line change
Expand Up @@ -3207,3 +3207,102 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * restrict s, size_t bs, const void *
}
}
}

static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor *t, int interleave_block, uint8_t **pmem, size_t *psize) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(t->ne[0] % 8 == 0);
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);

// Do in-place transformation. Allocate scratch buffer
size_t size = sizeof(block_q4_0x4) * t->ne[0] / QK4_0;
if (size > *psize) {
uint8_t *new_mem = realloc(*pmem, size);
if (!new_mem) {
return -1;
}
*pmem = new_mem;
*psize = size;
}
block_q4_0x4 *dst = (block_q4_0x4*) *pmem;
block_q4_0 *src = (block_q4_0*) t->data;
block_q4_0 dst_tmp[4];
int n = t->ne[0];
int nrow = t->ne[1]; // Number of rows
int nrows_interleaved = 4;
int nblocks = t->ne[0] / QK4_0;
for (int b = 0; b < (nrow * n); b += nrows_interleaved * n) {
int cnt = 0;
for (int64_t x = 0; x < nblocks; x++) {
for (int i = 0; i < nrows_interleaved; i++ ) {
dst_tmp[i] = src[x + i * nblocks];
}
dst[cnt++] = make_block_q4_0x4(dst_tmp, interleave_block, 0x88);
}
memcpy(src, dst, size);
src += cnt * 4;
}
return 0;
}

static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor *t, int interleave_block, uint8_t **pmem, size_t *psize) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_0);
GGML_ASSERT(t->ne[0] % 8 == 0);
GGML_ASSERT(interleave_block == 8);

// Do in-place transformation. Allocate scratch buffer
size_t size = sizeof(block_q4_0x8) * t->ne[0] / QK4_0;
if (size > *psize) {
uint8_t *new_mem = realloc(*pmem, size);
if (!new_mem) {
return -1;
}
*pmem = new_mem;
*psize = size;
}
block_q4_0x8 *dst = (block_q4_0x8*) *pmem;
block_q4_0 *src = (block_q4_0*) t->data;
block_q4_0 dst_tmp[8];
int n = t->ne[0];
int nrow = t->ne[1]; // Number of rows
int nrows_interleaved = 8;
int nblocks = t->ne[0] / QK4_0;
for (int b = 0; b < (nrow * n); b += nrows_interleaved * n) {
int cnt = 0;
for (int64_t x = 0; x < nblocks; x++) {
for (int i = 0; i < nrows_interleaved; i++ ) {
dst_tmp[i] = src[x + i * nblocks];
}
dst[cnt++] = make_block_q4_0x8(dst_tmp, interleave_block, 0x88);
}
memcpy(src, dst, size);
src += cnt * 4;
}
return 0;
}

// Prepare for optimized kernels if applicable
void ggml_prepare_optimal_kernel(struct ggml_tensor *cur, uint8_t **pmem, size_t *psize) {
UNUSED(cur);
UNUSED(pmem);
UNUSED(psize);

#if defined(__ARM_ARCH)
if (cur->type == GGML_TYPE_Q4_0) {
if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) {
if (repack_q4_0_to_q4_0_8_bl(cur, 8, pmem, psize) == 0) {
cur->type = GGML_TYPE_Q4_0_8_8;
}
}
else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
if (repack_q4_0_to_q4_0_4_bl(cur, 8, pmem, psize) == 0) {
cur->type = GGML_TYPE_Q4_0_4_8;
}
}
else if (ggml_cpu_has_neon()) {
if (repack_q4_0_to_q4_0_4_bl(cur, 4, pmem, psize) == 0) {
cur->type = GGML_TYPE_Q4_0_4_4;
}
}
}
#endif
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-aarch64.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);

void ggml_prepare_optimal_kernel(struct ggml_tensor *cur, uint8_t **pmem, size_t *psize);

#ifdef __cplusplus
}
#endif
Expand Down
26 changes: 26 additions & 0 deletions ggml/src/ggml-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "ggml-backend-impl.h"
#include "ggml-alloc.h"
#include "ggml-impl.h"
#include "ggml-aarch64.h"

#include <assert.h>
#include <limits.h>
Expand Down Expand Up @@ -882,6 +883,10 @@ struct ggml_backend_cpu_context {
uint8_t * work_data;
size_t work_size;

bool runtime_repack;
uint8_t * scratch_memory;
size_t scratch_size;

ggml_abort_callback abort_callback;
void * abort_callback_data;
};
Expand All @@ -895,6 +900,7 @@ static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
static void ggml_backend_cpu_free(ggml_backend_t backend) {
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
delete[] cpu_ctx->work_data;
free(cpu_ctx->scratch_memory); // free the scratch memory allocated by C module
delete cpu_ctx;
delete backend;
}
Expand Down Expand Up @@ -952,6 +958,16 @@ static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backe
static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;

if (cpu_ctx->runtime_repack) {
for (int i = 0; i < cgraph->n_nodes; i++) {
struct ggml_tensor * node = cgraph->nodes[i];
if (node->op == GGML_OP_MUL_MAT && node->src[0]->type == GGML_TYPE_Q4_0) {
// Prepare for optimized kernels if applicable.
ggml_prepare_optimal_kernel(node->src[0], &cpu_ctx->scratch_memory, &cpu_ctx->scratch_size);
}
}
}

struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);

if (cpu_ctx->work_size < cplan.work_size) {
Expand Down Expand Up @@ -1008,6 +1024,9 @@ ggml_backend_t ggml_backend_cpu_init(void) {
ctx->work_size = 0;
ctx->abort_callback = NULL;
ctx->abort_callback_data = NULL;
ctx->runtime_repack = false;
ctx->scratch_memory = NULL;
ctx->scratch_size = 0;

ggml_backend_t cpu_backend = new ggml_backend {
/* .guid = */ ggml_backend_cpu_guid(),
Expand Down Expand Up @@ -1055,6 +1074,13 @@ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_
ctx->abort_callback_data = abort_callback_data;
}

void ggml_backend_cpu_set_runtime_repack(ggml_backend_t backend_cpu, bool runtime_repack) {
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));

struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
ctx->runtime_repack = runtime_repack;
}

ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
GGML_ASSERT((uintptr_t)ptr % TENSOR_ALIGNMENT == 0 && "buffer pointer must be aligned");
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), ggml_backend_cpu_buffer_from_ptr_i, ptr, size);
Expand Down
11 changes: 6 additions & 5 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -341,11 +341,12 @@ extern "C" {

// Keep the booleans together and at the end of the struct to avoid misalignment during copy-by-value.
// TODO: move at the end of the struct
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings
bool logits_all; // the llama_decode() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead)
bool embeddings; // if true, extract embeddings (together with logits)
bool offload_kqv; // whether to offload the KQV ops (including the KV cache) to GPU
bool flash_attn; // whether to use flash attention [EXPERIMENTAL]
bool no_perf; // whether to measure performance timings
bool runtime_repack; // runtime repack weight for optimized kernels

// Abort callback
// if it returns true, execution of llama_decode() will be aborted
Expand Down
4 changes: 4 additions & 0 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2574,6 +2574,7 @@ struct llama_cparams {
bool offload_kqv;
bool flash_attn;
bool no_perf;
bool runtime_repack;

enum llama_pooling_type pooling_type;

Expand Down Expand Up @@ -17107,6 +17108,7 @@ static void llama_graph_compute(
ggml_threadpool * threadpool) {
if (lctx.backend_cpu != nullptr) {
ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool);
ggml_backend_cpu_set_runtime_repack(lctx.backend_cpu, lctx.cparams.runtime_repack);
ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
}

Expand Down Expand Up @@ -19034,6 +19036,7 @@ struct llama_context_params llama_context_default_params() {
/*.offload_kqv =*/ true,
/*.flash_attn =*/ false,
/*.no_perf =*/ true,
/*.runtime_repack =*/ false,
/*.abort_callback =*/ nullptr,
/*.abort_callback_data =*/ nullptr,
};
Expand Down Expand Up @@ -19292,6 +19295,7 @@ struct llama_context * llama_new_context_with_model(
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.runtime_repack = params.runtime_repack;

cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
Expand Down

0 comments on commit c9c1afb

Please sign in to comment.