Skip to content

Commit 52e01be

Browse files
authored
ENH: cache helper functions (#308)
* ENH: cache helper functions
1 parent 5e14b53 commit 52e01be

File tree

1 file changed

+108
-84
lines changed

1 file changed

+108
-84
lines changed

array_api_compat/common/_helpers.py

+108-84
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
import math
1313
import sys
1414
import warnings
15-
from collections.abc import Collection
15+
from collections.abc import Collection, Hashable
16+
from functools import lru_cache
1617
from typing import (
1718
TYPE_CHECKING,
1819
Any,
@@ -61,23 +62,37 @@
6162
_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"})
6263

6364

65+
@lru_cache(100)
66+
def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool:
67+
try:
68+
mod = sys.modules[modname]
69+
except KeyError:
70+
return False
71+
parent_cls = getattr(mod, clsname)
72+
return issubclass(cls, parent_cls)
73+
74+
6475
def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]:
6576
"""Return True if `x` is a zero-gradient array.
6677
6778
These arrays are a design quirk of Jax that may one day be removed.
6879
See https://github.com/google/jax/issues/20620.
6980
"""
70-
if "numpy" not in sys.modules or "jax" not in sys.modules:
81+
# Fast exit
82+
try:
83+
dtype = x.dtype # type: ignore[attr-defined]
84+
except AttributeError:
85+
return False
86+
cls = cast(Hashable, type(dtype))
87+
if not _issubclass_fast(cls, "numpy.dtypes", "VoidDType"):
7188
return False
7289

73-
import jax
74-
import numpy as np
90+
if "jax" not in sys.modules:
91+
return False
7592

76-
jax_float0 = cast("np.dtype[np.void]", jax.float0)
77-
return (
78-
isinstance(x, np.ndarray)
79-
and cast("npt.NDArray[np.void]", x).dtype == jax_float0
80-
)
93+
import jax
94+
# jax.float0 is a np.dtype([('float0', 'V')])
95+
return dtype == jax.float0
8196

8297

8398
def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
@@ -101,15 +116,12 @@ def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
101116
is_jax_array
102117
is_pydata_sparse_array
103118
"""
104-
# Avoid importing NumPy if it isn't already
105-
if "numpy" not in sys.modules:
106-
return False
107-
108-
import numpy as np
109-
110119
# TODO: Should we reject ndarray subclasses?
111-
return (isinstance(x, (np.ndarray, np.generic))
112-
and not _is_jax_zero_gradient_array(x)) # pyright: ignore[reportUnknownArgumentType] # fmt: skip
120+
cls = cast(Hashable, type(x))
121+
return (
122+
_issubclass_fast(cls, "numpy", "ndarray")
123+
or _issubclass_fast(cls, "numpy", "generic")
124+
) and not _is_jax_zero_gradient_array(x)
113125

114126

115127
def is_cupy_array(x: object) -> bool:
@@ -133,14 +145,8 @@ def is_cupy_array(x: object) -> bool:
133145
is_jax_array
134146
is_pydata_sparse_array
135147
"""
136-
# Avoid importing CuPy if it isn't already
137-
if "cupy" not in sys.modules:
138-
return False
139-
140-
import cupy as cp # pyright: ignore[reportMissingTypeStubs]
141-
142-
# TODO: Should we reject ndarray subclasses?
143-
return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType]
148+
cls = cast(Hashable, type(x))
149+
return _issubclass_fast(cls, "cupy", "ndarray")
144150

145151

146152
def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
@@ -161,14 +167,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
161167
is_jax_array
162168
is_pydata_sparse_array
163169
"""
164-
# Avoid importing torch if it isn't already
165-
if "torch" not in sys.modules:
166-
return False
167-
168-
import torch
169-
170-
# TODO: Should we reject ndarray subclasses?
171-
return isinstance(x, torch.Tensor)
170+
cls = cast(Hashable, type(x))
171+
return _issubclass_fast(cls, "torch", "Tensor")
172172

173173

174174
def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
@@ -190,13 +190,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
190190
is_jax_array
191191
is_pydata_sparse_array
192192
"""
193-
# Avoid importing torch if it isn't already
194-
if "ndonnx" not in sys.modules:
195-
return False
196-
197-
import ndonnx as ndx
198-
199-
return isinstance(x, ndx.Array)
193+
cls = cast(Hashable, type(x))
194+
return _issubclass_fast(cls, "ndonnx", "Array")
200195

