forked from data-apis/array-api-extra
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_backends.py
46 lines (36 loc) · 1.43 KB
/
_backends.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""Backends against which array-api-extra runs its tests."""
from __future__ import annotations
from enum import Enum
__all__ = ["Backend"]
class Backend(Enum): # numpydoc ignore=PR02
"""
All array library backends explicitly tested by array-api-extra.
Parameters
----------
value : str
Tag of the backend's module, in the format ``<namespace>[:<extra tag>]``.
"""
# Use :<tag> to prevent Enum from deduplicating items with the same value
ARRAY_API_STRICT = "array_api_strict"
ARRAY_API_STRICTEST = "array_api_strict:strictest"
NUMPY = "numpy"
NUMPY_READONLY = "numpy:readonly"
CUPY = "cupy"
TORCH = "torch"
TORCH_GPU = "torch:gpu"
DASK = "dask.array"
SPARSE = "sparse"
JAX = "jax.numpy"
JAX_GPU = "jax.numpy:gpu"
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01
"""Pretty-print parameterized test names."""
return (
self.name.lower().replace("_gpu", ":gpu").replace("_readonly", ":readonly")
)
@property
def modname(self) -> str: # numpydoc ignore=RT01
"""Module name to be imported."""
return self.value.split(":")[0]
def like(self, *others: Backend) -> bool: # numpydoc ignore=PR01,RT01
"""Check if this backend uses the same module as others."""
return any(self.modname == other.modname for other in others)