Skip to content

Commit d6e4163

Browse files
Flashattention support qkvpacked and varlen (PaddlePaddle#63289)
* Flashattention support qkvpacked and varlen * fix codestyle * fix codestyle * FlashAttention kvReduceGQA Performance Optimization * Fix problem with windows * code clean * update third_party/flashattn * update errormsg and docs * update api * update doc * update doctest * update doc, test=document_fix * update doc, test=document_fix * Update python/paddle/nn/functional/flash_attention.py Co-authored-by: zachary sun <[email protected]> * Update python/paddle/nn/functional/flash_attention.py Co-authored-by: zachary sun <[email protected]> * update doc --------- Co-authored-by: zachary sun <[email protected]>
1 parent f692f02 commit d6e4163

File tree

12 files changed

+1580
-74
lines changed

12 files changed

+1580
-74
lines changed

paddle/phi/api/yaml/backward.yaml

+24
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,18 @@
882882
func : flash_attn_grad
883883
data_type: q
884884

885+
- backward_op : flash_attn_qkvpacked_grad
886+
forward : flash_attn_qkvpacked (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
887+
args : (Tensor qkv, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false)
888+
optional : attn_mask
889+
output : Tensor(qkv_grad)
890+
infer_meta :
891+
func : FlashAttnQKVPackedGradInferMeta
892+
param : [qkv]
893+
kernel :
894+
func : flash_attn_qkvpacked_grad
895+
data_type: qkv
896+
885897
- backward_op : flash_attn_unpadded_grad
886898
forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
887899
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
@@ -894,6 +906,18 @@
894906
func : flash_attn_unpadded_grad
895907
data_type: q
896908

909+
- backward_op : flash_attn_varlen_qkvpacked_grad
910+
forward : flash_attn_varlen_qkvpacked (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
911+
args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool varlen_padded = true)
912+
optional : attn_mask
913+
output : Tensor(qkv_grad)
914+
infer_meta :
915+
func : FlashAttnQKVPackedGradInferMeta
916+
param : [qkv]
917+
kernel :
918+
func : flash_attn_varlen_qkvpacked_grad
919+
data_type: qkv
920+
897921
- backward_op : flash_attn_with_sparse_mask_grad
898922
forward : flash_attn_with_sparse_mask (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
899923
args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0)

paddle/phi/api/yaml/ops.yaml

+25
Original file line numberDiff line numberDiff line change
@@ -1109,6 +1109,18 @@
11091109
backward : flash_attn_grad
11101110
interfaces : paddle::dialect::InferSymbolicShapeInterface
11111111

1112+
- op : flash_attn_qkvpacked
1113+
args : (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
1114+
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
1115+
optional : fixed_seed_offset, attn_mask
1116+
infer_meta :
1117+
func : FlashAttnQKVPackedInferMeta
1118+
param : [qkv]
1119+
kernel :
1120+
func : flash_attn_qkvpacked
1121+
data_type : qkv
1122+
backward : flash_attn_qkvpacked_grad
1123+
11121124
- op : flash_attn_unpadded
11131125
args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "")
11141126
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
@@ -1122,6 +1134,19 @@
11221134
intermediate : softmax_lse, seed_offset
11231135
backward : flash_attn_unpadded_grad
11241136

1137+
- op : flash_attn_varlen_qkvpacked
1138+
args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true)
1139+
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
1140+
optional : fixed_seed_offset , attn_mask
1141+
infer_meta :
1142+
func : FlashAttnQKVPackedInferMeta
1143+
param : [qkv]
1144+
kernel :
1145+
func : flash_attn_varlen_qkvpacked
1146+
data_type : qkv
1147+
intermediate : softmax_lse, seed_offset
1148+
backward : flash_attn_varlen_qkvpacked_grad
1149+
11251150
- op : flash_attn_with_sparse_mask
11261151
args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "")
11271152
output : Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)

paddle/phi/infermeta/backward.cc

+6
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,12 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
244244
}
245245
}
246246

247+
void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dqkv) {
248+
if (dqkv) {
249+
dqkv->share_meta(qkv);
250+
}
251+
}
252+
247253
void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
248254
const MetaTensor& out_grad,
249255
MetaTensor* x_grad,

paddle/phi/infermeta/backward.h

+2
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,8 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
207207
MetaTensor* dk,
208208
MetaTensor* dv);
209209

210+
void FlashAttnQKVPackedGradInferMeta(const MetaTensor& qkv, MetaTensor* dq);
211+
210212
void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
211213
const MetaTensor& out_grad,
212214
MetaTensor* x_grad,

paddle/phi/infermeta/ternary.cc

+28
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ limitations under the License. */
1717
#include "glog/logging.h"
1818

1919
#include "paddle/common/ddim.h"
20+
#include "paddle/common/errors.h"
2021
#include "paddle/common/layout.h"
22+
#include "paddle/phi/core/ddim.h"
23+
#include "paddle/phi/core/enforce.h"
2124
#include "paddle/phi/kernels/funcs/common_shape.h"
2225
#include "paddle/phi/kernels/impl/box_coder.h"
2326

@@ -433,6 +436,31 @@ void FlashAttnInferMeta(const MetaTensor& q,
433436
seed_offset->set_dims({2});
434437
}
435438
}
439+
void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv,
440+
MetaTensor* out,
441+
MetaTensor* softmax,
442+
MetaTensor* softmax_lse,
443+
MetaTensor* seed_offset) {
444+
const auto& qkvdims = qkv.dims();
445+
PADDLE_ENFORCE(qkvdims.size() == 4 || qkvdims.size() == 5,
446+
phi::errors::InvalidArgument(
447+
"qkv dims must be 4(unpadded) or 5(padded batch)"));
448+
// qkv [total_*,nheads/nheads_k+2,nheads_k,headdim]
449+
auto out_dims = DDim({qkvdims[0], (qkvdims[1] - 2) * qkvdims[2], qkvdims[3]});
450+
if (qkvdims.size() == 5) {
451+
// qkv [batchsize,seqlen,nheads/nheads_k+2,nheads_k,headdim]
452+
out_dims =
453+
DDim{qkvdims[0], qkvdims[1], (qkvdims[2] - 2) * qkvdims[3], qkvdims[4]};
454+
}
455+
out->set_dims(out_dims);
456+
out->set_dtype(qkv.dtype());
457+
out->set_layout(qkv.layout());
458+
softmax->set_dtype(qkv.dtype());
459+
softmax_lse->set_dtype(qkv.dtype());
460+
if (seed_offset) {
461+
seed_offset->set_dtype(phi::DataType::INT64);
462+
}
463+
}
436464

437465
void ArangeTensorInferMeta(const MetaTensor& start,
438466
const MetaTensor& end,

paddle/phi/infermeta/ternary.h

+6
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ void FlashAttnInferMeta(const MetaTensor& q,
115115
MetaTensor* softmax_lse,
116116
MetaTensor* seed_offset);
117117

118+
void FlashAttnQKVPackedInferMeta(const MetaTensor& qkv,
119+
MetaTensor* out,
120+
MetaTensor* softmax,
121+
MetaTensor* softmax_lse,
122+
MetaTensor* seed_offset);
123+
118124
void InstanceNormInferMeta(const MetaTensor& x,
119125
const MetaTensor& scale,
120126
const MetaTensor& bias,

0 commit comments

Comments
 (0)