201196

202197
def is_dask_array(x: object) -> TypeIs[da.Array]:
@@ -218,13 +213,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]:
218213
is_jax_array
219214
is_pydata_sparse_array
220215
"""
221-
# Avoid importing dask if it isn't already
222-
if "dask.array" not in sys.modules:
223-
return False
224-
225-
import dask.array
226-
227-
return isinstance(x, dask.array.Array)
216+
cls = cast(Hashable, type(x))
217+
return _issubclass_fast(cls, "dask.array", "Array")
228218

229219

230220
def is_jax_array(x: object) -> TypeIs[jax.Array]:
@@ -247,13 +237,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
247237
is_dask_array
248238
is_pydata_sparse_array
249239
"""
250-
# Avoid importing jax if it isn't already
251-
if "jax" not in sys.modules:
252-
return False
253-
254-
import jax
255-
256-
return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x)
240+
cls = cast(Hashable, type(x))
241+
return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x)
257242

258243

259244
def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
@@ -276,14 +261,9 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
276261
is_dask_array
277262
is_jax_array
278263
"""
279-
# Avoid importing jax if it isn't already
280-
if "sparse" not in sys.modules:
281-
return False
282-
283-
import sparse # pyright: ignore[reportMissingTypeStubs]
284-
285264
# TODO: Account for other backends.
286-
return isinstance(x, sparse.SparseArray)
265+
cls = cast(Hashable, type(x))
266+
return _issubclass_fast(cls, "sparse", "SparseArray")
287267

288268

289269
def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType]
@@ -302,13 +282,23 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo
302282
is_jax_array
303283
"""
304284
return (
305-
is_numpy_array(x)
306-
or is_cupy_array(x)
307-
or is_torch_array(x)
308-
or is_dask_array(x)
309-
or is_jax_array(x)
310-
or is_pydata_sparse_array(x)
311-
or hasattr(x, "__array_namespace__")
285+
hasattr(x, '__array_namespace__')
286+
or _is_array_api_cls(cast(Hashable, type(x)))
287+
)
288+
289+
290+
@lru_cache(100)
291+
def _is_array_api_cls(cls: type) -> bool:
292+
return (
293+
# TODO: drop support for numpy<2 which didn't have __array_namespace__
294+
_issubclass_fast(cls, "numpy", "ndarray")
295+
or _issubclass_fast(cls, "numpy", "generic")
296+
or _issubclass_fast(cls, "cupy", "ndarray")
297+
or _issubclass_fast(cls, "torch", "Tensor")
298+
or _issubclass_fast(cls, "dask.array", "Array")
299+
or _issubclass_fast(cls, "sparse", "SparseArray")
300+
# TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
301+
or _issubclass_fast(cls, "jax", "Array")
312302
)
313303

314304

@@ -317,6 +307,7 @@ def _compat_module_name() -> str:
317307
return __name__.removesuffix(".common._helpers")
318308

319309

310+
@lru_cache(100)
320311
def is_numpy_namespace(xp: Namespace) -> bool:
321312
"""
322313
Returns True if `xp` is a NumPy namespace.
@@ -338,6 +329,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
338329
return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"}
339330

340331

332+
@lru_cache(100)
341333
def is_cupy_namespace(xp: Namespace) -> bool:
342334
"""
343335
Returns True if `xp` is a CuPy namespace.
@@ -359,6 +351,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
359351
return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"}
360352

361353

354+
@lru_cache(100)
362355
def is_torch_namespace(xp: Namespace) -> bool:
363356
"""
364357
Returns True if `xp` is a PyTorch namespace.
@@ -399,6 +392,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
399392
return xp.__name__ == "ndonnx"
400393

401394

395+
@lru_cache(100)
402396
def is_dask_namespace(xp: Namespace) -> bool:
403397
"""
404398
Returns True if `xp` is a Dask namespace.
@@ -939,6 +933,19 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
939933
return None if math.isnan(out) else out
940934

