Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 17 additions & 20 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -955,43 +955,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
(K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
}

for (; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
constexpr bool oob_check = false;
constexpr int k_VKQ_sup = nbatch_fa;
flash_attn_ext_f16_iter
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
}
// kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
if constexpr (ncols2 == 1) {
if (ne11 % nbatch_fa == 0) {
constexpr bool last_iter = true;
constexpr bool oob_check = false;
constexpr bool oob_check = true;
for (; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
constexpr int k_VKQ_sup = nbatch_fa;
flash_attn_ext_f16_iter
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
} else {
constexpr bool last_iter = true;
constexpr bool oob_check = true;
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
}
constexpr bool last_iter = true;
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
flash_attn_ext_f16_iter
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
} else {
constexpr bool oob_check = false;
for (; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
constexpr int k_VKQ_sup = nbatch_fa;
flash_attn_ext_f16_iter
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
}
} else {
constexpr bool last_iter = true;
constexpr bool oob_check = false;
constexpr int k_VKQ_sup = nbatch_fa;
flash_attn_ext_f16_iter
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
Expand Down
16 changes: 15 additions & 1 deletion ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,26 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];

float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));

const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
// Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers
// are put into the template specialization without GQA optimizations.
bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
for (const ggml_tensor * t : {Q, K, V, mask}) {
if (t == nullptr) {
continue;
}
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
if (t->nb[i] % 16 != 0) {
use_gqa_opt = false;
break;
}
}
}

GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
Expand Down
Loading