@@ -955,43 +955,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
955955 (K_h2 + int64_t (kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
956956 }
957957
958- for (; kb0 < kb0_stop-1 ; ++kb0) {
959- constexpr bool last_iter = false ;
960- constexpr bool oob_check = false ;
961- constexpr int k_VKQ_sup = nbatch_fa;
962- flash_attn_ext_f16_iter
963- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
964- T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
965- (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
966- ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
967- KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
968- }
969958 // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
970959 if constexpr (ncols2 == 1 ) {
971- if (ne11 % nbatch_fa == 0 ) {
972- constexpr bool last_iter = true ;
973- constexpr bool oob_check = false ;
960+ constexpr bool oob_check = true ;
961+ for (; kb0 < kb0_stop- 1 ; ++kb0) {
962+ constexpr bool last_iter = false ;
974963 constexpr int k_VKQ_sup = nbatch_fa;
975964 flash_attn_ext_f16_iter
976965 <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
977966 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
978967 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
979968 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
980969 KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
981- } else {
982- constexpr bool last_iter = true ;
983- constexpr bool oob_check = true ;
984- const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
970+ }
971+ constexpr bool last_iter = true ;
972+ const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
973+ flash_attn_ext_f16_iter
974+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
975+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
976+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
977+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
978+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
979+ } else {
980+ constexpr bool oob_check = false ;
981+ for (; kb0 < kb0_stop-1 ; ++kb0) {
982+ constexpr bool last_iter = false ;
983+ constexpr int k_VKQ_sup = nbatch_fa;
985984 flash_attn_ext_f16_iter
986985 <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
987986 T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
988987 (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
989988 ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
990989 KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
991990 }
992- } else {
993991 constexpr bool last_iter = true ;
994- constexpr bool oob_check = false ;
995992 constexpr int k_VKQ_sup = nbatch_fa;
996993 flash_attn_ext_f16_iter
997994 <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
0 commit comments