Skip to content

Commit 621494b

Browse files
authored
ENH: torch.asarray device propagation (#299)
1 parent b6900df commit 621494b

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

array_api_compat/torch/_aliases.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
from functools import reduce as _reduce, wraps as _wraps
44
from builtins import all as _builtin_all, any as _builtin_any
5-
from typing import List, Optional, Sequence, Tuple, Union
5+
from typing import Any, List, Optional, Sequence, Tuple, Union
66

77
import torch
88

99
from .._internal import get_xp
1010
from ..common import _aliases
11+
from ..common._typing import NestedSequence, SupportsBufferProtocol
1112
from ._info import __array_namespace_info__
1213
from ._typing import Array, Device, DType
1314

@@ -207,6 +208,28 @@ def can_cast(from_: Union[DType, Array], to: DType, /) -> bool:
207208
remainder = _two_arg(torch.remainder)
208209
subtract = _two_arg(torch.subtract)
209210

211+
212+
def asarray(
213+
obj: (
214+
Array
215+
| bool | int | float | complex
216+
| NestedSequence[bool | int | float | complex]
217+
| SupportsBufferProtocol
218+
),
219+
/,
220+
*,
221+
dtype: DType | None = None,
222+
device: Device | None = None,
223+
copy: bool | None = None,
224+
**kwargs: Any,
225+
) -> Array:
226+
# torch.asarray does not respect input->output device propagation
227+
# https://github.com/pytorch/pytorch/issues/150199
228+
if device is None and isinstance(obj, torch.Tensor):
229+
device = obj.device
230+
return torch.asarray(obj, dtype=dtype, device=device, copy=copy, **kwargs)
231+
232+
210233
# These wrappers are mostly based on the fact that pytorch uses 'dim' instead
211234
# of 'axis'.
212235

@@ -282,7 +305,6 @@ def prod(x: Array,
282305
dtype: Optional[DType] = None,
283306
keepdims: bool = False,
284307
**kwargs) -> Array:
285-
x = torch.asarray(x)
286308
ndim = x.ndim
287309

288310
# https://github.com/pytorch/pytorch/issues/29137. Separate from the logic
@@ -318,7 +340,6 @@ def sum(x: Array,
318340
dtype: Optional[DType] = None,
319341
keepdims: bool = False,
320342
**kwargs) -> Array:
321-
x = torch.asarray(x)
322343
ndim = x.ndim
323344

324345
# https://github.com/pytorch/pytorch/issues/29137.
@@ -348,7 +369,6 @@ def any(x: Array,
348369
axis: Optional[Union[int, Tuple[int, ...]]] = None,
349370
keepdims: bool = False,
350371
**kwargs) -> Array:
351-
x = torch.asarray(x)
352372
ndim = x.ndim
353373
if axis == ():
354374
return x.to(torch.bool)
@@ -373,7 +393,6 @@ def all(x: Array,
373393
axis: Optional[Union[int, Tuple[int, ...]]] = None,
374394
keepdims: bool = False,
375395
**kwargs) -> Array:
376-
x = torch.asarray(x)
377396
ndim = x.ndim
378397
if axis == ():
379398
return x.to(torch.bool)
@@ -816,7 +835,7 @@ def sign(x: Array, /) -> Array:
816835
return out
817836

818837

819-
__all__ = ['__array_namespace_info__', 'result_type', 'can_cast',
838+
__all__ = ['__array_namespace_info__', 'asarray', 'result_type', 'can_cast',
820839
'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add',
821840
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
822841
'bitwise_right_shift', 'bitwise_xor', 'copysign', 'count_nonzero',

array_api_compat/torch/_typing.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
__all__ = ["Array", "DType", "Device"]
1+
__all__ = ["Array", "Device", "DType"]
22

3-
from torch import dtype as DType, Tensor as Array
4-
from ..common._typing import Device
3+
from torch import device as Device, dtype as DType, Tensor as Array

0 commit comments

Comments
 (0)