|
882 | 882 | func : flash_attn_grad
|
883 | 883 | data_type: q
|
884 | 884 |
|
| 885 | +- backward_op : flash_attn_qkvpacked_grad |
| 886 | + forward : flash_attn_qkvpacked (Tensor qkv, Tensor fixed_seed_offset, Tensor attn_mask, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) |
| 887 | + args : (Tensor qkv, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, float dropout = 0.0, bool causal = false) |
| 888 | + optional : attn_mask |
| 889 | + output : Tensor(qkv_grad) |
| 890 | + infer_meta : |
| 891 | + func : FlashAttnQKVPackedGradInferMeta |
| 892 | + param : [qkv] |
| 893 | + kernel : |
| 894 | + func : flash_attn_qkvpacked_grad |
| 895 | + data_type: qkv |
| 896 | + |
885 | 897 | - backward_op : flash_attn_unpadded_grad
|
886 | 898 | forward : flash_attn_unpadded (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
|
887 | 899 | args : (Tensor q, Tensor k, Tensor v, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false)
|
|
894 | 906 | func : flash_attn_unpadded_grad
|
895 | 907 | data_type: q
|
896 | 908 |
|
| 909 | +- backward_op : flash_attn_varlen_qkvpacked_grad |
| 910 | + forward : flash_attn_varlen_qkvpacked (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor fixed_seed_offset, Tensor attn_mask, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool return_softmax = false, bool is_test = false, str rng_name = "", bool varlen_padded = true) -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset) |
| 911 | + args : (Tensor qkv, Tensor cu_seqlens_q, Tensor cu_seqlens_k, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor attn_mask, Tensor out_grad, int64_t max_seqlen_q, int64_t max_seqlen_k, float scale, float dropout = 0.0, bool causal = false, bool varlen_padded = true) |
| 912 | + optional : attn_mask |
| 913 | + output : Tensor(qkv_grad) |
| 914 | + infer_meta : |
| 915 | + func : FlashAttnQKVPackedGradInferMeta |
| 916 | + param : [qkv] |
| 917 | + kernel : |
| 918 | + func : flash_attn_varlen_qkvpacked_grad |
| 919 | + data_type: qkv |
| 920 | + |
897 | 921 | - backward_op : flash_attn_with_sparse_mask_grad
|
898 | 922 | forward : flash_attn_with_sparse_mask (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor fixed_seed_offset, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0, bool return_softmax = false, bool is_test = false, str rng_name = "") -> Tensor(out), Tensor(softmax), Tensor(softmax_lse), Tensor(seed_offset)
|
899 | 923 | args : (Tensor q, Tensor k, Tensor v, Tensor attn_mask_start_row_indices, Tensor out, Tensor softmax_lse, Tensor seed_offset, Tensor out_grad, float dropout = 0.0, bool causal = false, int attn_mask_start_row = 0)
|
|
0 commit comments