diff --git a/src/array_api_extra/_delegation.py b/src/array_api_extra/_delegation.py index bb11b7ee..b52c23ae 100644 --- a/src/array_api_extra/_delegation.py +++ b/src/array_api_extra/_delegation.py @@ -4,33 +4,22 @@ from types import ModuleType from typing import Literal -from ._lib import Backend, _funcs -from ._lib._utils._compat import array_namespace +from ._lib import _funcs +from ._lib._utils._compat import ( + array_namespace, + is_cupy_namespace, + is_dask_namespace, + is_jax_namespace, + is_numpy_namespace, + is_pydata_sparse_namespace, + is_torch_namespace, +) from ._lib._utils._helpers import asarrays from ._lib._utils._typing import Array __all__ = ["isclose", "pad"] -def _delegate(xp: ModuleType, *backends: Backend) -> bool: - """ - Check whether `xp` is one of the `backends` to delegate to. - - Parameters - ---------- - xp : array_namespace - Array namespace to check. - *backends : IsNamespace - Arbitrarily many backends (from the ``IsNamespace`` enum) to check. - - Returns - ------- - bool - ``True`` if `xp` matches one of the `backends`, ``False`` otherwise. - """ - return any(backend.is_namespace(xp) for backend in backends) - - def isclose( a: Array | complex, b: Array | complex, @@ -108,10 +97,15 @@ def isclose( """ xp = array_namespace(a, b) if xp is None else xp - if _delegate(xp, Backend.NUMPY, Backend.CUPY, Backend.DASK, Backend.JAX): + if ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_dask_namespace(xp) + or is_jax_namespace(xp) + ): return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) - if _delegate(xp, Backend.TORCH): + if is_torch_namespace(xp): a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) @@ -159,14 +153,19 @@ def pad( msg = "Only `'constant'` mode is currently supported" raise NotImplementedError(msg) + if ( + is_numpy_namespace(xp) + or is_cupy_namespace(xp) + or is_jax_namespace(xp) + or is_pydata_sparse_namespace(xp) + ): + return xp.pad(x, pad_width, mode, constant_values=constant_values) + # https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056 - if _delegate(xp, Backend.TORCH): + if is_torch_namespace(xp): pad_width = xp.asarray(pad_width) pad_width = xp.broadcast_to(pad_width, (x.ndim, 2)) pad_width = xp.flip(pad_width, axis=(0,)).flatten() return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - if _delegate(xp, Backend.NUMPY, Backend.JAX, Backend.CUPY, Backend.SPARSE): - return xp.pad(x, pad_width, mode, constant_values=constant_values) - return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) diff --git a/src/array_api_extra/_lib/__init__.py b/src/array_api_extra/_lib/__init__.py index b83d7e8c..d7b32033 100644 --- a/src/array_api_extra/_lib/__init__.py +++ b/src/array_api_extra/_lib/__init__.py @@ -1,5 +1 @@ """Internals of array-api-extra.""" - -from ._backends import Backend - -__all__ = ["Backend"] diff --git a/src/array_api_extra/_lib/_backends.py b/src/array_api_extra/_lib/_backends.py index e30afd55..f64e1479 100644 --- a/src/array_api_extra/_lib/_backends.py +++ b/src/array_api_extra/_lib/_backends.py @@ -1,58 +1,34 @@ -"""Backends with which array-api-extra interacts in delegation and testing.""" +"""Backends against which array-api-extra runs its tests.""" from __future__ import annotations -from collections.abc import Callable from enum import Enum -from types import ModuleType - -from ._utils import _compat __all__ = ["Backend"] -class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-any] +class Backend(Enum): # numpydoc ignore=PR02 """ All array library backends explicitly tested by array-api-extra. Parameters ---------- value : str - Name of the backend's module. - is_namespace : Callable[[ModuleType], bool] - Function to check whether an input module is the array namespace - corresponding to the backend. + Tag of the backend's module, in the format ``[:]``. """ # Use : to prevent Enum from deduplicating items with the same value - ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace - ARRAY_API_STRICTEST = ( - "array_api_strict:strictest", - _compat.is_array_api_strict_namespace, - ) - NUMPY = "numpy", _compat.is_numpy_namespace - NUMPY_READONLY = "numpy:readonly", _compat.is_numpy_namespace - CUPY = "cupy", _compat.is_cupy_namespace - TORCH = "torch", _compat.is_torch_namespace - TORCH_GPU = "torch:gpu", _compat.is_torch_namespace - DASK = "dask.array", _compat.is_dask_namespace - SPARSE = "sparse", _compat.is_pydata_sparse_namespace - JAX = "jax.numpy", _compat.is_jax_namespace - JAX_GPU = "jax.numpy:gpu", _compat.is_jax_namespace - - def __new__( - cls, value: str, _is_namespace: Callable[[ModuleType], bool] - ): # numpydoc ignore=GL08 - obj = object.__new__(cls) - obj._value_ = value - return obj - - def __init__( - self, - value: str, # noqa: ARG002 # pylint: disable=unused-argument - is_namespace: Callable[[ModuleType], bool], - ): # numpydoc ignore=GL08 - self.is_namespace = is_namespace + ARRAY_API_STRICT = "array_api_strict" + ARRAY_API_STRICTEST = "array_api_strict:strictest" + NUMPY = "numpy" + NUMPY_READONLY = "numpy:readonly" + CUPY = "cupy" + TORCH = "torch" + TORCH_GPU = "torch:gpu" + DASK = "dask.array" + SPARSE = "sparse" + JAX = "jax.numpy" + JAX_GPU = "jax.numpy:gpu" def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01 """Pretty-print parameterized test names.""" diff --git a/tests/conftest.py b/tests/conftest.py index 70854249..410a87ff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,7 @@ import numpy as np import pytest -from array_api_extra._lib import Backend +from array_api_extra._lib._backends import Backend from array_api_extra._lib._testing import xfail from array_api_extra._lib._utils._compat import array_namespace from array_api_extra._lib._utils._compat import device as get_device diff --git a/tests/test_at.py b/tests/test_at.py index 926685cb..4ccf584e 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -9,8 +9,8 @@ import pytest from array_api_extra import at -from array_api_extra._lib import Backend from array_api_extra._lib._at import _AtOp +from array_api_extra._lib._backends import Backend from array_api_extra._lib._testing import xp_assert_equal from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array from array_api_extra._lib._utils._compat import device as get_device diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 553df5dc..4e40f09b 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -25,7 +25,7 @@ setdiff1d, sinc, ) -from array_api_extra._lib import Backend +from array_api_extra._lib._backends import Backend from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal from array_api_extra._lib._utils._compat import device as get_device from array_api_extra._lib._utils._helpers import eager_shape, ndindex diff --git a/tests/test_helpers.py b/tests/test_helpers.py index c7d271ca..ebd4811f 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -4,7 +4,7 @@ import numpy as np import pytest -from array_api_extra._lib import Backend +from array_api_extra._lib._backends import Backend from array_api_extra._lib._testing import xp_assert_equal from array_api_extra._lib._utils._compat import array_namespace from array_api_extra._lib._utils._compat import device as get_device diff --git a/tests/test_lazy.py b/tests/test_lazy.py index 8690c33e..f40df277 100644 --- a/tests/test_lazy.py +++ b/tests/test_lazy.py @@ -7,7 +7,7 @@ import array_api_extra as xpx # Let some tests bypass lazy_xp_function from array_api_extra import lazy_apply -from array_api_extra._lib import Backend +from array_api_extra._lib._backends import Backend from array_api_extra._lib._testing import xp_assert_equal from array_api_extra._lib._utils import _compat from array_api_extra._lib._utils._compat import array_namespace, is_dask_array diff --git a/tests/test_testing.py b/tests/test_testing.py index 9976e6fd..ff67121b 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from array_api_extra._lib import Backend +from array_api_extra._lib._backends import Backend from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal from array_api_extra._lib._utils._compat import ( array_namespace,