Skip to content

Commit c5cf58c

Browse files
Use self._use_mask in sdpa rewrite call (#2135)
Co-authored-by: G. Ramalingam <[email protected]>
1 parent a1c9380 commit c5cf58c

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,10 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
7070
return True
7171

7272
def rewrite(self, op, query, key_transposed, value, mask, **_):
73-
return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion")
73+
if self._use_mask:
74+
return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion")
75+
else:
76+
return op.SDPA(query, key_transposed, value, _domain="ai.onnxruntime.fusion")
7477

7578

7679
# Rules for SDPA without mask

onnxscript/rewriter/ort_fusions/sdpa_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
@script()
32-
def _unmasked_pre_div_sdpa_script(query, key, value, mask):
32+
def _unmasked_pre_div_sdpa_script(query, key, value):
3333
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
3434
divisor = op.Constant(value_float=SQRT_SCALE_FACTOR)
3535
scaled_query = op.Div(query, divisor)
@@ -41,7 +41,7 @@ def _unmasked_pre_div_sdpa_script(query, key, value, mask):
4141

4242

4343
@script()
44-
def _unmasked_pre_mul_sdpa_script(query, key, value, mask):
44+
def _unmasked_pre_mul_sdpa_script(query, key, value):
4545
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
4646
multiplier = op.Constant(value_float=SQRT_MUL_SCALE_FACTOR)
4747
scaled_query = op.Mul(query, multiplier)
@@ -53,7 +53,7 @@ def _unmasked_pre_mul_sdpa_script(query, key, value, mask):
5353

5454

5555
@script()
56-
def _unmasked_post_div_sdpa_script(query, key, value, mask):
56+
def _unmasked_post_div_sdpa_script(query, key, value):
5757
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
5858
divisor = op.Constant(value_float=SCALE_FACTOR)
5959
attn_score = op.MatMul(query, key_transposed)
@@ -64,7 +64,7 @@ def _unmasked_post_div_sdpa_script(query, key, value, mask):
6464

6565

6666
@script()
67-
def _unmasked_post_mul_sdpa_script(query, key, value, mask):
67+
def _unmasked_post_mul_sdpa_script(query, key, value):
6868
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
6969
multiplier = op.Constant(value_float=MUL_SCALE_FACTOR)
7070
attn_score = op.MatMul(query, key_transposed)

0 commit comments

Comments
 (0)