diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index ade0773dad8..d51537f7d04 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -955,22 +955,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup); } - for (; kb0 < kb0_stop-1; ++kb0) { - constexpr bool last_iter = false; - constexpr bool oob_check = false; - constexpr int k_VKQ_sup = nbatch_fa; - flash_attn_ext_f16_iter - - (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, - KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); - } // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. if constexpr (ncols2 == 1) { - if (ne11 % nbatch_fa == 0) { - constexpr bool last_iter = true; - constexpr bool oob_check = false; + constexpr bool oob_check = true; + for (; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter + (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, + KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); + } else { + constexpr bool oob_check = false; + for (; kb0 < kb0_stop-1; ++kb0) { + constexpr bool last_iter = false; + constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter @@ -989,9 +988,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup); } - } else { constexpr bool last_iter = true; - constexpr bool oob_check = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter src[0]; const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; const ggml_tensor * mask = dst->src[3]; float max_bias = 0.0f; memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); - const bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + // Edge cases like no mask, ALiBi, unpadded K/V, or misaligned addresses for large data transfers + // are put into the template specialization without GQA optimizations. + bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + for (const ggml_tensor * t : {Q, K, V, mask}) { + if (t == nullptr) { + continue; + } + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (t->nb[i] % 16 != 0) { + use_gqa_opt = false; + break; + } + } + } GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2];