@@ -26,9 +26,12 @@ def pattern(self, op, x):
26
26
def rewrite (self , op , x : ir .Value ):
27
27
return op .Identity (x )
28
28
29
- def check (self , context , x ) -> bool :
29
+ def check (self , context , x ) -> orp . MatchResult :
30
30
del context # Unused
31
- return ir_utils .has_rank (x , 1 )
31
+ check_result = orp .MatchResult ()
32
+ if not ir_utils .has_rank (x , 1 ):
33
+ return check_result .fail ("Input is not 1D" )
34
+ return check_result
32
35
33
36
34
37
class CastIdentity (orp .RewriteRuleAsClass ):
@@ -43,8 +46,11 @@ def rewrite(cls, op, x: ir.Value, to: ir.Attr):
43
46
return op .Identity (x )
44
47
45
48
@classmethod
46
- def check (cls , context , x , to ) -> bool :
47
- return x .dtype == to .value
49
+ def check (cls , context , x , to ) -> orp .MatchResult :
50
+ check_result = orp .MatchResult ()
51
+ if x .dtype != to .value :
52
+ return check_result .fail ("Input and output types are not the same" )
53
+ return check_result
48
54
49
55
50
56
class CastCast (orp .RewriteRuleAsClass ):
@@ -62,11 +68,13 @@ def pattern(cls, op, x, to, to_ignored):
62
68
return op .Cast (op .Cast (x , to = to_ignored ), to = to )
63
69
64
70
@classmethod
65
- def check (cls , context , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ) -> bool :
66
- return (
67
- to .value in cls ._allowed_tensor_types
68
- and to_ignored .value in cls ._allowed_tensor_types
69
- )
71
+ def check (cls , context , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ) -> orp .MatchResult :
72
+ check_result = orp .MatchResult ()
73
+ if to .value not in cls ._allowed_tensor_types :
74
+ return check_result .fail (f"Output type { to .value } is not allowed" )
75
+ if to_ignored .value not in cls ._allowed_tensor_types :
76
+ return check_result .fail (f"Ignored type { to_ignored .value } is not allowed" )
77
+ return check_result
70
78
71
79
@classmethod
72
80
def rewrite (cls , op , x : ir .Value , to : ir .Attr , to_ignored : ir .Attr ):
@@ -85,14 +93,19 @@ def rewrite(cls, op, x: ir.Value, shape: ir.Value):
85
93
return op .Identity (x )
86
94
87
95
@classmethod
88
- def check (cls , context , x , shape ) -> bool :
96
+ def check (cls , context , x , shape ) -> orp .MatchResult :
97
+ check_result = orp .MatchResult ()
89
98
if shape .const_value is None :
90
99
# Shape is not a constant and cannot be guessed.
91
- return False
100
+ return check_result . fail ( "Shape is not a constant and cannot be guessed." )
92
101
if (x_shape := x .shape ) is None :
93
102
# We don't know the shape of the input
94
- return False
95
- return x_shape .dims == tuple (shape .const_value .numpy ().tolist ())
103
+ return check_result .fail ("Input shape is not known." )
104
+ if x_shape .dims != tuple (shape .const_value .numpy ().tolist ()):
105
+ return check_result .fail (
106
+ f"Input shape { x_shape .dims } does not match the shape { shape .const_value .numpy ().tolist ()} ."
107
+ )
108
+ return check_result
96
109
97
110
98
111
class ReshapeReshape (orp .RewriteRuleAsClass ):
@@ -110,12 +123,15 @@ def rewrite(cls, op, x: ir.Value, shape_ignored: ir.Value, shape: ir.Value):
110
123
return op .Reshape (x , shape )
111
124
112
125
@classmethod
113
- def check (cls , context , x , shape_ignored , shape ) -> bool :
114
- if shape_ignored .const_value is None or shape .const_value is None :
115
- return False
126
+ def check (cls , context , x , shape_ignored , shape ) -> orp .MatchResult :
127
+ check_result = orp .MatchResult ()
128
+ if shape_ignored .const_value is None :
129
+ return check_result .fail ("Shape ignored is not a constant." )
130
+ if shape .const_value is None :
131
+ return check_result .fail ("Shape is not a constant." )
116
132
if shape .const_value .numpy ().min () <= 0 :
117
- return False
118
- return True
133
+ return check_result . fail ( "Shape has non-positive values." )
134
+ return check_result
119
135
120
136
121
137
class SlicesSplit (orp .RewriteRuleAsClass ):
@@ -128,49 +144,50 @@ def pattern(cls, op, x, begin0, end0, axes0, begin1, end1, axes1):
128
144
return op .Slice (x , begin0 , end0 , axes0 ), op .Slice (x , begin1 , end1 , axes1 )
129
145
130
146
@classmethod
131
- def check (cls , context , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ) -> bool :
147
+ def check (cls , context , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ) -> orp .MatchResult :
148
+ check_result = orp .MatchResult ()
132
149
if (
133
150
axes0 .const_value is None
134
151
or axes1 .const_value is None
135
152
or axes0 .const_value .numpy ().tolist () != axes1 .const_value .numpy ().tolist ()
136
153
):
137
- return False
154
+ return check_result . fail ( "Axes are not equal or not constant." )
138
155
axes = axes0 .const_value .numpy ().tolist ()
139
156
if len (axes ) != 1 :
140
- return False
157
+ return check_result . fail ( "Axes has more than one dimension." )
141
158
if x .shape :
142
159
rk = len (x .shape )
143
160
else :
144
161
rk = x .rank
145
162
if axes [0 ] != - 1 and axes [0 ] != rk - 1 :
146
- return False
163
+ return check_result . fail ( "Axes is not -1 or last dimension." )
147
164
if (
148
165
begin0 .const_value is None
149
166
or end0 .const_value is None
150
167
or begin1 .const_value is None
151
168
or end1 .const_value is None
152
169
):
153
- return False
170
+ return check_result . fail ( "Begin or end are not constant values." )
154
171
if begin0 .const_value .numpy ().tolist () != [0 ]:
155
- return False
172
+ return check_result . fail ( "First begin value is not 0." )
156
173
e0 , b1 , e1 = (
157
174
end0 .const_value .numpy ().tolist (),
158
175
begin1 .const_value .numpy ().tolist (),
159
176
end1 .const_value .numpy ().tolist (),
160
177
)
161
178
if e0 [0 ] != b1 [0 ]:
162
- return False
179
+ return check_result . fail ( "End0 is not equal to Begin1." )
163
180
shape = x .shape
164
181
if shape is None :
165
- return False
182
+ return check_result . fail ( "Shape is not known." )
166
183
last_dim = shape [- 1 ]
167
184
if not isinstance (last_dim , int ):
168
- return False
185
+ return check_result . fail ( "Last dimension is not known." )
169
186
if last_dim != e1 [0 ]:
170
- return False
187
+ return check_result . fail ( "Last dimension is not equal to End1." )
171
188
if last_dim // 2 != b1 [0 ]:
172
- return False
173
- return True
189
+ return check_result . fail ( "Last dimension is not equal to Begin1." )
190
+ return check_result
174
191
175
192
@classmethod
176
193
def rewrite (cls , op , x , begin0 , end0 , axes0 , begin1 , end1 , axes1 ):
@@ -187,13 +204,14 @@ def pattern(cls, op, x, perm):
187
204
return op .Transpose (x , perm = perm )
188
205
189
206
@classmethod
190
- def check (cls , context , x : ir .Value , perm : ir .Attr ) -> bool :
207
+ def check (cls , context , x : ir .Value , perm : ir .Attr ) -> orp .MatchResult :
208
+ check_result = orp .MatchResult ()
191
209
if isinstance (perm , ir .RefAttr ):
192
- return False
210
+ return check_result . fail ( "Permutation is a reference attribute." )
193
211
if perm .type == ir .AttributeType .INTS :
194
212
if perm .value == list (range (len (perm .value ))):
195
- return True
196
- return False
213
+ return check_result
214
+ return check_result . fail ( "Permutation is not identity." )
197
215
198
216
@classmethod
199
217
def rewrite (cls , op , x : ir .Value , perm : ir .Attr ):
@@ -210,10 +228,11 @@ def pattern(cls, op, x, perm1, perm2):
210
228
return op .Transpose (op .Transpose (x , perm = perm1 ), perm = perm2 )
211
229
212
230
@classmethod
213
- def check (cls , context , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ) -> bool :
231
+ def check (cls , context , x : ir .Value , perm1 : ir .Attr , perm2 : ir .Attr ) -> orp .MatchResult :
232
+ check_result = orp .MatchResult ()
214
233
if isinstance (perm1 , ir .RefAttr ) or isinstance (perm2 , ir .RefAttr ):
215
- return False
216
- return True
234
+ return check_result . fail ( "Permutation is a reference attribute." )
235
+ return check_result
217
236
218
237
@classmethod
219
238
def _apply_transpose (cls , perm : tuple [int , ...], on : list [int ]) -> list [int ]:
@@ -257,17 +276,18 @@ def rewrite(cls, op, x: ir.Value, axes1: ir.Value, axes2: ir.Value):
257
276
return op .Unsqueeze (x , op .Constant (value = ir .tensor (axes , dtype = ir .DataType .INT64 )))
258
277
259
278
@classmethod
260
- def check (cls , context , x , axes1 , axes2 ) -> bool :
279
+ def check (cls , context , x , axes1 , axes2 ) -> orp .MatchResult :
280
+ check_result = orp .MatchResult ()
261
281
del context # Unused
262
282
del x # Unused
263
283
# Currently restricted to single element positive axis
264
284
v1 = ir_utils .get_singleton_value (axes1 )
265
285
v2 = ir_utils .get_singleton_value (axes2 )
266
286
if v1 is None or v2 is None :
267
- return False
287
+ return check_result . fail ( "Axes are not constant." )
268
288
if (v1 < 0 ) or (v2 < 0 ):
269
- return False
270
- return True
289
+ return check_result . fail ( "Axes are negative." )
290
+ return check_result
271
291
272
292
273
293
cast_cast_rule = orp .make_rewrite_rule_from_class (CastCast )
0 commit comments