29
29
30
30
31
31
@script ()
32
- def _unmasked_pre_div_sdpa_script (query , key , value , mask ):
32
+ def _unmasked_pre_div_sdpa_script (query , key , value ):
33
33
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
34
34
divisor = op .Constant (value_float = SQRT_SCALE_FACTOR )
35
35
scaled_query = op .Div (query , divisor )
@@ -41,7 +41,7 @@ def _unmasked_pre_div_sdpa_script(query, key, value, mask):
41
41
42
42
43
43
@script ()
44
- def _unmasked_pre_mul_sdpa_script (query , key , value , mask ):
44
+ def _unmasked_pre_mul_sdpa_script (query , key , value ):
45
45
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
46
46
multiplier = op .Constant (value_float = SQRT_MUL_SCALE_FACTOR )
47
47
scaled_query = op .Mul (query , multiplier )
@@ -53,7 +53,7 @@ def _unmasked_pre_mul_sdpa_script(query, key, value, mask):
53
53
54
54
55
55
@script ()
56
- def _unmasked_post_div_sdpa_script (query , key , value , mask ):
56
+ def _unmasked_post_div_sdpa_script (query , key , value ):
57
57
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
58
58
divisor = op .Constant (value_float = SCALE_FACTOR )
59
59
attn_score = op .MatMul (query , key_transposed )
@@ -64,7 +64,7 @@ def _unmasked_post_div_sdpa_script(query, key, value, mask):
64
64
65
65
66
66
@script ()
67
- def _unmasked_post_mul_sdpa_script (query , key , value , mask ):
67
+ def _unmasked_post_mul_sdpa_script (query , key , value ):
68
68
key_transposed = op .Transpose (key , perm = [0 , 1 , 3 , 2 ])
69
69
multiplier = op .Constant (value_float = MUL_SCALE_FACTOR )
70
70
attn_score = op .MatMul (query , key_transposed )
0 commit comments