Skip to content

Commit 0d81ebe

Browse files
Make the return type of rewrite check functions a MatchResult object (#2138)
- Check function returns a MatchResult object instead of bool - This allows propagating the failure reason to the tracer to help in debugging
1 parent af49eff commit 0d81ebe

10 files changed

+220
-112
lines changed

onnxscript/rewriter/llama_rule_sets.py

+61-41
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,12 @@ def pattern(self, op, x):
2626
def rewrite(self, op, x: ir.Value):
2727
return op.Identity(x)
2828

29-
def check(self, context, x) -> bool:
29+
def check(self, context, x) -> orp.MatchResult:
3030
del context # Unused
31-
return ir_utils.has_rank(x, 1)
31+
check_result = orp.MatchResult()
32+
if not ir_utils.has_rank(x, 1):
33+
return check_result.fail("Input is not 1D")
34+
return check_result
3235

3336

3437
class CastIdentity(orp.RewriteRuleAsClass):
@@ -43,8 +46,11 @@ def rewrite(cls, op, x: ir.Value, to: ir.Attr):
4346
return op.Identity(x)
4447

4548
@classmethod
46-
def check(cls, context, x, to) -> bool:
47-
return x.dtype == to.value
49+
def check(cls, context, x, to) -> orp.MatchResult:
50+
check_result = orp.MatchResult()
51+
if x.dtype != to.value:
52+
return check_result.fail("Input and output types are not the same")
53+
return check_result
4854

4955

5056
class CastCast(orp.RewriteRuleAsClass):
@@ -62,11 +68,13 @@ def pattern(cls, op, x, to, to_ignored):
6268
return op.Cast(op.Cast(x, to=to_ignored), to=to)
6369

6470
@classmethod
65-
def check(cls, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> bool:
66-
return (
67-
to.value in cls._allowed_tensor_types
68-
and to_ignored.value in cls._allowed_tensor_types
69-
)
71+
def check(cls, context, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr) -> orp.MatchResult:
72+
check_result = orp.MatchResult()
73+
if to.value not in cls._allowed_tensor_types:
74+
return check_result.fail(f"Output type {to.value} is not allowed")
75+
if to_ignored.value not in cls._allowed_tensor_types:
76+
return check_result.fail(f"Ignored type {to_ignored.value} is not allowed")
77+
return check_result
7078

7179
@classmethod
7280
def rewrite(cls, op, x: ir.Value, to: ir.Attr, to_ignored: ir.Attr):
@@ -85,14 +93,19 @@ def rewrite(cls, op, x: ir.Value, shape: ir.Value):
8593
return op.Identity(x)
8694

8795
@classmethod
88-
def check(cls, context, x, shape) -> bool:
96+
def check(cls, context, x, shape) -> orp.MatchResult:
97+
check_result = orp.MatchResult()
8998
if shape.const_value is None:
9099
# Shape is not a constant and cannot be guessed.
91-
return False
100+
return check_result.fail("Shape is not a constant and cannot be guessed.")
92101
if (x_shape := x.shape) is None:
93102
# We don't know the shape of the input
94-
return False
95-
return x_shape.dims == tuple(shape.const_value.numpy().tolist())
103+
return check_result.fail("Input shape is not known.")
104+
if x_shape.dims != tuple(shape.const_value.numpy().tolist()):
105+
return check_result.fail(
106+
f"Input shape {x_shape.dims} does not match the shape {shape.const_value.numpy().tolist()}."
107+
)
108+
return check_result
96109

97110

98111
class ReshapeReshape(orp.RewriteRuleAsClass):
@@ -110,12 +123,15 @@ def rewrite(cls, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
110123
return op.Reshape(x, shape)
111124

112125
@classmethod
113-
def check(cls, context, x, shape_ignored, shape) -> bool:
114-
if shape_ignored.const_value is None or shape.const_value is None:
115-
return False
126+
def check(cls, context, x, shape_ignored, shape) -> orp.MatchResult:
127+
check_result = orp.MatchResult()
128+
if shape_ignored.const_value is None:
129+
return check_result.fail("Shape ignored is not a constant.")
130+
if shape.const_value is None:
131+
return check_result.fail("Shape is not a constant.")
116132
if shape.const_value.numpy().min() <= 0:
117-
return False
118-
return True
133+
return check_result.fail("Shape has non-positive values.")
134+
return check_result
119135

120136

121137
class SlicesSplit(orp.RewriteRuleAsClass):
@@ -128,49 +144,50 @@ def pattern(cls, op, x, begin0, end0, axes0, begin1, end1, axes1):
128144
return op.Slice(x, begin0, end0, axes0), op.Slice(x, begin1, end1, axes1)
129145

130146
@classmethod
131-
def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> bool:
147+
def check(cls, context, x, begin0, end0, axes0, begin1, end1, axes1) -> orp.MatchResult:
148+
check_result = orp.MatchResult()
132149
if (
133150
axes0.const_value is None
134151
or axes1.const_value is None
135152
or axes0.const_value.numpy().tolist() != axes1.const_value.numpy().tolist()
136153
):
137-
return False
154+
return check_result.fail("Axes are not equal or not constant.")
138155
axes = axes0.const_value.numpy().tolist()
139156
if len(axes) != 1:
140-
return False
157+
return check_result.fail("Axes has more than one dimension.")
141158
if x.shape:
142159
rk = len(x.shape)
143160
else:
144161
rk = x.rank
145162
if axes[0] != -1 and axes[0] != rk - 1:
146-
return False
163+
return check_result.fail("Axes is not -1 or last dimension.")
147164
if (
148165
begin0.const_value is None
149166
or end0.const_value is None
150167
or begin1.const_value is None
151168
or end1.const_value is None
152169
):
153-
return False
170+
return check_result.fail("Begin or end are not constant values.")
154171
if begin0.const_value.numpy().tolist() != [0]:
155-
return False
172+
return check_result.fail("First begin value is not 0.")
156173
e0, b1, e1 = (
157174
end0.const_value.numpy().tolist(),
158175
begin1.const_value.numpy().tolist(),
159176
end1.const_value.numpy().tolist(),
160177
)
161178
if e0[0] != b1[0]:
162-
return False
179+
return check_result.fail("End0 is not equal to Begin1.")
163180
shape = x.shape
164181
if shape is None:
165-
return False
182+
return check_result.fail("Shape is not known.")
166183
last_dim = shape[-1]
167184
if not isinstance(last_dim, int):
168-
return False
185+
return check_result.fail("Last dimension is not known.")
169186
if last_dim != e1[0]:
170-
return False
187+
return check_result.fail("Last dimension is not equal to End1.")
171188
if last_dim // 2 != b1[0]:
172-
return False
173-
return True
189+
return check_result.fail("Last dimension is not equal to Begin1.")
190+
return check_result
174191

175192
@classmethod
176193
def rewrite(cls, op, x, begin0, end0, axes0, begin1, end1, axes1):
@@ -187,13 +204,14 @@ def pattern(cls, op, x, perm):
187204
return op.Transpose(x, perm=perm)
188205

189206
@classmethod
190-
def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool:
207+
def check(cls, context, x: ir.Value, perm: ir.Attr) -> orp.MatchResult:
208+
check_result = orp.MatchResult()
191209
if isinstance(perm, ir.RefAttr):
192-
return False
210+
return check_result.fail("Permutation is a reference attribute.")
193211
if perm.type == ir.AttributeType.INTS:
194212
if perm.value == list(range(len(perm.value))):
195-
return True
196-
return False
213+
return check_result
214+
return check_result.fail("Permutation is not identity.")
197215

198216
@classmethod
199217
def rewrite(cls, op, x: ir.Value, perm: ir.Attr):
@@ -210,10 +228,11 @@ def pattern(cls, op, x, perm1, perm2):
210228
return op.Transpose(op.Transpose(x, perm=perm1), perm=perm2)
211229

212230
@classmethod
213-
def check(cls, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> bool:
231+
def check(cls, context, x: ir.Value, perm1: ir.Attr, perm2: ir.Attr) -> orp.MatchResult:
232+
check_result = orp.MatchResult()
214233
if isinstance(perm1, ir.RefAttr) or isinstance(perm2, ir.RefAttr):
215-
return False
216-
return True
234+
return check_result.fail("Permutation is a reference attribute.")
235+
return check_result
217236

218237
@classmethod
219238
def _apply_transpose(cls, perm: tuple[int, ...], on: list[int]) -> list[int]:
@@ -257,17 +276,18 @@ def rewrite(cls, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
257276
return op.Unsqueeze(x, op.Constant(value=ir.tensor(axes, dtype=ir.DataType.INT64)))
258277

259278
@classmethod
260-
def check(cls, context, x, axes1, axes2) -> bool:
279+
def check(cls, context, x, axes1, axes2) -> orp.MatchResult:
280+
check_result = orp.MatchResult()
261281
del context # Unused
262282
del x # Unused
263283
# Currently restricted to single element positive axis
264284
v1 = ir_utils.get_singleton_value(axes1)
265285
v2 = ir_utils.get_singleton_value(axes2)
266286
if v1 is None or v2 is None:
267-
return False
287+
return check_result.fail("Axes are not constant.")
268288
if (v1 < 0) or (v2 < 0):
269-
return False
270-
return True
289+
return check_result.fail("Axes are negative.")
290+
return check_result
271291

272292

273293
cast_cast_rule = orp.make_rewrite_rule_from_class(CastCast)

onnxscript/rewriter/ort_fusions/cos_sin_cache.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -96,24 +96,32 @@ def pattern(
9696
_domain="ai.onnxruntime.fusion",
9797
)
9898

99-
def check(self, context, inv_freq, position_ids, freqs, extra_dims, **_):
99+
def check(
100+
self, context, inv_freq, position_ids, freqs, extra_dims, **_
101+
) -> pattern.MatchResult: # type: ignore[name-defined]
102+
check_result = pattern.MatchResult()
100103
# TODO(rama): handle redundant reshape/expand
101104
if self._const_freqs:
102-
return (freqs.const_value is not None) and _ir_utils.has_rank(freqs, 3)
105+
if (freqs.const_value is None) or not _ir_utils.has_rank(freqs, 3):
106+
return check_result.fail("freqs is not a constant or not 3D.", freqs)
107+
else:
108+
return check_result
103109
if (
104110
_ir_utils.has_rank(position_ids, 2) and _ir_utils.is_singleton_value(extra_dims, 1)
105111
) or (
106112
_ir_utils.has_rank(position_ids, 1) and _ir_utils.is_1d_value(extra_dims, [0, 1])
107113
):
108114
pass
109115
else:
110-
return False
116+
return check_result.fail("position_ids is not a 1D or 2D tensor.", position_ids)
111117
if not _ir_utils.has_rank(inv_freq, 3):
112-
return False
118+
return check_result.fail("inv_freq is not 3D.", inv_freq)
113119
inv_freq_shape = inv_freq.shape
114120
if inv_freq.const_value is None: # TODO: should this be inv_freq_shape?
115-
return False
116-
return inv_freq_shape[0] == 1 and inv_freq_shape[2] == 1
121+
return check_result.fail("inv_freq is not a constant.", inv_freq)
122+
if inv_freq_shape[0] != 1 or inv_freq_shape[2] != 1:
123+
return check_result.fail("inv_freq is not of shape [1, ., 1].", inv_freq)
124+
return check_result
117125

118126
def rewrite(
119127
self, op, x, inv_freq, position_ids, interleaved, num_heads, freqs, dtype, **_

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@ def pattern(cls, op, x, y, cst):
1515
return op.Div(op.MatMul(x, y), cst)
1616

1717
@classmethod
18-
def check(cls, context, x, y, cst) -> bool:
18+
def check(cls, context, x, y, cst) -> orp.MatchResult:
19+
check_result = orp.MatchResult()
1920
if cst.const_value is None:
20-
return False
21+
return check_result.fail("Divisor is not a constant value.")
2122
value = cst.const_value.numpy()
2223
if value.size > 1:
23-
return False
24-
return True
24+
return check_result.fail("Divisor is not a scalar value.")
25+
return check_result
2526

2627
@classmethod
2728
def rewrite(cls, op, x, y, cst):
@@ -38,12 +39,13 @@ def pattern(cls, op, x, y, cst):
3839
return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst)
3940

4041
@classmethod
41-
def check(cls, context, x, y, cst) -> bool:
42+
def check(cls, context, x, y, cst) -> orp.MatchResult:
43+
check_result = orp.MatchResult()
4244
if cst.const_value is None:
43-
return False
45+
return check_result.fail("Divisor is not a constant value.")
4446
if cst.const_value.numpy().size > 1:
45-
return False
46-
return True
47+
return check_result.fail("Divisor is not a scalar value.")
48+
return check_result
4749

4850
@classmethod
4951
def rewrite(cls, op, x, y, cst):
@@ -65,11 +67,14 @@ class _TransposeMatMulBase(orp.RewriteRuleAsClass):
6567
_pos: ClassVar = 1
6668

6769
@classmethod
68-
def check(cls, context, x, y) -> bool:
70+
def check(cls, context, x, y) -> orp.MatchResult:
71+
check_result = orp.MatchResult()
6972
perm = list((x if cls._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015
7073
expected_perm = list(range(len(perm)))
7174
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
72-
return perm == expected_perm
75+
if perm != expected_perm:
76+
return check_result.fail("Permutation values for Transpose are not correct.")
77+
return check_result
7378

7479
@classmethod
7580
def rewrite(cls, op, x, y):
@@ -126,13 +131,16 @@ def pattern(cls, op, x, y):
126131
return op.Transpose(op.MatMul(x, y))
127132

128133
@classmethod
129-
def check(cls, context, x, y) -> bool:
134+
def check(cls, context, x, y) -> orp.MatchResult:
135+
check_result = orp.MatchResult()
130136
matmul = list(x.uses())[0][0] # noqa: RUF015
131137
transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015
132138
perm = transpose.attributes["perm"].value
133139
expected_perm = list(range(len(perm)))
134140
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
135-
return perm == expected_perm
141+
if perm != expected_perm:
142+
return check_result.fail("Permutation values for Transpose are not correct.")
143+
return check_result
136144

137145
@classmethod
138146
def rewrite(cls, op, x, y):

onnxscript/rewriter/ort_fusions/gqa.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ def check(
9494
# key_transposed,
9595
# attention_reshaped,
9696
**_,
97-
):
97+
) -> pattern.MatchResult: # type: ignore[name-defined]
98+
check_result = pattern.MatchResult()
9899
# bindings: dict[str, int] = {}
99100
# status = (
100101
# _check_shape(bindings, query_mm_reshaped, ["B", "S", "H", "d_h"])
@@ -110,7 +111,7 @@ def check(
110111
# return False
111112
# if bindings["H"] * bindings["d_h"] != bindings["H*d_h"]:
112113
# return False
113-
return True
114+
return check_result
114115

115116
def rewrite(
116117
self,

0 commit comments

Comments
 (0)