Skip to content

Commit 8b836a9

Browse files
CUDA: fix unpadded strides in MMA FA kernel
1 parent 6b82eb7 commit 8b836a9

File tree

2 files changed

+29
-21
lines changed

2 files changed

+29
-21
lines changed

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -955,43 +955,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
955955
(K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
956956
}
957957

958-
for (; kb0 < kb0_stop-1; ++kb0) {
959-
constexpr bool last_iter = false;
960-
constexpr bool oob_check = false;
961-
constexpr int k_VKQ_sup = nbatch_fa;
962-
flash_attn_ext_f16_iter
963-
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
964-
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
965-
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
966-
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
967-
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
968-
}
969958
// kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
970959
if constexpr (ncols2 == 1) {
971-
if (ne11 % nbatch_fa == 0) {
972-
constexpr bool last_iter = true;
973-
constexpr bool oob_check = false;
960+
constexpr bool oob_check = true;
961+
for (; kb0 < kb0_stop-1; ++kb0) {
962+
constexpr bool last_iter = false;
974963
constexpr int k_VKQ_sup = nbatch_fa;
975964
flash_attn_ext_f16_iter
976965
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
977966
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
978967
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
979968
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
980969
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
981-
} else {
982-
constexpr bool last_iter = true;
983-
constexpr bool oob_check = true;
984-
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
970+
}
971+
constexpr bool last_iter = true;
972+
const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
973+
flash_attn_ext_f16_iter
974+
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
975+
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
976+
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
977+
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
978+
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
979+
} else {
980+
constexpr bool oob_check = false;
981+
for (; kb0 < kb0_stop-1; ++kb0) {
982+
constexpr bool last_iter = false;
983+
constexpr int k_VKQ_sup = nbatch_fa;
985984
flash_attn_ext_f16_iter
986985
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
987986
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
988987
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
989988
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
990989
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
991990
}
992-
} else {
993991
constexpr bool last_iter = true;
994-
constexpr bool oob_check = false;
995992
constexpr int k_VKQ_sup = nbatch_fa;
996993
flash_attn_ext_f16_iter
997994
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,

ggml/src/ggml-cuda/fattn.cu

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,23 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
3636
const ggml_tensor * KQV = dst;
3737
const ggml_tensor * Q = dst->src[0];
3838
const ggml_tensor * K = dst->src[1];
39+
const ggml_tensor * V = dst->src[2];
3940
const ggml_tensor * mask = dst->src[3];
4041

4142
float max_bias = 0.0f;
4243
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
4344

44-
const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
45+
// Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers
46+
// are put into the template specialization without GQA optimizations.
47+
bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
48+
for (const ggml_tensor * t : {Q, K, V, mask}) {
49+
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
50+
if (t->nb[i] % 16 != 0) {
51+
use_gqa_opt = false;
52+
break;
53+
}
54+
}
55+
}
4556

4657
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
4758
const int gqa_ratio = Q->ne[2] / K->ne[2];

0 commit comments

Comments
 (0)