Skip to content

Commit 026aa7c

Browse files
keewisdcherianshoyerpre-commit-ci[bot]
authored
array api-related upstream-dev failures (#8854)
* replace the use of `numpy.array_api` with `array_api_strict` This would make it a dependency of `namedarray`, and not allow behavior that is allowed but not required by the array API standard. Otherwise we can: - use the main `numpy` namespace - use `array_api_compat` (would also be a new dependency) to allow optional behavior * replace `numpy.array_api` with `array_api_strict` in the tests * replace the use of the removed `nxp` with just plain `numpy` * directly pass the `numpy` dtype * replace `dtype.type` with `type(dtype)` for `isnull` * use a new function to compare dtypes * use `array_api_strict`'s version of `int64` * use `array_api_strict`'s dtypes when interacting with its `Array` class * Revert the (unintentional) switch to `mamba` [skip-ci] * use the array API in `result_type` * skip modifying the casting rules if no numpy dtype is involved * don't use isdtype for modules that don't have it * allow mixing numpy arrays with others This is not explicitly allowed by the array API specification (it was declared out of scope), so I'm not sure if this is the right way to do this. * use the array api to implement `nbytes` * refactor `isdtype` * refactor `isdtype` to be a more general dtype checking mechanism * replace all `dtype.kind` calls with `dtypes.isdtype` * use the proper dtype kind * use `_get_data_namespace` to get the array api namespace * explicitly handle `bool` when determining the item size * prefer `itemsize` over the array API's version * add `array-api-strict` as a test dep to the bare-minimum environment * ignore the redefinition of `nxp` * move the array api duck array check into a separate test This allows skipping it if the import fails (and we don't have to add it to the `bare-minimum` ci). * remove `extract_dtype` * try comparing working around extension dtypes * change the `nbytes` test to more clearly communicate the intention * remove the deprecated dtype alias `"a"` * refactor to have different code paths for numpy dtypes and others * use `isdtype` for all other dtype checks in `xarray.core.dtypes` * use the proper kinds * remove the now unused "always tuple" branch in `split_numpy_kinds` * raise an error on invalid / unknown kinds * add tests for `isdtype` * pass in the iterable version of `kind` * remove the array api check * remove the unused `requires_pandas_version_two` * add `bool` to the dummy namespace * actual make the extension array dtype test check something * actually make the extension array dtype check work * adapt the name of the wrapped array * remove the dtype for those examples that use the default dtype * filter out the warning raised by importing `numpy.array_api` * move the `pandas` isdtype check to a different function * mention that we can remove `numpy_isdtype` once we require `numpy>=2.0` * use an enum instead * make `isdtype` simpler * comment on the empty pandas_isdtype * drop `pandas_isdtype` in favor of a simple `return `False` * move the dtype kind verification to `numpy_isdtype` `xp.isdtype` should already check the same thing. * fall back to `numpy.isdtype` if `xp` is not passed * move `numpy_isdtype` to `npcompat` * typing * fix a type comment * additional code comments Co-authored-by: Stephan Hoyer <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more typing * raise a `TypeError` as `numpy.isdtype` does * also allow tuples of strings as kind * invert the condition * final fix, hopefully * next attempt * raise a `ValueError` for unknown dtype kinds * split out the tests we expect to raise into a separate function * add another expected failing test --------- Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: Stephan Hoyer <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 9e240c5 commit 026aa7c

10 files changed

+282
-113
lines changed

xarray/core/dtypes.py

+78-21
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
from typing import Any
55

66
import numpy as np
7+
from pandas.api.types import is_extension_array_dtype
78

8-
from xarray.core import utils
9+
from xarray.core import npcompat, utils
910

1011
# Use as a sentinel value to indicate a dtype appropriate NA value.
1112
NA = utils.ReprObject("<NA>")
@@ -60,22 +61,22 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
6061
# N.B. these casting rules should match pandas
6162
dtype_: np.typing.DTypeLike
6263
fill_value: Any
63-
if np.issubdtype(dtype, np.floating):
64+
if isdtype(dtype, "real floating"):
6465
dtype_ = dtype
6566
fill_value = np.nan
66-
elif np.issubdtype(dtype, np.timedelta64):
67+
elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.timedelta64):
6768
# See https://github.com/numpy/numpy/issues/10685
6869
# np.timedelta64 is a subclass of np.integer
6970
# Check np.timedelta64 before np.integer
7071
fill_value = np.timedelta64("NaT")
7172
dtype_ = dtype
72-
elif np.issubdtype(dtype, np.integer):
73+
elif isdtype(dtype, "integral"):
7374
dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64
7475
fill_value = np.nan
75-
elif np.issubdtype(dtype, np.complexfloating):
76+
elif isdtype(dtype, "complex floating"):
7677
dtype_ = dtype
7778
fill_value = np.nan + np.nan * 1j
78-
elif np.issubdtype(dtype, np.datetime64):
79+
elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.datetime64):
7980
dtype_ = dtype
8081
fill_value = np.datetime64("NaT")
8182
else:
@@ -118,16 +119,16 @@ def get_pos_infinity(dtype, max_for_int=False):
118119
-------
119120
fill_value : positive infinity value corresponding to this dtype.
120121
"""
121-
if issubclass(dtype.type, np.floating):
122+
if isdtype(dtype, "real floating"):
122123
return np.inf
123124

124-
if issubclass(dtype.type, np.integer):
125+
if isdtype(dtype, "integral"):
125126
if max_for_int:
126127
return np.iinfo(dtype).max
127128
else:
128129
return np.inf
129130

130-
if issubclass(dtype.type, np.complexfloating):
131+
if isdtype(dtype, "complex floating"):
131132
return np.inf + 1j * np.inf
132133

133134
return INF
@@ -146,24 +147,66 @@ def get_neg_infinity(dtype, min_for_int=False):
146147
-------
147148
fill_value : positive infinity value corresponding to this dtype.
148149
"""
149-
if issubclass(dtype.type, np.floating):
150+
if isdtype(dtype, "real floating"):
150151
return -np.inf
151152

