Skip to content

Commit 2fc1511

Browse files
committed
MAINT: don't use Backend for delegation
1 parent c1cd43e commit 2fc1511

File tree

9 files changed

+46
-75
lines changed

9 files changed

+46
-75
lines changed

src/array_api_extra/_delegation.py

+26-27
Original file line numberDiff line numberDiff line change
@@ -4,33 +4,22 @@
44
from types import ModuleType
55
from typing import Literal
66

7-
from ._lib import Backend, _funcs
8-
from ._lib._utils._compat import array_namespace
7+
from ._lib import _funcs
8+
from ._lib._utils._compat import (
9+
array_namespace,
10+
is_cupy_namespace,
11+
is_dask_namespace,
12+
is_jax_namespace,
13+
is_numpy_namespace,
14+
is_pydata_sparse_namespace,
15+
is_torch_namespace,
16+
)
917
from ._lib._utils._helpers import asarrays
1018
from ._lib._utils._typing import Array
1119

1220
__all__ = ["isclose", "pad"]
1321

1422

15-
def _delegate(xp: ModuleType, *backends: Backend) -> bool:
16-
"""
17-
Check whether `xp` is one of the `backends` to delegate to.
18-
19-
Parameters
20-
----------
21-
xp : array_namespace
22-
Array namespace to check.
23-
*backends : IsNamespace
24-
Arbitrarily many backends (from the ``IsNamespace`` enum) to check.
25-
26-
Returns
27-
-------
28-
bool
29-
``True`` if `xp` matches one of the `backends`, ``False`` otherwise.
30-
"""
31-
return any(backend.is_namespace(xp) for backend in backends)
32-
33-
3423
def isclose(
3524
a: Array | complex,
3625
b: Array | complex,
@@ -108,10 +97,15 @@ def isclose(
10897
"""
10998
xp = array_namespace(a, b) if xp is None else xp
11099

111-
if _delegate(xp, Backend.NUMPY, Backend.CUPY, Backend.DASK, Backend.JAX):
100+
if (
101+
is_numpy_namespace(xp)
102+
or is_cupy_namespace(xp)
103+
or is_dask_namespace(xp)
104+
or is_jax_namespace(xp)
105+
):
112106
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
113107

114-
if _delegate(xp, Backend.TORCH):
108+
if is_torch_namespace(xp):
115109
a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support
116110
return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
117111

@@ -159,14 +153,19 @@ def pad(
159153
msg = "Only `'constant'` mode is currently supported"
160154
raise NotImplementedError(msg)
161155

156+
if (
157+
is_numpy_namespace(xp)
158+
or is_cupy_namespace(xp)
159+
or is_jax_namespace(xp)
160+
or is_pydata_sparse_namespace(xp)
161+
):
162+
return xp.pad(x, pad_width, mode, constant_values=constant_values)
163+
162164
# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
163-
if _delegate(xp, Backend.TORCH):
165+
if is_torch_namespace(xp):
164166
pad_width = xp.asarray(pad_width)
165167
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
166168
pad_width = xp.flip(pad_width, axis=(0,)).flatten()
167169
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
168170

169-
if _delegate(xp, Backend.NUMPY, Backend.JAX, Backend.CUPY, Backend.SPARSE):
170-
return xp.pad(x, pad_width, mode, constant_values=constant_values)
171-
172171
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)

src/array_api_extra/_lib/__init__.py

-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1 @@
11
"""Internals of array-api-extra."""
2-
3-
from ._backends import Backend
4-
5-
__all__ = ["Backend"]

src/array_api_extra/_lib/_backends.py

+14-38
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,34 @@
1-
"""Backends with which array-api-extra interacts in delegation and testing."""
1+
"""Backends with which array-api-extra runs its tests against."""
22

33
from __future__ import annotations
44

5-
from collections.abc import Callable
65
from enum import Enum
7-
from types import ModuleType
8-
9-
from ._utils import _compat
106

117
__all__ = ["Backend"]
128

139

14-
class Backend(Enum): # numpydoc ignore=PR01,PR02 # type: ignore[no-subclass-any]
10+
class Backend(Enum): # numpydoc ignore=PR02
1511
"""
1612
All array library backends explicitly tested by array-api-extra.
1713
1814
Parameters
1915
----------
2016
value : str
21-
Name of the backend's module.
22-
is_namespace : Callable[[ModuleType], bool]
23-
Function to check whether an input module is the array namespace
24-
corresponding to the backend.
17+
Tag of the backend's module, in the format ``<namespace>[:<extra tag>]``.
2518
"""
2619

2720
# Use :<tag> to prevent Enum from deduplicating items with the same value
28-
ARRAY_API_STRICT = "array_api_strict", _compat.is_array_api_strict_namespace
29-
ARRAY_API_STRICTEST = (
30-
"array_api_strict:strictest",
31-
_compat.is_array_api_strict_namespace,
32-
)
33-
NUMPY = "numpy", _compat.is_numpy_namespace
34-
NUMPY_READONLY = "numpy:readonly", _compat.is_numpy_namespace
35-
CUPY = "cupy", _compat.is_cupy_namespace
36-
TORCH = "torch", _compat.is_torch_namespace
37-
TORCH_GPU = "torch:gpu", _compat.is_torch_namespace
38-
DASK = "dask.array", _compat.is_dask_namespace
39-
SPARSE = "sparse", _compat.is_pydata_sparse_namespace
40-
JAX = "jax.numpy", _compat.is_jax_namespace
41-
JAX_GPU = "jax.numpy:gpu", _compat.is_jax_namespace
42-
43-
def __new__(
44-
cls, value: str, _is_namespace: Callable[[ModuleType], bool]
45-
): # numpydoc ignore=GL08
46-
obj = object.__new__(cls)
47-
obj._value_ = value
48-
return obj
49-
50-
def __init__(
51-
self,
52-
value: str, # noqa: ARG002 # pylint: disable=unused-argument
53-
is_namespace: Callable[[ModuleType], bool],
54-
): # numpydoc ignore=GL08
55-
self.is_namespace = is_namespace
21+
ARRAY_API_STRICT = "array_api_strict"
22+
ARRAY_API_STRICTEST = "array_api_strict:strictest"
23+
NUMPY = "numpy"
24+
NUMPY_READONLY = "numpy:readonly"
25+
CUPY = "cupy"
26+
TORCH = "torch"
27+
TORCH_GPU = "torch:gpu"
28+
DASK = "dask.array"
29+
SPARSE = "sparse"
30+
JAX = "jax.numpy"
31+
JAX_GPU = "jax.numpy:gpu"
5632

5733
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
5834
"""Pretty-print parameterized test names."""

tests/conftest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import numpy as np
1010
import pytest
1111

12-
from array_api_extra._lib import Backend
12+
from array_api_extra._lib._backends import Backend
1313
from array_api_extra._lib._testing import xfail
1414
from array_api_extra._lib._utils._compat import array_namespace
1515
from array_api_extra._lib._utils._compat import device as get_device

tests/test_at.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import pytest
1010

1111
from array_api_extra import at
12-
from array_api_extra._lib import Backend
1312
from array_api_extra._lib._at import _AtOp
13+
from array_api_extra._lib._backends import Backend
1414
from array_api_extra._lib._testing import xp_assert_equal
1515
from array_api_extra._lib._utils._compat import array_namespace, is_writeable_array
1616
from array_api_extra._lib._utils._compat import device as get_device

tests/test_funcs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
setdiff1d,
2626
sinc,
2727
)
28-
from array_api_extra._lib import Backend
28+
from array_api_extra._lib._backends import Backend
2929
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
3030
from array_api_extra._lib._utils._compat import device as get_device
3131
from array_api_extra._lib._utils._helpers import eager_shape, ndindex

tests/test_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pytest
66

7-
from array_api_extra._lib import Backend
7+
from array_api_extra._lib._backends import Backend
88
from array_api_extra._lib._testing import xp_assert_equal
99
from array_api_extra._lib._utils._compat import array_namespace
1010
from array_api_extra._lib._utils._compat import device as get_device

tests/test_lazy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import array_api_extra as xpx # Let some tests bypass lazy_xp_function
99
from array_api_extra import lazy_apply
10-
from array_api_extra._lib import Backend
10+
from array_api_extra._lib._backends import Backend
1111
from array_api_extra._lib._testing import xp_assert_equal
1212
from array_api_extra._lib._utils import _compat
1313
from array_api_extra._lib._utils._compat import array_namespace, is_dask_array

tests/test_testing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
import pytest
77

8-
from array_api_extra._lib import Backend
8+
from array_api_extra._lib._backends import Backend
99
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
1010
from array_api_extra._lib._utils._compat import (
1111
array_namespace,

0 commit comments

Comments
 (0)