941935

936+
@lru_cache(100)
937+
def _is_writeable_cls(cls: type) -> bool | None:
938+
if (
939+
_issubclass_fast(cls, "numpy", "generic")
940+
or _issubclass_fast(cls, "jax", "Array")
941+
or _issubclass_fast(cls, "sparse", "SparseArray")
942+
):
943+
return False
944+
if _is_array_api_cls(cls):
945+
return True
946+
return None
947+
948+
942949
def is_writeable_array(x: object) -> bool:
943950
"""
944951
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
@@ -949,11 +956,32 @@ def is_writeable_array(x: object) -> bool:
949956
As there is no standard way to check if an array is writeable without actually
950957
writing to it, this function blindly returns True for all unknown array types.
951958
"""
952-
if is_numpy_array(x):
953-
return x.flags.writeable
954-
if is_jax_array(x) or is_pydata_sparse_array(x):
959+
cls = cast(Hashable, type(x))
960+
if _issubclass_fast(cls, "numpy", "ndarray"):
961+
return cast("npt.NDArray", x).flags.writeable
962+
res = _is_writeable_cls(cls)
963+
if res is not None:
964+
return res
965+
return hasattr(x, '__array_namespace__')
966+
967+
968+
@lru_cache(100)
969+
def _is_lazy_cls(cls: type) -> bool | None:
970+
if (
971+
_issubclass_fast(cls, "numpy", "ndarray")
972+
or _issubclass_fast(cls, "numpy", "generic")
973+
or _issubclass_fast(cls, "cupy", "ndarray")
974+
or _issubclass_fast(cls, "torch", "Tensor")
975+
or _issubclass_fast(cls, "sparse", "SparseArray")
976+
):
955977
return False
956-
return is_array_api_obj(x)
978+
if (
979+
_issubclass_fast(cls, "jax", "Array")
980+
or _issubclass_fast(cls, "dask.array", "Array")
981+
or _issubclass_fast(cls, "ndonnx", "Array")
982+
):
983+
return True
984+
return None
957985

958986

959987
def is_lazy_array(x: object) -> bool:
@@ -969,14 +997,6 @@ def is_lazy_array(x: object) -> bool:
969997
This function errs on the side of caution for array types that may or may not be
970998
lazy, e.g. JAX arrays, by always returning True for them.
971999
"""
972-
if (
973-
is_numpy_array(x)
974-
or is_cupy_array(x)
975-
or is_torch_array(x)
976-
or is_pydata_sparse_array(x)
977-
):
978-
return False
979-
9801000
# **JAX note:** while it is possible to determine if you're inside or outside
9811001
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
9821002
# as we do below for unknown arrays, this is not recommended by JAX best practices.
@@ -986,10 +1006,14 @@ def is_lazy_array(x: object) -> bool:
9861006
# compatibility, is highly detrimental to performance as the whole graph will end
9871007
# up being computed multiple times.
9881008

989-
if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x):
990-
return True
1009+
# Note: skipping reclassification of JAX zero gradient arrays, as one will
1010+
# exclusively get them once they leave a jax.grad JIT context.
1011+
cls = cast(Hashable, type(x))
1012+
res = _is_lazy_cls(cls)
1013+
if res is not None:
1014+
return res
9911015

992-
if not is_array_api_obj(x):
1016+
if not hasattr(x, "__array_namespace__"):
9931017
return False
9941018

9951019
# Unknown Array API compatible object. Note that this test may have dire consequences
@@ -1042,7 +1066,7 @@ def is_lazy_array(x: object) -> bool:
10421066
"to_device",
10431067
]
10441068

1045-
_all_ignore = ["sys", "math", "inspect", "warnings"]
1069+
_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings']
10461070

10471071
def __dir__() -> list[str]:
10481072
return __all__

0 commit comments

Comments
 (0)