Skip to content

[Perf] [CPU] eliminate redundant memory access in group query attention #13319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -2785,7 +2785,11 @@ struct ggml_cplan ggml_graph_plan(
const int64_t ne10 = node->src[1]->ne[0]; // DK
const int64_t ne20 = node->src[2]->ne[0]; // DV

cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread)
const int64_t ne02 = node->src[0]->ne[2]; // n_head
const int64_t ne12 = node->src[1]->ne[2]; // n_kv_head
const int64_t n_gqa = ne02/ne12;

cur = sizeof(float)*n_gqa*(1*ne10 + 2*ne20)*n_tasks; // ngqa * (1x head size K + 2x head size V) (per thread)
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
Expand Down
251 changes: 145 additions & 106 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6854,7 +6854,82 @@ void ggml_compute_forward_argsort(
}

// ggml_compute_forward_flash_attn_ext
static inline void ggml_compute_forward_flash_attn_ext_f16_one_QKV(
const ggml_fp16_t *Q,
const char *K,
const char *V,
const int64_t DK,
const int64_t DV,
const float mask_value,
const float scale,
const float logit_softcap,
const enum ggml_type v_type,
ggml_vec_dot_t const kq_vec_dot,
ggml_to_float_t const v_to_float,
ggml_fp16_t *VKQ16,
float *VKQ32,
float *V32,
float *sum,
float *max_kq_value) {
float s; // KQ value
kq_vec_dot(DK, &s, 0, K, 0, Q, 0, 1);

s = s*scale; // scale KQ value

if (logit_softcap != 0.0f) {
s = logit_softcap*tanhf(s);
}
s += mask_value; // apply mask
float M = *max_kq_value;
const float Mold = M;

float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
float vs = 1.0f; // post-softmax KQ value, expf(s - M)

if (v_type == GGML_TYPE_F16) {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f16(DV, VKQ16, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}

// V += v*expf(s - M)
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) V, vs);
} else {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}

// V += v*expf(s - M)
if (v_to_float) {
v_to_float(V, V32, DV);
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
} else {
// V is F32
ggml_vec_mad_f32(DV, VKQ32, (const float *) V, vs);
}
}
float S = *sum;
S = S*ms + vs; // scale and increment sum with partial sum
*sum = S;
*max_kq_value = M;
}

#define GGML_FLASH_ATTN_EXT_MAX_GQA 16
static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params,
const ggml_tensor * q,
Expand Down Expand Up @@ -6907,16 +6982,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t rv3 = neq3/nev3;

// parallelize by q rows using ggml_vec_dot_f32
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));

// total rows in q
const int nr = neq1*neq2*neq3;
const uint32_t n_kv_head = nek2;
const int n_gqa = n_head / n_kv_head;
GGML_ASSERT(n_gqa <= GGML_FLASH_ATTN_EXT_MAX_GQA);

// rows per thread
const int dr = (nr + nth - 1)/nth;
// total groups in q
const int ng = neq1*neq2*neq3/n_gqa;

// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
// groups per thread
const int dg = (ng + nth - 1)/nth;

// group range for this thread
const int ig0 = dg*ith;
const int ig1 = MIN(ig0 + dg, ng);

float scale = 1.0f;
float max_bias = 0.0f;
Expand All @@ -6930,9 +7011,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
scale /= logit_softcap;
}

const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));

const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);

Expand All @@ -6944,28 +7022,42 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");

// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) {
float S[GGML_FLASH_ATTN_EXT_MAX_GQA]; // sum
float M[GGML_FLASH_ATTN_EXT_MAX_GQA]; // maximum KQ value
float * VKQ32[GGML_FLASH_ATTN_EXT_MAX_GQA]; // FP32 VKQ accumulator
float * V32[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q[GGML_FLASH_ATTN_EXT_MAX_GQA]; // (temporary) buffer for Q converted to quantized/FP16
float slope[GGML_FLASH_ATTN_EXT_MAX_GQA];

for (int ig = ig0; ig < ig1; ++ig) {
const int group_index = ig % n_kv_head;
const int batch_index = ig / n_kv_head;
// q indices
const int iq3 = ir/(neq2*neq1);
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);

const uint32_t h = iq2; // head index
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
const int iq3 = 0;
const int iq2 = group_index * n_gqa; // start head index
const int iq1 = batch_index;

const int single_buffer_size = 1*DK + 2*DV;
for (int i_gqa = 0; i_gqa < n_gqa; ++i_gqa) {
S[i_gqa] = 0.0f;
M[i_gqa] = -INFINITY;
VKQ32 [i_gqa] = (float *) params->wdata + ith*(single_buffer_size*n_gqa + CACHE_LINE_SIZE_F32) + single_buffer_size*i_gqa;
V32 [i_gqa] = (VKQ32[i_gqa] + 1*DV);
VKQ16 [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 1*DV);
Q_q [i_gqa] = (ggml_fp16_t *) (VKQ32[i_gqa] + 2*DV);

float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value

float * VKQ32 = (float *) params->wdata + ith*(1*DK + 2*DV + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
float * V32 = (VKQ32 + 1*DV); // (temporary) FP32 V buffer
ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*DV); // (temporary) FP16 VKQ accumulator
ggml_fp16_t * Q_q = (ggml_fp16_t *) (VKQ32 + 2*DV); // (temporary) buffer for Q converted to quantized/FP16
if (v->type == GGML_TYPE_F16) {
memset(VKQ16[i_gqa], 0, DV*sizeof(ggml_fp16_t));
} else {
memset(VKQ32[i_gqa], 0, DV*sizeof(float));
}

if (v->type == GGML_TYPE_F16) {
memset(VKQ16, 0, DV*sizeof(ggml_fp16_t));
} else {
memset(VKQ32, 0, DV*sizeof(float));
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + (iq2 + i_gqa)*nbq2 + iq3*nbq3));
q_to_vec_dot(pq, Q_q[i_gqa], DK);

const uint32_t h = iq2 + i_gqa;
slope[i_gqa] = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
}

const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
Expand All @@ -6978,99 +7070,46 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int iv3 = iq3 / rv3;
const int iv2 = iq2 / rv2;

const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
q_to_vec_dot(pq, Q_q, DK);

// online softmax / attention
// loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf
for (int64_t ic = 0; ic < nek1; ++ic) {
const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) {
const float mp_value_base = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mp_value_base == -INFINITY) {
continue;
}

float s; // KQ value

const char * v_data = (const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3);
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(DK, &s, 0, k_data, 0, Q_q, 0, 1);

s = s*scale; // scale KQ value

if (logit_softcap != 0.0f) {
s = logit_softcap*tanhf(s);
for (int i_gqa = 0; i_gqa < n_gqa; ++i_gqa) {
const float mv = mp_value_base * slope[i_gqa];
ggml_compute_forward_flash_attn_ext_f16_one_QKV(
Q_q[i_gqa], k_data, v_data, DK, DV, mv, scale, logit_softcap, v->type,
kq_vec_dot, v_to_float, VKQ16[i_gqa], VKQ32[i_gqa], V32[i_gqa], S+i_gqa, M+i_gqa);
}
}

s += mv; // apply mask

const float Mold = M;

float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
float vs = 1.0f; // post-softmax KQ value, expf(s - M)

const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));

for (int i = 0; i < n_gqa; ++i) {
if (v->type == GGML_TYPE_F16) {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f16(DV, VKQ16, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}

// V += v*expf(s - M)
ggml_vec_mad_f16(DV, VKQ16, (const ggml_fp16_t *) v_data, vs);
} else {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
ms = expf(Mold - M);

// V = V*expf(Mold - M)
ggml_vec_scale_f32(DV, VKQ32, ms);
} else {
// no new maximum, ms == 1.0f, vs != 1.0f
vs = expf(s - M);
}

// V += v*expf(s - M)
if (v_to_float) {
v_to_float(v_data, V32, DV);
ggml_vec_mad_f32(DV, VKQ32, V32, vs);
} else {
// V is F32
ggml_vec_mad_f32(DV, VKQ32, (const float *) v_data, vs);
for (int64_t d = 0; d < DV; ++d) {
VKQ32[i][d] = GGML_FP16_TO_FP32(VKQ16[i][d]);
}
}

S = S*ms + vs; // scale and increment sum with partial sum
}
// V /= S
const float S_inv = 1.0f/S[i];
ggml_vec_scale_f32(DV, VKQ32[i], S_inv);

if (v->type == GGML_TYPE_F16) {
for (int64_t d = 0; d < DV; ++d) {
VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
}
}
// dst indices
const int i1 = iq1;
const int i2 = iq2 + i;
const int i3 = iq3;

// V /= S
const float S_inv = 1.0f/S;
ggml_vec_scale_f32(DV, VKQ32, S_inv);
// original
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));

// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;

// original
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));

// permute(0, 2, 1, 3)
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
// permute(0, 2, 1, 3)
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32[i], nb1);
}
}
}

Expand Down
Loading