152-
if issubclass(dtype.type, np.integer):
153+
if isdtype(dtype, "integral"):
153154
if min_for_int:
154155
return np.iinfo(dtype).min
155156
else:
156157
return -np.inf
157158

158-
if issubclass(dtype.type, np.complexfloating):
159+
if isdtype(dtype, "complex floating"):
159160
return -np.inf - 1j * np.inf
160161

161162
return NINF
162163

163164

164-
def is_datetime_like(dtype):
165+
def is_datetime_like(dtype) -> bool:
165166
"""Check if a dtype is a subclass of the numpy datetime types"""
166-
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
167+
return _is_numpy_subdtype(dtype, (np.datetime64, np.timedelta64))
168+
169+
170+
def is_object(dtype) -> bool:
171+
"""Check if a dtype is object"""
172+
return _is_numpy_subdtype(dtype, object)
173+
174+
175+
def is_string(dtype) -> bool:
176+
"""Check if a dtype is a string dtype"""
177+
return _is_numpy_subdtype(dtype, (np.str_, np.character))
178+
179+
180+
def _is_numpy_subdtype(dtype, kind) -> bool:
181+
if not isinstance(dtype, np.dtype):
182+
return False
183+
184+
kinds = kind if isinstance(kind, tuple) else (kind,)
185+
return any(np.issubdtype(dtype, kind) for kind in kinds)
186+
187+
188+
def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
189+
"""Compatibility wrapper for isdtype() from the array API standard.
190+
191+
Unlike xp.isdtype(), kind must be a string.
192+
"""
193+
# TODO(shoyer): remove this wrapper when Xarray requires
194+
# numpy>=2 and pandas extensions arrays are implemented in
195+
# Xarray via the array API
196+
if not isinstance(kind, str) and not (
197+
isinstance(kind, tuple) and all(isinstance(k, str) for k in kind)
198+
):
199+
raise TypeError(f"kind must be a string or a tuple of strings: {repr(kind)}")
200+
201+
if isinstance(dtype, np.dtype):
202+
return npcompat.isdtype(dtype, kind)
203+
elif is_extension_array_dtype(dtype):
204+
# we never want to match pandas extension array dtypes
205+
return False
206+
else:
207+
if xp is None:
208+
xp = np
209+
return xp.isdtype(dtype, kind)
167210

168211

