@@ -225,9 +225,7 @@ def flash_v3(
225
225
q = q .transpose (1 , 2 ).contiguous ()
226
226
k = k .transpose (1 , 2 ).contiguous ()
227
227
v = v .transpose (1 , 2 ).contiguous ()
228
- fn = lambda : flashattn_hopper_cuda .fwd (
229
- q , k , v , None , self .sm_scale , self .causal
230
- )
228
+ fn = lambda : flash_attn_v3 (q , k , v , self .sm_scale , self .causal )
231
229
return fn
232
230
233
231
@register_benchmark ()
@@ -360,6 +358,25 @@ def sdpa_flash_attention(q, k, v):
360
358
v ,
361
359
)
362
360
361
+ @register_benchmark ()
362
+ def flex_attention (self , q , k , v ):
363
+ from torch .nn .attention .flex_attention import create_block_mask , flex_attention
364
+
365
+ def causal_mask (b , h , q_idx , kv_idx ):
366
+ return q_idx >= kv_idx
367
+
368
+ flex_attention = torch .compile (flex_attention , dynamic = False )
369
+
370
+ if self .causal :
371
+ B , H , S , D = q .shape
372
+ block_mask = create_block_mask (
373
+ causal_mask , B = None , H = None , Q_LEN = S , KV_LEN = S
374
+ )
375
+ else :
376
+ block_mask = None
377
+
378
+ return lambda : flex_attention (q , k , v , block_mask = block_mask )
379
+
363
380
@register_metric ()
364
381
def tflops (
365
382
self , fn_name : str , example_inputs : Any , metrics : BenchmarkOperatorMetrics
0 commit comments