Skip to content

Commit 611bf70

Browse files
bertmaherfacebook-github-bot
authored andcommitted
Restore FlexAttention and FlashV3 backward (#2473)
Summary: Pull Request resolved: #2473 Reviewed By: xuzhao9 Differential Revision: D63543625 Pulled By: bertmaher fbshipit-source-id: 1693e15875544bda0f5f6c69daa5597fffd80509
1 parent 0f05015 commit 611bf70

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

torchbenchmark/operators/flash_attention/operator.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,7 @@ def flash_v3(
225225
q = q.transpose(1, 2).contiguous()
226226
k = k.transpose(1, 2).contiguous()
227227
v = v.transpose(1, 2).contiguous()
228-
fn = lambda: flashattn_hopper_cuda.fwd(
229-
q, k, v, None, self.sm_scale, self.causal
230-
)
228+
fn = lambda: flash_attn_v3(q, k, v, self.sm_scale, self.causal)
231229
return fn
232230

233231
@register_benchmark()
@@ -360,6 +358,25 @@ def sdpa_flash_attention(q, k, v):
360358
v,
361359
)
362360

361+
@register_benchmark()
362+
def flex_attention(self, q, k, v):
363+
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
364+
365+
def causal_mask(b, h, q_idx, kv_idx):
366+
return q_idx >= kv_idx
367+
368+
flex_attention = torch.compile(flex_attention, dynamic=False)
369+
370+
if self.causal:
371+
B, H, S, D = q.shape
372+
block_mask = create_block_mask(
373+
causal_mask, B=None, H=None, Q_LEN=S, KV_LEN=S
374+
)
375+
else:
376+
block_mask = None
377+
378+
return lambda: flex_attention(q, k, v, block_mask=block_mask)
379+
363380
@register_metric()
364381
def tflops(
365382
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics

0 commit comments

Comments
 (0)