Skip to content

[DNM] ENH: CuPy multi-device support #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import inspect
from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Sequence, cast

from ._helpers import _check_device, array_namespace
from ._helpers import _device_ctx, array_namespace
from ._helpers import device as _get_device
from ._helpers import is_cupy_namespace as _is_cupy_namespace
from ._typing import Array, Device, DType, Namespace
Expand All @@ -32,8 +32,8 @@ def arange(
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
with _device_ctx(xp, device):
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)


def empty(
Expand All @@ -44,8 +44,8 @@ def empty(
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.empty(shape, dtype=dtype, **kwargs)
with _device_ctx(xp, device):
return xp.empty(shape, dtype=dtype, **kwargs)


def empty_like(
Expand All @@ -57,8 +57,8 @@ def empty_like(
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.empty_like(x, dtype=dtype, **kwargs)
with _device_ctx(xp, device, like=x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my education: why can the user pass a device= argument to a *_like function? Naively I'd have expected that the _like implies that the device matches that of x. But then you can also pass a dtype= which overrides the dtype of x, so by symmetry allowing a device= makes sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matching device of the input is what the user wants 99% of the time.
Using empty_like etc. on a different device can make some sense when preparing an output vessel that is filled from different sources. TBH, though, the main difference between empty and empty_like, besides convenience, is that the latter can easily duplicate a lazy (unknown) shape. Which frequently prevents masked updates, but that's a separate problem.

return xp.empty_like(x, dtype=dtype, **kwargs)


def eye(
Expand All @@ -72,8 +72,8 @@ def eye(
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
with _device_ctx(xp, device):
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)


def full(
Expand All @@ -85,8 +85,8 @@ def full(
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
with _device_ctx(xp, device):
return xp.full(shape, fill_value, dtype=dtype, **kwargs)


def full_like(
Expand All @@ -99,8 +99,8 @@ def full_like(
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
with _device_ctx(xp, device, like=x):
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)


def linspace(
Expand All @@ -115,8 +115,8 @@ def linspace(
endpoint: bool = True,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
with _device_ctx(xp, device):
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)


def ones(
Expand All @@ -127,8 +127,8 @@ def ones(
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.ones(shape, dtype=dtype, **kwargs)
with _device_ctx(xp, device):
return xp.ones(shape, dtype=dtype, **kwargs)


def ones_like(
Expand All @@ -140,8 +140,8 @@ def ones_like(
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.ones_like(x, dtype=dtype, **kwargs)
with _device_ctx(xp, device, like=x):
return xp.ones_like(x, dtype=dtype, **kwargs)


def zeros(
Expand All @@ -152,8 +152,8 @@ def zeros(
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.zeros(shape, dtype=dtype, **kwargs)
with _device_ctx(xp, device):
return xp.zeros(shape, dtype=dtype, **kwargs)


def zeros_like(
Expand All @@ -165,8 +165,8 @@ def zeros_like(
device: Device | None = None,
**kwargs: object,
) -> Array:
_check_device(xp, device)
return xp.zeros_like(x, dtype=dtype, **kwargs)
with _device_ctx(xp, device, like=x):
return xp.zeros_like(x, dtype=dtype, **kwargs)


# np.unique() is split into four functions in the array API:
Expand Down
47 changes: 32 additions & 15 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

import contextlib
import inspect
import math
import sys
Expand Down Expand Up @@ -657,26 +658,42 @@
get_namespace = array_namespace


def _check_device(bare_xp: Namespace, device: Device) -> None: # pyright: ignore[reportUnusedFunction]
"""
Validate dummy device on device-less array backends.
def _device_ctx(
bare_xp: Namespace, device: Device, like: Array | None = None
) -> Generator[None]:

Check failure on line 663 in array_api_compat/common/_helpers.py

View workflow job for this annotation

GitHub Actions / check-ruff

Ruff (F821)

array_api_compat/common/_helpers.py:663:6: F821 Undefined name `Generator`
"""Context manager which changes the current device in CuPy.

Notes
-----
This function is also invoked by CuPy, which does have multiple devices
if there are multiple GPUs available.
However, CuPy multi-device support is currently impossible
without using the global device or a context manager:

https://github.com/data-apis/array-api-compat/pull/293
Used internally by array creation functions in common._aliases.
"""
if bare_xp is sys.modules.get("numpy"):
if device not in ("cpu", None):
if device is None:
if like is None:
return contextlib.nullcontext()
device = _device(like)

if bare_xp is sys.modules.get('numpy'):
if device != "cpu":
raise ValueError(f"Unsupported device for NumPy: {device!r}")
return contextlib.nullcontext()

elif bare_xp is sys.modules.get("dask.array"):
if device not in ("cpu", _DASK_DEVICE, None):
if bare_xp is sys.modules.get('dask.array'):
if device not in ("cpu", _DASK_DEVICE):
raise ValueError(f"Unsupported device for Dask: {device!r}")
return contextlib.nullcontext()

if bare_xp is sys.modules.get('cupy'):
if not isinstance(device, bare_xp.cuda.Device):
raise TypeError(f"device is not a cupy.cuda.Device: {device!r}")
return device

# PyTorch doesn't have a "current device" context manager and you
# can't use array creation functions from common._aliases.
raise AssertionError("unreachable") # pragma: nocover


def _check_device(bare_xp: Namespace, device: Device) -> None:
"""Validate dummy device on device-less array backends."""
with _device_ctx(bare_xp, device):
pass


# Placeholder object to represent the dask device
Expand Down
3 changes: 2 additions & 1 deletion array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def asarray(
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
with cp.cuda.Device(device):
like = obj if isinstance(obj, cp.ndarray) else None
with _helpers._device_ctx(cp, device, like=like):
if copy is None:
return cp.asarray(obj, dtype=dtype, **kwargs)
else:
Expand Down
Loading