42
42
paddle .complex128 ,
43
43
}
44
44
45
+ # NOTE: Implicit promotion rules of Paddle is a bit strict than other frameworks,
46
+ # see details: https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/advanced/auto_type_promotion_cn.html
45
47
_promotion_table = {
46
48
# bool
47
49
(paddle .bool , paddle .bool ): paddle .bool ,
48
50
# ints
49
51
(paddle .int8 , paddle .int8 ): paddle .int8 ,
50
- (paddle .int8 , paddle .int16 ): paddle .int16 ,
51
- (paddle .int8 , paddle .int32 ): paddle .int32 ,
52
- (paddle .int8 , paddle .int64 ): paddle .int64 ,
53
- (paddle .int16 , paddle .int8 ): paddle .int16 ,
54
52
(paddle .int16 , paddle .int16 ): paddle .int16 ,
55
- (paddle .int16 , paddle .int32 ): paddle .int32 ,
56
- (paddle .int16 , paddle .int64 ): paddle .int64 ,
57
- (paddle .int32 , paddle .int8 ): paddle .int32 ,
58
- (paddle .int32 , paddle .int16 ): paddle .int32 ,
59
53
(paddle .int32 , paddle .int32 ): paddle .int32 ,
60
- (paddle .int32 , paddle .int64 ): paddle .int64 ,
61
- (paddle .int64 , paddle .int8 ): paddle .int64 ,
62
- (paddle .int64 , paddle .int16 ): paddle .int64 ,
63
- (paddle .int64 , paddle .int32 ): paddle .int64 ,
64
54
(paddle .int64 , paddle .int64 ): paddle .int64 ,
65
55
# uints
66
56
(paddle .uint8 , paddle .uint8 ): paddle .uint8 ,
67
- # ints and uints (mixed sign)
68
- (paddle .int8 , paddle .uint8 ): paddle .int16 ,
69
- (paddle .int16 , paddle .uint8 ): paddle .int16 ,
70
- (paddle .int32 , paddle .uint8 ): paddle .int32 ,
71
- (paddle .int64 , paddle .uint8 ): paddle .int64 ,
72
- (paddle .uint8 , paddle .int8 ): paddle .int16 ,
73
- (paddle .uint8 , paddle .int16 ): paddle .int16 ,
74
- (paddle .uint8 , paddle .int32 ): paddle .int32 ,
75
- (paddle .uint8 , paddle .int64 ): paddle .int64 ,
76
57
# floats
77
58
(paddle .float32 , paddle .float32 ): paddle .float32 ,
78
59
(paddle .float32 , paddle .float64 ): paddle .float64 ,
@@ -158,12 +139,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
158
139
paddle .float64 : True ,
159
140
paddle .complex64 : True ,
160
141
paddle .complex128 : True ,
161
- paddle .uint8 : False ,
162
- paddle .int8 : False ,
163
- paddle .int16 : False ,
164
- paddle .int32 : False ,
165
- paddle .int64 : False ,
166
- paddle .bool : False ,
142
+ paddle .uint8 : True ,
143
+ paddle .int8 : True ,
144
+ paddle .int16 : True ,
145
+ paddle .int32 : True ,
146
+ paddle .int64 : True ,
147
+ paddle .bool : True ,
167
148
},
168
149
paddle .float16 : {
169
150
paddle .bfloat16 : True ,
@@ -172,12 +153,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
172
153
paddle .float64 : True ,
173
154
paddle .complex64 : True ,
174
155
paddle .complex128 : True ,
175
- paddle .uint8 : False ,
176
- paddle .int8 : False ,
177
- paddle .int16 : False ,
178
- paddle .int32 : False ,
179
- paddle .int64 : False ,
180
- paddle .bool : False ,
156
+ paddle .uint8 : True ,
157
+ paddle .int8 : True ,
158
+ paddle .int16 : True ,
159
+ paddle .int32 : True ,
160
+ paddle .int64 : True ,
161
+ paddle .bool : True ,
181
162
},
182
163
paddle .float32 : {
183
164
paddle .bfloat16 : True ,
@@ -186,12 +167,12 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
186
167
paddle .float64 : True ,
187
168
paddle .complex64 : True ,
188
169
paddle .complex128 : True ,
189
- paddle .uint8 : False ,
190
- paddle .int8 : False ,
191
- paddle .int16 : False ,
192
- paddle .int32 : False ,
193
- paddle .int64 : False ,
194
- paddle .bool : False ,
170
+ paddle .uint8 : True ,
171
+ paddle .int8 : True ,
172
+ paddle .int16 : True ,
173
+ paddle .int32 : True ,
174
+ paddle .int64 : True ,
175
+ paddle .bool : True ,
195
176
},
196
177
paddle .float64 : {
197
178
paddle .bfloat16 : True ,
@@ -200,40 +181,40 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
200
181
paddle .float64 : True ,
201
182
paddle .complex64 : True ,
202
183
paddle .complex128 : True ,
203
- paddle .uint8 : False ,
204
- paddle .int8 : False ,
205
- paddle .int16 : False ,
206
- paddle .int32 : False ,
207
- paddle .int64 : False ,
208
- paddle .bool : False ,
184
+ paddle .uint8 : True ,
185
+ paddle .int8 : True ,
186
+ paddle .int16 : True ,
187
+ paddle .int32 : True ,
188
+ paddle .int64 : True ,
189
+ paddle .bool : True ,
209
190
},
210
191
paddle .complex64 : {
211
- paddle .bfloat16 : False ,
212
- paddle .float16 : False ,
213
- paddle .float32 : False ,
214
- paddle .float64 : False ,
192
+ paddle .bfloat16 : True ,
193
+ paddle .float16 : True ,
194
+ paddle .float32 : True ,
195
+ paddle .float64 : True ,
215
196
paddle .complex64 : True ,
216
197
paddle .complex128 : True ,
217
- paddle .uint8 : False ,
218
- paddle .int8 : False ,
219
- paddle .int16 : False ,
220
- paddle .int32 : False ,
221
- paddle .int64 : False ,
222
- paddle .bool : False ,
198
+ paddle .uint8 : True ,
199
+ paddle .int8 : True ,
200
+ paddle .int16 : True ,
201
+ paddle .int32 : True ,
202
+ paddle .int64 : True ,
203
+ paddle .bool : True ,
223
204
},
224
205
paddle .complex128 : {
225
- paddle .bfloat16 : False ,
226
- paddle .float16 : False ,
227
- paddle .float32 : False ,
228
- paddle .float64 : False ,
206
+ paddle .bfloat16 : True ,
207
+ paddle .float16 : True ,
208
+ paddle .float32 : True ,
209
+ paddle .float64 : True ,
229
210
paddle .complex64 : True ,
230
211
paddle .complex128 : True ,
231
- paddle .uint8 : False ,
232
- paddle .int8 : False ,
233
- paddle .int16 : False ,
234
- paddle .int32 : False ,
235
- paddle .int64 : False ,
236
- paddle .bool : False ,
212
+ paddle .uint8 : True ,
213
+ paddle .int8 : True ,
214
+ paddle .int16 : True ,
215
+ paddle .int32 : True ,
216
+ paddle .int64 : True ,
217
+ paddle .bool : True ,
237
218
},
238
219
paddle .uint8 : {
239
220
paddle .bfloat16 : True ,
@@ -247,7 +228,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
247
228
paddle .int16 : True ,
248
229
paddle .int32 : True ,
249
230
paddle .int64 : True ,
250
- paddle .bool : False ,
231
+ paddle .bool : True ,
251
232
},
252
233
paddle .int8 : {
253
234
paddle .bfloat16 : True ,
@@ -261,7 +242,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
261
242
paddle .int16 : True ,
262
243
paddle .int32 : True ,
263
244
paddle .int64 : True ,
264
- paddle .bool : False ,
245
+ paddle .bool : True ,
265
246
},
266
247
paddle .int16 : {
267
248
paddle .bfloat16 : True ,
@@ -275,7 +256,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
275
256
paddle .int16 : True ,
276
257
paddle .int32 : True ,
277
258
paddle .int64 : True ,
278
- paddle .bool : False ,
259
+ paddle .bool : True ,
279
260
},
280
261
paddle .int32 : {
281
262
paddle .bfloat16 : True ,
@@ -289,7 +270,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
289
270
paddle .int16 : True ,
290
271
paddle .int32 : True ,
291
272
paddle .int64 : True ,
292
- paddle .bool : False ,
273
+ paddle .bool : True ,
293
274
},
294
275
paddle .int64 : {
295
276
paddle .bfloat16 : True ,
@@ -303,7 +284,7 @@ def can_cast(from_: Union[Dtype, array], to: Dtype, /) -> bool:
303
284
paddle .int16 : True ,
304
285
paddle .int32 : True ,
305
286
paddle .int64 : True ,
306
- paddle .bool : False ,
287
+ paddle .bool : True ,
307
288
},
308
289
paddle .bool : {
309
290
paddle .bfloat16 : True ,
0 commit comments