Skip to content

Commit 9194c5c

Browse files
authored
MAINT: simplify torch dtype promotion (#303)
reviewed at #303
1 parent d743dc1 commit 9194c5c

File tree

1 file changed

+40
-59
lines changed

1 file changed

+40
-59
lines changed

array_api_compat/torch/_aliases.py

+40-59
Original file line numberDiff line numberDiff line change
@@ -35,54 +35,33 @@
3535
torch.complex128,
3636
}
3737

38-
_promotion_table = {
39-
# bool
40-
(torch.bool, torch.bool): torch.bool,
38+
_promotion_table = {
4139
# ints
42-
(torch.int8, torch.int8): torch.int8,
4340
(torch.int8, torch.int16): torch.int16,
4441
(torch.int8, torch.int32): torch.int32,
4542
(torch.int8, torch.int64): torch.int64,
46-
(torch.int16, torch.int8): torch.int16,
47-
(torch.int16, torch.int16): torch.int16,
4843
(torch.int16, torch.int32): torch.int32,
4944
(torch.int16, torch.int64): torch.int64,
50-
(torch.int32, torch.int8): torch.int32,
51-
(torch.int32, torch.int16): torch.int32,
52-
(torch.int32, torch.int32): torch.int32,
5345
(torch.int32, torch.int64): torch.int64,
54-
(torch.int64, torch.int8): torch.int64,
55-
(torch.int64, torch.int16): torch.int64,
56-
(torch.int64, torch.int32): torch.int64,
57-
(torch.int64, torch.int64): torch.int64,
58-
# uints
59-
(torch.uint8, torch.uint8): torch.uint8,
6046
# ints and uints (mixed sign)
61-
(torch.int8, torch.uint8): torch.int16,
62-
(torch.int16, torch.uint8): torch.int16,
63-
(torch.int32, torch.uint8): torch.int32,
64-
(torch.int64, torch.uint8): torch.int64,
6547
(torch.uint8, torch.int8): torch.int16,
6648
(torch.uint8, torch.int16): torch.int16,
6749
(torch.uint8, torch.int32): torch.int32,
6850
(torch.uint8, torch.int64): torch.int64,
6951
# floats
70-
(torch.float32, torch.float32): torch.float32,
7152
(torch.float32, torch.float64): torch.float64,
72-
(torch.float64, torch.float32): torch.float64,
73-
(torch.float64, torch.float64): torch.float64,
7453
# complexes
75-
(torch.complex64, torch.complex64): torch.complex64,
7654
(torch.complex64, torch.complex128): torch.complex128,
77-
(torch.complex128, torch.complex64): torch.complex128,
78-
(torch.complex128, torch.complex128): torch.complex128,
7955
# Mixed float and complex
8056
(torch.float32, torch.complex64): torch.complex64,
8157
(torch.float32, torch.complex128): torch.complex128,
8258
(torch.float64, torch.complex64): torch.complex128,
8359
(torch.float64, torch.complex128): torch.complex128,
8460
}
8561

62+
_promotion_table.update({(b, a): c for (a, b), c in _promotion_table.items()})
63+
_promotion_table.update({(a, a): a for a in _array_api_dtypes})
64+
8665

8766
def _two_arg(f):
8867
@_wraps(f)
@@ -150,13 +129,18 @@ def result_type(
150129
return _reduce(_result_type, others + scalars)
151130

152131

153-
def _result_type(x, y):
132+
def _result_type(
133+
x: Array | DType | bool | int | float | complex,
134+
y: Array | DType | bool | int | float | complex,
135+
) -> DType:
154136
if not (isinstance(x, _py_scalars) or isinstance(y, _py_scalars)):
155-
xdt = x.dtype if not isinstance(x, torch.dtype) else x
156-
ydt = y.dtype if not isinstance(y, torch.dtype) else y
137+
xdt = x if isinstance(x, torch.dtype) else x.dtype
138+
ydt = y if isinstance(y, torch.dtype) else y.dtype
157139

158-
if (xdt, ydt) in _promotion_table:
140+
try:
159141
return _promotion_table[xdt, ydt]
142+
except KeyError:
143+
pass
160144

161145
# This doesn't result_type(dtype, dtype) for non-array API dtypes
162146
# because torch.result_type only accepts tensors. This does however, allow
@@ -301,27 +285,35 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
301285
out = torch.unsqueeze(out, a)
302286
return out
303287

288+
289+
def _sum_prod_no_axis(x: Array, dtype: DType | None) -> Array:
290+
"""
291+
Implements `sum(..., axis=())` and `prod(..., axis=())`.
292+
293+
Works around https://github.com/pytorch/pytorch/issues/29137
294+
"""
295+
if dtype is not None:
296+
return x.clone() if dtype == x.dtype else x.to(dtype)
297+
298+
# We can't upcast uint8 according to the spec because there is no
299+
# torch.uint64, so at least upcast to int64 which is what prod does
300+
# when axis=None.
301+
if x.dtype in (torch.uint8, torch.int8, torch.int16, torch.int32):
302+
return x.to(torch.int64)
303+
304+
return x.clone()
305+
306+
304307
def prod(x: Array,
305308
/,
306309
*,
307310
axis: Optional[Union[int, Tuple[int, ...]]] = None,
308311
dtype: Optional[DType] = None,
309312
keepdims: bool = False,
310313
**kwargs) -> Array:
311-
ndim = x.ndim
312314

313-
# https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
314-
# below because it still needs to upcast.
315315
if axis == ():
316-
if dtype is None:
317-
# We can't upcast uint8 according to the spec because there is no
318-
# torch.uint64, so at least upcast to int64 which is what sum does
319-
# when axis=None.
320-
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
321-
return x.to(torch.int64)
322-
return x.clone()
323-
return x.to(dtype)
324-
316+
return _sum_prod_no_axis(x, dtype)
325317
# torch.prod doesn't support multiple axes
326318
# (https://github.com/pytorch/pytorch/issues/56586).
327319
if isinstance(axis, tuple):
@@ -330,7 +322,7 @@ def prod(x: Array,
330322
# torch doesn't support keepdims with axis=None
331323
# (https://github.com/pytorch/pytorch/issues/71209)
332324
res = torch.prod(x, dtype=dtype, **kwargs)
333-
res = _axis_none_keepdims(res, ndim, keepdims)
325+
res = _axis_none_keepdims(res, x.ndim, keepdims)
334326
return res
335327

336328
return torch.prod(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
@@ -343,25 +335,14 @@ def sum(x: Array,
343335
dtype: Optional[DType] = None,
344336
keepdims: bool = False,
345337
**kwargs) -> Array:
346-
ndim = x.ndim
347338

348-
# https://github.com/pytorch/pytorch/issues/29137.
349-
# Make sure it upcasts.
350339
if axis == ():
351-
if dtype is None:
352-
# We can't upcast uint8 according to the spec because there is no
353-
# torch.uint64, so at least upcast to int64 which is what sum does
354-
# when axis=None.
355-
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.uint8]:
356-
return x.to(torch.int64)
357-
return x.clone()
358-
return x.to(dtype)
359-
340+
return _sum_prod_no_axis(x, dtype)
360341
if axis is None:
361342
# torch doesn't support keepdims with axis=None
362343
# (https://github.com/pytorch/pytorch/issues/71209)
363344
res = torch.sum(x, dtype=dtype, **kwargs)
364-
res = _axis_none_keepdims(res, ndim, keepdims)
345+
res = _axis_none_keepdims(res, x.ndim, keepdims)
365346
return res
366347

367348
return torch.sum(x, axis, dtype=dtype, keepdims=keepdims, **kwargs)
@@ -372,7 +353,7 @@ def any(x: Array,
372353
axis: Optional[Union[int, Tuple[int, ...]]] = None,
373354
keepdims: bool = False,
374355
**kwargs) -> Array:
375-
ndim = x.ndim
356+
376357
if axis == ():
377358
return x.to(torch.bool)
378359
# torch.any doesn't support multiple axes
@@ -384,7 +365,7 @@ def any(x: Array,
384365
# torch doesn't support keepdims with axis=None
385366
# (https://github.com/pytorch/pytorch/issues/71209)
386367
res = torch.any(x, **kwargs)
387-
res = _axis_none_keepdims(res, ndim, keepdims)
368+
res = _axis_none_keepdims(res, x.ndim, keepdims)
388369
return res.to(torch.bool)
389370

390371
# torch.any doesn't return bool for uint8
@@ -396,7 +377,7 @@ def all(x: Array,
396377
axis: Optional[Union[int, Tuple[int, ...]]] = None,
397378
keepdims: bool = False,
398379
**kwargs) -> Array:
399-
ndim = x.ndim
380+
400381
if axis == ():
401382
return x.to(torch.bool)
402383
# torch.all doesn't support multiple axes
@@ -408,7 +389,7 @@ def all(x: Array,
408389
# torch doesn't support keepdims with axis=None
409390
# (https://github.com/pytorch/pytorch/issues/71209)
410391
res = torch.all(x, **kwargs)
411-
res = _axis_none_keepdims(res, ndim, keepdims)
392+
res = _axis_none_keepdims(res, x.ndim, keepdims)
412393
return res.to(torch.bool)
413394

414395
# torch.all doesn't return bool for uint8

0 commit comments

Comments
 (0)