|
4 | 4 | from types import ModuleType
|
5 | 5 | from typing import Literal
|
6 | 6 |
|
7 |
| -from ._lib import Backend, _funcs |
8 |
| -from ._lib._utils._compat import array_namespace |
| 7 | +from ._lib import _funcs |
| 8 | +from ._lib._utils._compat import ( |
| 9 | + array_namespace, |
| 10 | + is_cupy_namespace, |
| 11 | + is_dask_namespace, |
| 12 | + is_jax_namespace, |
| 13 | + is_numpy_namespace, |
| 14 | + is_pydata_sparse_namespace, |
| 15 | + is_torch_namespace, |
| 16 | +) |
9 | 17 | from ._lib._utils._helpers import asarrays
|
10 | 18 | from ._lib._utils._typing import Array
|
11 | 19 |
|
12 | 20 | __all__ = ["isclose", "pad"]
|
13 | 21 |
|
14 | 22 |
|
15 |
| -def _delegate(xp: ModuleType, *backends: Backend) -> bool: |
16 |
| - """ |
17 |
| - Check whether `xp` is one of the `backends` to delegate to. |
18 |
| -
|
19 |
| - Parameters |
20 |
| - ---------- |
21 |
| - xp : array_namespace |
22 |
| - Array namespace to check. |
23 |
| - *backends : IsNamespace |
24 |
| - Arbitrarily many backends (from the ``IsNamespace`` enum) to check. |
25 |
| -
|
26 |
| - Returns |
27 |
| - ------- |
28 |
| - bool |
29 |
| - ``True`` if `xp` matches one of the `backends`, ``False`` otherwise. |
30 |
| - """ |
31 |
| - return any(backend.is_namespace(xp) for backend in backends) |
32 |
| - |
33 |
| - |
34 | 23 | def isclose(
|
35 | 24 | a: Array | complex,
|
36 | 25 | b: Array | complex,
|
@@ -108,10 +97,15 @@ def isclose(
|
108 | 97 | """
|
109 | 98 | xp = array_namespace(a, b) if xp is None else xp
|
110 | 99 |
|
111 |
| - if _delegate(xp, Backend.NUMPY, Backend.CUPY, Backend.DASK, Backend.JAX): |
| 100 | + if ( |
| 101 | + is_numpy_namespace(xp) |
| 102 | + or is_cupy_namespace(xp) |
| 103 | + or is_dask_namespace(xp) |
| 104 | + or is_jax_namespace(xp) |
| 105 | + ): |
112 | 106 | return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
113 | 107 |
|
114 |
| - if _delegate(xp, Backend.TORCH): |
| 108 | + if is_torch_namespace(xp): |
115 | 109 | a, b = asarrays(a, b, xp=xp) # Array API 2024.12 support
|
116 | 110 | return xp.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
117 | 111 |
|
@@ -159,14 +153,19 @@ def pad(
|
159 | 153 | msg = "Only `'constant'` mode is currently supported"
|
160 | 154 | raise NotImplementedError(msg)
|
161 | 155 |
|
| 156 | + if ( |
| 157 | + is_numpy_namespace(xp) |
| 158 | + or is_cupy_namespace(xp) |
| 159 | + or is_jax_namespace(xp) |
| 160 | + or is_pydata_sparse_namespace(xp) |
| 161 | + ): |
| 162 | + return xp.pad(x, pad_width, mode, constant_values=constant_values) |
| 163 | + |
162 | 164 | # https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056
|
163 |
| - if _delegate(xp, Backend.TORCH): |
| 165 | + if is_torch_namespace(xp): |
164 | 166 | pad_width = xp.asarray(pad_width)
|
165 | 167 | pad_width = xp.broadcast_to(pad_width, (x.ndim, 2))
|
166 | 168 | pad_width = xp.flip(pad_width, axis=(0,)).flatten()
|
167 | 169 | return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
|
168 | 170 |
|
169 |
| - if _delegate(xp, Backend.NUMPY, Backend.JAX, Backend.CUPY, Backend.SPARSE): |
170 |
| - return xp.pad(x, pad_width, mode, constant_values=constant_values) |
171 |
| - |
172 | 171 | return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
|
0 commit comments