Skip to content

Commit a833507

Browse files
shubhambhokare1gramalingam
authored andcommitted
Use self._use_mask in sdpa rewrite call (microsoft#2135)
Co-authored-by: G. Ramalingam <[email protected]>
1 parent ed7e08e commit a833507

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

onnxscript/rewriter/ort_fusions/sdpa.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,10 @@ def check(self, op, query, key_transposed, value, mask, query_scale, key_scale,
6060
return True
6161

6262
def rewrite(self, op, query, key_transposed, value, mask, **_):
63-
return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion")
63+
if self._use_mask:
64+
return op.SDPA(query, key_transposed, value, mask, _domain="ai.onnxruntime.fusion")
65+
else:
66+
return op.SDPA(query, key_transposed, value, _domain="ai.onnxruntime.fusion")
6467

6568

6669
masked_pre_mul_sdpa_rule = SDPA.rule("masked_pre_mul_sdpa", use_mask=True, pre_scale=True)

onnxscript/rewriter/ort_fusions/sdpa_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
@script()
32-
def _unmasked_pre_div_sdpa_script(query, key, value, mask):
32+
def _unmasked_pre_div_sdpa_script(query, key, value):
3333
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
3434
divisor = op.Constant(value_float=SQRT_SCALE_FACTOR)
3535
scaled_query = op.Div(query, divisor)
@@ -41,7 +41,7 @@ def _unmasked_pre_div_sdpa_script(query, key, value, mask):
4141

4242

4343
@script()
44-
def _unmasked_pre_mul_sdpa_script(query, key, value, mask):
44+
def _unmasked_pre_mul_sdpa_script(query, key, value):
4545
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
4646
multiplier = op.Constant(value_float=SQRT_MUL_SCALE_FACTOR)
4747
scaled_query = op.Mul(query, multiplier)
@@ -53,7 +53,7 @@ def _unmasked_pre_mul_sdpa_script(query, key, value, mask):
5353

5454

5555
@script()
56-
def _unmasked_post_div_sdpa_script(query, key, value, mask):
56+
def _unmasked_post_div_sdpa_script(query, key, value):
5757
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
5858
divisor = op.Constant(value_float=SCALE_FACTOR)
5959
attn_score = op.MatMul(query, key_transposed)
@@ -64,7 +64,7 @@ def _unmasked_post_div_sdpa_script(query, key, value, mask):
6464

6565

6666
@script()
67-
def _unmasked_post_mul_sdpa_script(query, key, value, mask):
67+
def _unmasked_post_mul_sdpa_script(query, key, value):
6868
key_transposed = op.Transpose(key, perm=[0, 1, 3, 2])
6969
multiplier = op.Constant(value_float=MUL_SCALE_FACTOR)
7070
attn_score = op.MatMul(query, key_transposed)

0 commit comments

Comments
 (0)