35
35
torch .complex128 ,
36
36
}
37
37
38
- _promotion_table = {
39
- # bool
40
- (torch .bool , torch .bool ): torch .bool ,
38
+ _promotion_table = {
41
39
# ints
42
- (torch .int8 , torch .int8 ): torch .int8 ,
43
40
(torch .int8 , torch .int16 ): torch .int16 ,
44
41
(torch .int8 , torch .int32 ): torch .int32 ,
45
42
(torch .int8 , torch .int64 ): torch .int64 ,
46
- (torch .int16 , torch .int8 ): torch .int16 ,
47
- (torch .int16 , torch .int16 ): torch .int16 ,
48
43
(torch .int16 , torch .int32 ): torch .int32 ,
49
44
(torch .int16 , torch .int64 ): torch .int64 ,
50
- (torch .int32 , torch .int8 ): torch .int32 ,
51
- (torch .int32 , torch .int16 ): torch .int32 ,
52
- (torch .int32 , torch .int32 ): torch .int32 ,
53
45
(torch .int32 , torch .int64 ): torch .int64 ,
54
- (torch .int64 , torch .int8 ): torch .int64 ,
55
- (torch .int64 , torch .int16 ): torch .int64 ,
56
- (torch .int64 , torch .int32 ): torch .int64 ,
57
- (torch .int64 , torch .int64 ): torch .int64 ,
58
- # uints
59
- (torch .uint8 , torch .uint8 ): torch .uint8 ,
60
46
# ints and uints (mixed sign)
61
- (torch .int8 , torch .uint8 ): torch .int16 ,
62
- (torch .int16 , torch .uint8 ): torch .int16 ,
63
- (torch .int32 , torch .uint8 ): torch .int32 ,
64
- (torch .int64 , torch .uint8 ): torch .int64 ,
65
47
(torch .uint8 , torch .int8 ): torch .int16 ,
66
48
(torch .uint8 , torch .int16 ): torch .int16 ,
67
49
(torch .uint8 , torch .int32 ): torch .int32 ,
68
50
(torch .uint8 , torch .int64 ): torch .int64 ,
69
51
# floats
70
- (torch .float32 , torch .float32 ): torch .float32 ,
71
52
(torch .float32 , torch .float64 ): torch .float64 ,
72
- (torch .float64 , torch .float32 ): torch .float64 ,
73
- (torch .float64 , torch .float64 ): torch .float64 ,
74
53
# complexes
75
- (torch .complex64 , torch .complex64 ): torch .complex64 ,
76
54
(torch .complex64 , torch .complex128 ): torch .complex128 ,
77
- (torch .complex128 , torch .complex64 ): torch .complex128 ,
78
- (torch .complex128 , torch .complex128 ): torch .complex128 ,
79
55
# Mixed float and complex
80
56
(torch .float32 , torch .complex64 ): torch .complex64 ,
81
57
(torch .float32 , torch .complex128 ): torch .complex128 ,
82
58
(torch .float64 , torch .complex64 ): torch .complex128 ,
83
59
(torch .float64 , torch .complex128 ): torch .complex128 ,
84
60
}
85
61
62
+ _promotion_table .update ({(b , a ): c for (a , b ), c in _promotion_table .items ()})
63
+ _promotion_table .update ({(a , a ): a for a in _array_api_dtypes })
64
+
86
65
87
66
def _two_arg (f ):
88
67
@_wraps (f )
@@ -150,13 +129,18 @@ def result_type(
150
129
return _reduce (_result_type , others + scalars )
151
130
152
131
153
- def _result_type (x , y ):
132
+ def _result_type (
133
+ x : Array | DType | bool | int | float | complex ,
134
+ y : Array | DType | bool | int | float | complex ,
135
+ ) -> DType :
154
136
if not (isinstance (x , _py_scalars ) or isinstance (y , _py_scalars )):
155
- xdt = x . dtype if not isinstance (x , torch .dtype ) else x
156
- ydt = y . dtype if not isinstance (y , torch .dtype ) else y
137
+ xdt = x if isinstance (x , torch .dtype ) else x . dtype
138
+ ydt = y if isinstance (y , torch .dtype ) else y . dtype
157
139
158
- if ( xdt , ydt ) in _promotion_table :
140
+ try :
159
141
return _promotion_table [xdt , ydt ]
142
+ except KeyError :
143
+ pass
160
144
161
145
# This doesn't result_type(dtype, dtype) for non-array API dtypes
162
146
# because torch.result_type only accepts tensors. This does however, allow
@@ -301,27 +285,35 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
301
285
out = torch .unsqueeze (out , a )
302
286
return out
303
287
288
+
289
+ def _sum_prod_no_axis (x : Array , dtype : DType | None ) -> Array :
290
+ """
291
+ Implements `sum(..., axis=())` and `prod(..., axis=())`.
292
+
293
+ Works around https://github.com/pytorch/pytorch/issues/29137
294
+ """
295
+ if dtype is not None :
296
+ return x .clone () if dtype == x .dtype else x .to (dtype )
297
+
298
+ # We can't upcast uint8 according to the spec because there is no
299
+ # torch.uint64, so at least upcast to int64 which is what prod does
300
+ # when axis=None.
301
+ if x .dtype in (torch .uint8 , torch .int8 , torch .int16 , torch .int32 ):
302
+ return x .to (torch .int64 )
303
+
304
+ return x .clone ()
305
+
306
+
304
307
def prod (x : Array ,
305
308
/ ,
306
309
* ,
307
310
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
308
311
dtype : Optional [DType ] = None ,
309
312
keepdims : bool = False ,
310
313
** kwargs ) -> Array :
311
- ndim = x .ndim
312
314
313
- # https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
314
- # below because it still needs to upcast.
315
315
if axis == ():
316
- if dtype is None :
317
- # We can't upcast uint8 according to the spec because there is no
318
- # torch.uint64, so at least upcast to int64 which is what sum does
319
- # when axis=None.
320
- if x .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
321
- return x .to (torch .int64 )
322
- return x .clone ()
323
- return x .to (dtype )
324
-
316
+ return _sum_prod_no_axis (x , dtype )
325
317
# torch.prod doesn't support multiple axes
326
318
# (https://github.com/pytorch/pytorch/issues/56586).
327
319
if isinstance (axis , tuple ):
@@ -330,7 +322,7 @@ def prod(x: Array,
330
322
# torch doesn't support keepdims with axis=None
331
323
# (https://github.com/pytorch/pytorch/issues/71209)
332
324
res = torch .prod (x , dtype = dtype , ** kwargs )
333
- res = _axis_none_keepdims (res , ndim , keepdims )
325
+ res = _axis_none_keepdims (res , x . ndim , keepdims )
334
326
return res
335
327
336
328
return torch .prod (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
@@ -343,25 +335,14 @@ def sum(x: Array,
343
335
dtype : Optional [DType ] = None ,
344
336
keepdims : bool = False ,
345
337
** kwargs ) -> Array :
346
- ndim = x .ndim
347
338
348
- # https://github.com/pytorch/pytorch/issues/29137.
349
- # Make sure it upcasts.
350
339
if axis == ():
351
- if dtype is None :
352
- # We can't upcast uint8 according to the spec because there is no
353
- # torch.uint64, so at least upcast to int64 which is what sum does
354
- # when axis=None.
355
- if x .dtype in [torch .int8 , torch .int16 , torch .int32 , torch .uint8 ]:
356
- return x .to (torch .int64 )
357
- return x .clone ()
358
- return x .to (dtype )
359
-
340
+ return _sum_prod_no_axis (x , dtype )
360
341
if axis is None :
361
342
# torch doesn't support keepdims with axis=None
362
343
# (https://github.com/pytorch/pytorch/issues/71209)
363
344
res = torch .sum (x , dtype = dtype , ** kwargs )
364
- res = _axis_none_keepdims (res , ndim , keepdims )
345
+ res = _axis_none_keepdims (res , x . ndim , keepdims )
365
346
return res
366
347
367
348
return torch .sum (x , axis , dtype = dtype , keepdims = keepdims , ** kwargs )
@@ -372,7 +353,7 @@ def any(x: Array,
372
353
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
373
354
keepdims : bool = False ,
374
355
** kwargs ) -> Array :
375
- ndim = x . ndim
356
+
376
357
if axis == ():
377
358
return x .to (torch .bool )
378
359
# torch.any doesn't support multiple axes
@@ -384,7 +365,7 @@ def any(x: Array,
384
365
# torch doesn't support keepdims with axis=None
385
366
# (https://github.com/pytorch/pytorch/issues/71209)
386
367
res = torch .any (x , ** kwargs )
387
- res = _axis_none_keepdims (res , ndim , keepdims )
368
+ res = _axis_none_keepdims (res , x . ndim , keepdims )
388
369
return res .to (torch .bool )
389
370
390
371
# torch.any doesn't return bool for uint8
@@ -396,7 +377,7 @@ def all(x: Array,
396
377
axis : Optional [Union [int , Tuple [int , ...]]] = None ,
397
378
keepdims : bool = False ,
398
379
** kwargs ) -> Array :
399
- ndim = x . ndim
380
+
400
381
if axis == ():
401
382
return x .to (torch .bool )
402
383
# torch.all doesn't support multiple axes
@@ -408,7 +389,7 @@ def all(x: Array,
408
389
# torch doesn't support keepdims with axis=None
409
390
# (https://github.com/pytorch/pytorch/issues/71209)
410
391
res = torch .all (x , ** kwargs )
411
- res = _axis_none_keepdims (res , ndim , keepdims )
392
+ res = _axis_none_keepdims (res , x . ndim , keepdims )
412
393
return res .to (torch .bool )
413
394
414
395
# torch.all doesn't return bool for uint8
0 commit comments