12
12
import math
13
13
import sys
14
14
import warnings
15
- from collections .abc import Collection
15
+ from collections .abc import Collection , Hashable
16
+ from functools import lru_cache
16
17
from typing import (
17
18
TYPE_CHECKING ,
18
19
Any ,
61
62
_API_VERSIONS : Final = _API_VERSIONS_OLD | frozenset ({"2024.12" })
62
63
63
64
65
+ @lru_cache (100 )
66
+ def _issubclass_fast (cls : type , modname : str , clsname : str ) -> bool :
67
+ try :
68
+ mod = sys .modules [modname ]
69
+ except KeyError :
70
+ return False
71
+ parent_cls = getattr (mod , clsname )
72
+ return issubclass (cls , parent_cls )
73
+
74
+
64
75
def _is_jax_zero_gradient_array (x : object ) -> TypeGuard [_ZeroGradientArray ]:
65
76
"""Return True if `x` is a zero-gradient array.
66
77
67
78
These arrays are a design quirk of Jax that may one day be removed.
68
79
See https://github.com/google/jax/issues/20620.
69
80
"""
70
- if "numpy" not in sys .modules or "jax" not in sys .modules :
81
+ # Fast exit
82
+ try :
83
+ dtype = x .dtype # type: ignore[attr-defined]
84
+ except AttributeError :
85
+ return False
86
+ cls = cast (Hashable , type (dtype ))
87
+ if not _issubclass_fast (cls , "numpy.dtypes" , "VoidDType" ):
71
88
return False
72
89
73
- import jax
74
- import numpy as np
90
+ if " jax" not in sys . modules :
91
+ return False
75
92
76
- jax_float0 = cast ("np.dtype[np.void]" , jax .float0 )
77
- return (
78
- isinstance (x , np .ndarray )
79
- and cast ("npt.NDArray[np.void]" , x ).dtype == jax_float0
80
- )
93
+ import jax
94
+ # jax.float0 is a np.dtype([('float0', 'V')])
95
+ return dtype == jax .float0
81
96
82
97
83
98
def is_numpy_array (x : object ) -> TypeGuard [npt .NDArray [Any ]]:
@@ -101,15 +116,12 @@ def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]:
101
116
is_jax_array
102
117
is_pydata_sparse_array
103
118
"""
104
- # Avoid importing NumPy if it isn't already
105
- if "numpy" not in sys .modules :
106
- return False
107
-
108
- import numpy as np
109
-
110
119
# TODO: Should we reject ndarray subclasses?
111
- return (isinstance (x , (np .ndarray , np .generic ))
112
- and not _is_jax_zero_gradient_array (x )) # pyright: ignore[reportUnknownArgumentType] # fmt: skip
120
+ cls = cast (Hashable , type (x ))
121
+ return (
122
+ _issubclass_fast (cls , "numpy" , "ndarray" )
123
+ or _issubclass_fast (cls , "numpy" , "generic" )
124
+ ) and not _is_jax_zero_gradient_array (x )
113
125
114
126
115
127
def is_cupy_array (x : object ) -> bool :
@@ -133,14 +145,8 @@ def is_cupy_array(x: object) -> bool:
133
145
is_jax_array
134
146
is_pydata_sparse_array
135
147
"""
136
- # Avoid importing CuPy if it isn't already
137
- if "cupy" not in sys .modules :
138
- return False
139
-
140
- import cupy as cp # pyright: ignore[reportMissingTypeStubs]
141
-
142
- # TODO: Should we reject ndarray subclasses?
143
- return isinstance (x , cp .ndarray ) # pyright: ignore[reportUnknownMemberType]
148
+ cls = cast (Hashable , type (x ))
149
+ return _issubclass_fast (cls , "cupy" , "ndarray" )
144
150
145
151
146
152
def is_torch_array (x : object ) -> TypeIs [torch .Tensor ]:
@@ -161,14 +167,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]:
161
167
is_jax_array
162
168
is_pydata_sparse_array
163
169
"""
164
- # Avoid importing torch if it isn't already
165
- if "torch" not in sys .modules :
166
- return False
167
-
168
- import torch
169
-
170
- # TODO: Should we reject ndarray subclasses?
171
- return isinstance (x , torch .Tensor )
170
+ cls = cast (Hashable , type (x ))
171
+ return _issubclass_fast (cls , "torch" , "Tensor" )
172
172
173
173
174
174
def is_ndonnx_array (x : object ) -> TypeIs [ndx .Array ]:
@@ -190,13 +190,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]:
190
190
is_jax_array
191
191
is_pydata_sparse_array
192
192
"""
193
- # Avoid importing torch if it isn't already
194
- if "ndonnx" not in sys .modules :
195
- return False
196
-
197
- import ndonnx as ndx
198
-
199
- return isinstance (x , ndx .Array )
193
+ cls = cast (Hashable , type (x ))
194
+ return _issubclass_fast (cls , "ndonnx" , "Array" )
200
195
201
196
202
197
def is_dask_array (x : object ) -> TypeIs [da .Array ]:
@@ -218,13 +213,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]:
218
213
is_jax_array
219
214
is_pydata_sparse_array
220
215
"""
221
- # Avoid importing dask if it isn't already
222
- if "dask.array" not in sys .modules :
223
- return False
224
-
225
- import dask .array
226
-
227
- return isinstance (x , dask .array .Array )
216
+ cls = cast (Hashable , type (x ))
217
+ return _issubclass_fast (cls , "dask.array" , "Array" )
228
218
229
219
230
220
def is_jax_array (x : object ) -> TypeIs [jax .Array ]:
@@ -247,13 +237,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
247
237
is_dask_array
248
238
is_pydata_sparse_array
249
239
"""
250
- # Avoid importing jax if it isn't already
251
- if "jax" not in sys .modules :
252
- return False
253
-
254
- import jax
255
-
256
- return isinstance (x , jax .Array ) or _is_jax_zero_gradient_array (x )
240
+ cls = cast (Hashable , type (x ))
241
+ return _issubclass_fast (cls , "jax" , "Array" ) or _is_jax_zero_gradient_array (x )
257
242
258
243
259
244
def is_pydata_sparse_array (x : object ) -> TypeIs [sparse .SparseArray ]:
@@ -276,14 +261,9 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
276
261
is_dask_array
277
262
is_jax_array
278
263
"""
279
- # Avoid importing jax if it isn't already
280
- if "sparse" not in sys .modules :
281
- return False
282
-
283
- import sparse # pyright: ignore[reportMissingTypeStubs]
284
-
285
264
# TODO: Account for other backends.
286
- return isinstance (x , sparse .SparseArray )
265
+ cls = cast (Hashable , type (x ))
266
+ return _issubclass_fast (cls , "sparse" , "SparseArray" )
287
267
288
268
289
269
def is_array_api_obj (x : object ) -> TypeIs [_ArrayApiObj ]: # pyright: ignore[reportUnknownParameterType]
@@ -302,13 +282,23 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo
302
282
is_jax_array
303
283
"""
304
284
return (
305
- is_numpy_array (x )
306
- or is_cupy_array (x )
307
- or is_torch_array (x )
308
- or is_dask_array (x )
309
- or is_jax_array (x )
310
- or is_pydata_sparse_array (x )
311
- or hasattr (x , "__array_namespace__" )
285
+ hasattr (x , '__array_namespace__' )
286
+ or _is_array_api_cls (cast (Hashable , type (x )))
287
+ )
288
+
289
+
290
+ @lru_cache (100 )
291
+ def _is_array_api_cls (cls : type ) -> bool :
292
+ return (
293
+ # TODO: drop support for numpy<2 which didn't have __array_namespace__
294
+ _issubclass_fast (cls , "numpy" , "ndarray" )
295
+ or _issubclass_fast (cls , "numpy" , "generic" )
296
+ or _issubclass_fast (cls , "cupy" , "ndarray" )
297
+ or _issubclass_fast (cls , "torch" , "Tensor" )
298
+ or _issubclass_fast (cls , "dask.array" , "Array" )
299
+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
300
+ # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__
301
+ or _issubclass_fast (cls , "jax" , "Array" )
312
302
)
313
303
314
304
@@ -317,6 +307,7 @@ def _compat_module_name() -> str:
317
307
return __name__ .removesuffix (".common._helpers" )
318
308
319
309
310
+ @lru_cache (100 )
320
311
def is_numpy_namespace (xp : Namespace ) -> bool :
321
312
"""
322
313
Returns True if `xp` is a NumPy namespace.
@@ -338,6 +329,7 @@ def is_numpy_namespace(xp: Namespace) -> bool:
338
329
return xp .__name__ in {"numpy" , _compat_module_name () + ".numpy" }
339
330
340
331
332
+ @lru_cache (100 )
341
333
def is_cupy_namespace (xp : Namespace ) -> bool :
342
334
"""
343
335
Returns True if `xp` is a CuPy namespace.
@@ -359,6 +351,7 @@ def is_cupy_namespace(xp: Namespace) -> bool:
359
351
return xp .__name__ in {"cupy" , _compat_module_name () + ".cupy" }
360
352
361
353
354
+ @lru_cache (100 )
362
355
def is_torch_namespace (xp : Namespace ) -> bool :
363
356
"""
364
357
Returns True if `xp` is a PyTorch namespace.
@@ -399,6 +392,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool:
399
392
return xp .__name__ == "ndonnx"
400
393
401
394
395
+ @lru_cache (100 )
402
396
def is_dask_namespace (xp : Namespace ) -> bool :
403
397
"""
404
398
Returns True if `xp` is a Dask namespace.
@@ -939,6 +933,19 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None:
939
933
return None if math .isnan (out ) else out
940
934
941
935
936
+ @lru_cache (100 )
937
+ def _is_writeable_cls (cls : type ) -> bool | None :
938
+ if (
939
+ _issubclass_fast (cls , "numpy" , "generic" )
940
+ or _issubclass_fast (cls , "jax" , "Array" )
941
+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
942
+ ):
943
+ return False
944
+ if _is_array_api_cls (cls ):
945
+ return True
946
+ return None
947
+
948
+
942
949
def is_writeable_array (x : object ) -> bool :
943
950
"""
944
951
Return False if ``x.__setitem__`` is expected to raise; True otherwise.
@@ -949,11 +956,32 @@ def is_writeable_array(x: object) -> bool:
949
956
As there is no standard way to check if an array is writeable without actually
950
957
writing to it, this function blindly returns True for all unknown array types.
951
958
"""
952
- if is_numpy_array (x ):
953
- return x .flags .writeable
954
- if is_jax_array (x ) or is_pydata_sparse_array (x ):
959
+ cls = cast (Hashable , type (x ))
960
+ if _issubclass_fast (cls , "numpy" , "ndarray" ):
961
+ return cast ("npt.NDArray" , x ).flags .writeable
962
+ res = _is_writeable_cls (cls )
963
+ if res is not None :
964
+ return res
965
+ return hasattr (x , '__array_namespace__' )
966
+
967
+
968
+ @lru_cache (100 )
969
+ def _is_lazy_cls (cls : type ) -> bool | None :
970
+ if (
971
+ _issubclass_fast (cls , "numpy" , "ndarray" )
972
+ or _issubclass_fast (cls , "numpy" , "generic" )
973
+ or _issubclass_fast (cls , "cupy" , "ndarray" )
974
+ or _issubclass_fast (cls , "torch" , "Tensor" )
975
+ or _issubclass_fast (cls , "sparse" , "SparseArray" )
976
+ ):
955
977
return False
956
- return is_array_api_obj (x )
978
+ if (
979
+ _issubclass_fast (cls , "jax" , "Array" )
980
+ or _issubclass_fast (cls , "dask.array" , "Array" )
981
+ or _issubclass_fast (cls , "ndonnx" , "Array" )
982
+ ):
983
+ return True
984
+ return None
957
985
958
986
959
987
def is_lazy_array (x : object ) -> bool :
@@ -969,14 +997,6 @@ def is_lazy_array(x: object) -> bool:
969
997
This function errs on the side of caution for array types that may or may not be
970
998
lazy, e.g. JAX arrays, by always returning True for them.
971
999
"""
972
- if (
973
- is_numpy_array (x )
974
- or is_cupy_array (x )
975
- or is_torch_array (x )
976
- or is_pydata_sparse_array (x )
977
- ):
978
- return False
979
-
980
1000
# **JAX note:** while it is possible to determine if you're inside or outside
981
1001
# jax.jit by testing the subclass of a jax.Array object, as well as testing bool()
982
1002
# as we do below for unknown arrays, this is not recommended by JAX best practices.
@@ -986,10 +1006,14 @@ def is_lazy_array(x: object) -> bool:
986
1006
# compatibility, is highly detrimental to performance as the whole graph will end
987
1007
# up being computed multiple times.
988
1008
989
- if is_jax_array (x ) or is_dask_array (x ) or is_ndonnx_array (x ):
990
- return True
1009
+ # Note: skipping reclassification of JAX zero gradient arrays, as one will
1010
+ # exclusively get them once they leave a jax.grad JIT context.
1011
+ cls = cast (Hashable , type (x ))
1012
+ res = _is_lazy_cls (cls )
1013
+ if res is not None :
1014
+ return res
991
1015
992
- if not is_array_api_obj ( x ):
1016
+ if not hasattr ( x , "__array_namespace__" ):
993
1017
return False
994
1018
995
1019
# Unknown Array API compatible object. Note that this test may have dire consequences
@@ -1042,7 +1066,7 @@ def is_lazy_array(x: object) -> bool:
1042
1066
"to_device" ,
1043
1067
]
1044
1068
1045
- _all_ignore = [" sys" , " math" , " inspect" , " warnings" ]
1069
+ _all_ignore = ['lru_cache' , ' sys' , ' math' , ' inspect' , ' warnings' ]
1046
1070
1047
1071
def __dir__ () -> list [str ]:
1048
1072
return __all__
0 commit comments