169212
def result_type(
@@ -184,12 +227,26 @@ def result_type(
184227
-------
185228
numpy.dtype for the result.
186229
"""
187-
types = {np.result_type(t).type for t in arrays_and_dtypes}
230+
from xarray.core.duck_array_ops import get_array_namespace
231+
232+
# TODO(shoyer): consider moving this logic into get_array_namespace()
233+
# or another helper function.
234+
namespaces = {get_array_namespace(t) for t in arrays_and_dtypes}
235+
non_numpy = namespaces - {np}
236+
if non_numpy:
237+
[xp] = non_numpy
238+
else:
239+
xp = np
240+
241+
types = {xp.result_type(t) for t in arrays_and_dtypes}
188242

189-
for left, right in PROMOTE_TO_OBJECT:
190-
if any(issubclass(t, left) for t in types) and any(
191-
issubclass(t, right) for t in types
192-
):
193-
return np.dtype(object)
243+
if any(isinstance(t, np.dtype) for t in types):
244+
# only check if there's numpy dtypes – the array API does not
245+
# define the types we're checking for
246+
for left, right in PROMOTE_TO_OBJECT:
247+
if any(np.issubdtype(t, left) for t in types) and any(
248+
np.issubdtype(t, right) for t in types
249+
):
250+
return xp.dtype(object)
194251

195-
return np.result_type(*arrays_and_dtypes)
252+
return xp.result_type(*arrays_and_dtypes)

xarray/core/duck_array_ops.py

+36-15
Original file line numberDiff line numberDiff line change
@@ -142,17 +142,25 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
142142

143143
def isnull(data):
144144
data = asarray(data)
145-
scalar_type = data.dtype.type
146-
if issubclass(scalar_type, (np.datetime64, np.timedelta64)):
145+
146+
xp = get_array_namespace(data)
147+
scalar_type = data.dtype
148+
if dtypes.is_datetime_like(scalar_type):
147149
# datetime types use NaT for null
148150
# note: must check timedelta64 before integers, because currently
149151
# timedelta64 inherits from np.integer
150152
return isnat(data)
151-
elif issubclass(scalar_type, np.inexact):
153+
elif dtypes.isdtype(scalar_type, ("real floating", "complex floating"), xp=xp):
152154
# float types use NaN for null
153155
xp = get_array_namespace(data)
154156
return xp.isnan(data)
155-
elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)):
157+
elif dtypes.isdtype(scalar_type, ("bool", "integral"), xp=xp) or (
158+
isinstance(scalar_type, np.dtype)
159+
and (
160+
np.issubdtype(scalar_type, np.character)
161+
or np.issubdtype(scalar_type, np.void)
162+
)
163+
):
156164
# these types cannot represent missing values
157165
return full_like(data, dtype=bool, fill_value=False)
158166
else:
@@ -406,13 +414,22 @@ def f(values, axis=None, skipna=None, **kwargs):
406414
if invariant_0d and axis == ():
407415
return values
408416

409-
values = asarray(values)
417+
xp = get_array_namespace(values)
418+
values = asarray(values, xp=xp)
410419

411-
if coerce_strings and values.dtype.kind in "SU":
420+
if coerce_strings and dtypes.is_string(values.dtype):
412421
values = astype(values, object)
413422

414423
func = None
415-
if skipna or (skipna is None and values.dtype.kind in "cfO"):
424+
if skipna or (
425+
skipna is None
426+
and (
427+
dtypes.isdtype(
428+
values.dtype, ("complex floating", "real floating"), xp=xp
429+
)
430+
or dtypes.is_object(values.dtype)
431+
)
432+
):
416433
nanname = "nan" + name
417434
func = getattr(nanops, nanname)
418435
else:
@@ -477,8 +494,8 @@ def _datetime_nanmin(array):
477494
- numpy nanmin() don't work on datetime64 (all versions at the moment of writing)
478495
- dask min() does not work on datetime64 (all versions at the moment of writing)
479496
"""
480-
assert array.dtype.kind in "mM"
481497
dtype = array.dtype
498+
assert dtypes.is_datetime_like(dtype)
482499
# (NaT).astype(float) does not produce NaN...
483500
array = where(pandas_isnull(array), np.nan, array.astype(float))
484501
array = min(array, skipna=True)
@@ -515,7 +532,7 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
515532
"""
516533
# Set offset to minimum if not given
517534
if offset is None:
518-
if array.dtype.kind in "Mm":
535+
if dtypes.is_datetime_like(array.dtype):
519536
offset = _datetime_nanmin(array)
520537
else:
521538
offset = min(array)
@@ -527,7 +544,7 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
527544
# This map_blocks call is for backwards compatibility.
528545
# dask == 2021.04.1 does not support subtracting object arrays
529546
# which is required for cftime
530-
if is_duck_dask_array(array) and np.issubdtype(array.dtype, object):
547+
if is_duck_dask_array(array) and dtypes.is_object(array.dtype):
531548
array = array.map_blocks(lambda a, b: a - b, offset, meta=array._meta)
532549
else:
533550
array = array - offset
@@ -537,11 +554,11 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
537554
array = np.array(array)
538555

539556
# Convert timedelta objects to float by first converting to microseconds.
540-
if array.dtype.kind in "O":
557+
if dtypes.is_object(array.dtype):
541558
return py_timedelta_to_float(array, datetime_unit or "ns").astype(dtype)
542559

543560
# Convert np.NaT to np.nan
544-
elif array.dtype.kind in "mM":
561+
elif dtypes.is_datetime_like(array.dtype):
545562
# Convert to specified timedelta units.
546563
if datetime_unit:
547564
array = array / np.timedelta64(1, datetime_unit)
@@ -641,7 +658,7 @@ def mean(array, axis=None, skipna=None, **kwargs):
641658
from xarray.core.common import _contains_cftime_datetimes
642659

643660
array = asarray(array)
644-
if array.dtype.kind in "Mm":
661+
if dtypes.is_datetime_like(array.dtype):
645662
offset = _datetime_nanmin(array)
646663

647664
# xarray always uses np.datetime64[ns] for np.datetime64 data
@@ -689,7 +706,9 @@ def cumsum(array, axis=None, **kwargs):
689706

690707
def first(values, axis, skipna=None):
691708
"""Return the first non-NA elements in this array along the given axis"""
692-
if (skipna or skipna is None) and values.dtype.kind not in "iSU":
709+
if (skipna or skipna is None) and not (
710+
dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype)
711+
):
693712
# only bother for dtypes that can hold NaN
694713
if is_chunked_array(values):
695714
return chunked_nanfirst(values, axis)
@@ -700,7 +719,9 @@ def first(values, axis, skipna=None):
700719

701720
def last(values, axis, skipna=None):
702721
"""Return the last non-NA elements in this array along the given axis"""
703-
if (skipna or skipna is None) and values.dtype.kind not in "iSU":
722+
if (skipna or skipna is None) and not (
723+
dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype)
724+
):
704725
# only bother for dtypes that can hold NaN
705726
if is_chunked_array(values):
706727
return chunked_nanlast(values, axis)

xarray/core/npcompat.py

+30
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,33 @@
2828
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
2929
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
3030
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31+
32+
try:
33+
# requires numpy>=2.0
34+
from numpy import isdtype # type: ignore[attr-defined,unused-ignore]
35+
except ImportError:
36+
import numpy as np
37+
38+
dtype_kinds = {
39+
"bool": np.bool_,
40+
"signed integer": np.signedinteger,
41+
"unsigned integer": np.unsignedinteger,
42+
"integral": np.integer,
43+
"real floating": np.floating,
44+
"complex floating": np.complexfloating,
45+
"numeric": np.number,
46+
}
47+
48+
def isdtype(dtype, kind):
49+
kinds = kind if isinstance(kind, tuple) else (kind,)
50+
51+
unknown_dtypes = [kind for kind in kinds if kind not in dtype_kinds]
52+
if unknown_dtypes:
53+
raise ValueError(f"unknown dtype kinds: {unknown_dtypes}")
54+
55+
# verified the dtypes already, no need to check again
56+
translated_kinds = [dtype_kinds[kind] for kind in kinds]
57+
if isinstance(dtype, np.generic):
58+
return any(isinstance(dtype, kind) for kind in translated_kinds)
59+
else:
60+
return any(np.issubdtype(dtype, kind) for kind in translated_kinds)

0 commit comments

Comments
 (0)