Skip to content

array api-related upstream-dev failures #8854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 79 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
e066a6c
replace the use of `numpy.array_api` with `array_api_strict`
keewis Mar 19, 2024
0c14425
replace `numpy.array_api` with `array_api_strict` in the tests
keewis Mar 19, 2024
c76adc9
replace the use of the removed `nxp` with just plain `numpy`
keewis Mar 22, 2024
57cd907
directly pass the `numpy` dtype
keewis Mar 22, 2024
3a7552f
replace `dtype.type` with `type(dtype)` for `isnull`
keewis Mar 22, 2024
4fa8767
use a new function to compare dtypes
keewis Mar 22, 2024
f708b50
Merge branch 'main' into numpy2-array-api
keewis Apr 10, 2024
4063e11
use `array_api_strict`'s version of `int64`
keewis Apr 10, 2024
fdd6c80
use `array_api_strict`'s dtypes when interacting with its `Array` class
keewis Apr 10, 2024
3ef997a
Revert the (unintentional) switch to `mamba` [skip-ci]
keewis Apr 10, 2024
2e0211c
use the array API in `result_type`
keewis Apr 12, 2024
73da2ba
skip modifying the casting rules if no numpy dtype is involved
keewis Apr 12, 2024
63c0f6c
don't use isdtype for modules that don't have it
keewis Apr 12, 2024
0326ae3
allow mixing numpy arrays with others
keewis Apr 12, 2024
b1e259d
use the array api to implement `nbytes`
keewis Apr 12, 2024
9d94dc1
refactor `isdtype`
keewis Apr 12, 2024
16e2403
refactor `isdtype` to be a more general dtype checking mechanism
keewis Apr 12, 2024
79bc7f5
replace all `dtype.kind` calls with `dtypes.isdtype`
keewis Apr 12, 2024
3cb20ce
use the proper dtype kind
keewis Apr 12, 2024
1f0d0f3
use `_get_data_namespace` to get the array api namespace
keewis Apr 12, 2024
40cd9c9
Merge branch 'main' into numpy2-array-api
keewis Apr 12, 2024
7e952fd
explicitly handle `bool` when determining the item size
keewis Apr 12, 2024
d0ab11d
prefer `itemsize` over the array API's version
keewis Apr 12, 2024
108d40f
add `array-api-strict` as a test dep to the bare-minimum environment
keewis Apr 13, 2024
da6fff6
ignore the redefinition of `nxp`
keewis Apr 13, 2024
423b7ea
move the array api duck array check into a separate test
keewis Apr 13, 2024
84f0c95
remove `extract_dtype`
keewis Apr 13, 2024
aad9386
Merge branch 'main' into numpy2-array-api
keewis Apr 14, 2024
1c424b9
Merge branch 'main' into numpy2-array-api
keewis Apr 19, 2024
7a929c1
try comparing working around extension dtypes
keewis Apr 21, 2024
45de4eb
change the `nbytes` test to more clearly communicate the intention
keewis Apr 22, 2024
aa3dea8
Merge branch 'main' into numpy2-array-api
keewis Apr 22, 2024
52c61ab
Merge branch 'main' into numpy2-array-api
keewis Apr 28, 2024
24f26c2
Merge branch 'main' into numpy2-array-api
keewis Apr 29, 2024
73372da
Merge branch 'main' into numpy2-array-api
keewis Apr 29, 2024
a82ec8b
remove the deprecated dtype alias `"a"`
keewis May 3, 2024
152b983
refactor to have different code paths for numpy dtypes and others
keewis May 3, 2024
1977ad5
Merge branch 'main' into numpy2-array-api
keewis May 3, 2024
d9426ec
use `isdtype` for all other dtype checks in `xarray.core.dtypes`
keewis May 3, 2024
a43c1bf
use the proper kinds
keewis May 3, 2024
0f4d7be
remove the now unused "always tuple" branch in `split_numpy_kinds`
keewis May 3, 2024
7e95622
raise an error on invalid / unknown kinds
keewis May 3, 2024
a088e11
add tests for `isdtype`
keewis May 3, 2024
26bd6a1
pass in the iterable version of `kind`
keewis May 3, 2024
833f54f
remove the array api check
keewis May 3, 2024
75b3b6d
remove the unused `requires_pandas_version_two`
keewis May 3, 2024
3c28cb7
Merge branch 'main' into numpy2-array-api
dcherian May 6, 2024
fb59d88
add `bool` to the dummy namespace
keewis May 7, 2024
d18e23c
Merge branch 'main' into numpy2-array-api
keewis May 7, 2024
5c34163
actual make the extension array dtype test check something
keewis May 7, 2024
d72a621
actually make the extension array dtype check work
keewis May 7, 2024
d9f2fb5
adapt the name of the wrapped array
keewis May 7, 2024
810cf61
remove the dtype for those examples that use the default dtype
keewis May 7, 2024
3e87ea9
filter out the warning raised by importing `numpy.array_api`
keewis May 7, 2024
846b1cb
move the `pandas` isdtype check to a different function
keewis May 8, 2024
a59edd3
mention that we can remove `numpy_isdtype` once we require `numpy>=2.0`
keewis May 8, 2024
911206b
use an enum instead
keewis May 8, 2024
14c5a56
make `isdtype` simpler
keewis May 16, 2024
58d6b8b
comment on the empty pandas_isdtype
keewis May 16, 2024
a9c7a21
drop `pandas_isdtype` in favor of a simple `return `False`
keewis May 16, 2024
62eec48
move the dtype kind verification to `numpy_isdtype`
keewis May 16, 2024
007e6c9
fall back to `numpy.isdtype` if `xp` is not passed
keewis May 16, 2024
c5f4262
move `numpy_isdtype` to `npcompat`
keewis May 16, 2024
5641e06
Merge branch 'main' into numpy2-array-api
keewis May 16, 2024
0329951
typing
keewis May 16, 2024
63046d0
fix a type comment
keewis May 16, 2024
2e88691
additional code comments
keewis May 16, 2024
499e553
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2024
63bacb4
more typing
keewis May 16, 2024
fca4b3c
raise a `TypeError` as `numpy.isdtype` does
keewis May 16, 2024
7979d44
also allow tuples of strings as kind
keewis May 16, 2024
513104b
invert the condition
keewis May 16, 2024
6dea06e
final fix, hopefully
keewis May 16, 2024
48b2e2d
next attempt
keewis May 16, 2024
da27a1b
Merge branch 'main' into numpy2-array-api
keewis May 19, 2024
eb9dece
raise a `ValueError` for unknown dtype kinds
keewis May 21, 2024
7302060
split out the tests we expect to raise into a separate function
keewis May 21, 2024
8c798d4
Merge branch 'main' into numpy2-array-api
keewis May 21, 2024
c8ebdc7
add another expected failing test
keewis May 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 78 additions & 21 deletions xarray/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from typing import Any

import numpy as np
from pandas.api.types import is_extension_array_dtype

from xarray.core import utils
from xarray.core import npcompat, utils

# Use as a sentinel value to indicate a dtype appropriate NA value.
NA = utils.ReprObject("<NA>")
Expand Down Expand Up @@ -60,22 +61,22 @@ def maybe_promote(dtype: np.dtype) -> tuple[np.dtype, Any]:
# N.B. these casting rules should match pandas
dtype_: np.typing.DTypeLike
fill_value: Any
if np.issubdtype(dtype, np.floating):
if isdtype(dtype, "real floating"):
dtype_ = dtype
fill_value = np.nan
elif np.issubdtype(dtype, np.timedelta64):
elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.timedelta64):
# See https://github.com/numpy/numpy/issues/10685
# np.timedelta64 is a subclass of np.integer
# Check np.timedelta64 before np.integer
fill_value = np.timedelta64("NaT")
dtype_ = dtype
elif np.issubdtype(dtype, np.integer):
elif isdtype(dtype, "integral"):
dtype_ = np.float32 if dtype.itemsize <= 2 else np.float64
fill_value = np.nan
elif np.issubdtype(dtype, np.complexfloating):
elif isdtype(dtype, "complex floating"):
dtype_ = dtype
fill_value = np.nan + np.nan * 1j
elif np.issubdtype(dtype, np.datetime64):
elif isinstance(dtype, np.dtype) and np.issubdtype(dtype, np.datetime64):
dtype_ = dtype
fill_value = np.datetime64("NaT")
else:
Expand Down Expand Up @@ -118,16 +119,16 @@ def get_pos_infinity(dtype, max_for_int=False):
-------
fill_value : positive infinity value corresponding to this dtype.
"""
if issubclass(dtype.type, np.floating):
if isdtype(dtype, "real floating"):
return np.inf

if issubclass(dtype.type, np.integer):
if isdtype(dtype, "integral"):
if max_for_int:
return np.iinfo(dtype).max
else:
return np.inf

if issubclass(dtype.type, np.complexfloating):
if isdtype(dtype, "complex floating"):
return np.inf + 1j * np.inf

return INF
Expand All @@ -146,24 +147,66 @@ def get_neg_infinity(dtype, min_for_int=False):
-------
fill_value : positive infinity value corresponding to this dtype.
"""
if issubclass(dtype.type, np.floating):
if isdtype(dtype, "real floating"):
return -np.inf

if issubclass(dtype.type, np.integer):
if isdtype(dtype, "integral"):
if min_for_int:
return np.iinfo(dtype).min
else:
return -np.inf

if issubclass(dtype.type, np.complexfloating):
if isdtype(dtype, "complex floating"):
return -np.inf - 1j * np.inf

return NINF


def is_datetime_like(dtype):
def is_datetime_like(dtype) -> bool:
"""Check if a dtype is a subclass of the numpy datetime types"""
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
return _is_numpy_subdtype(dtype, (np.datetime64, np.timedelta64))


def is_object(dtype) -> bool:
"""Check if a dtype is object"""
return _is_numpy_subdtype(dtype, object)


def is_string(dtype) -> bool:
"""Check if a dtype is a string dtype"""
return _is_numpy_subdtype(dtype, (np.str_, np.character))


def _is_numpy_subdtype(dtype, kind) -> bool:
if not isinstance(dtype, np.dtype):
return False

kinds = kind if isinstance(kind, tuple) else (kind,)
return any(np.issubdtype(dtype, kind) for kind in kinds)


def isdtype(dtype, kind: str | tuple[str, ...], xp=None) -> bool:
"""Compatibility wrapper for isdtype() from the array API standard.

Unlike xp.isdtype(), kind must be a string.
"""
# TODO(shoyer): remove this wrapper when Xarray requires
# numpy>=2 and pandas extensions arrays are implemented in
# Xarray via the array API
if not isinstance(kind, str) and not (
isinstance(kind, tuple) and all(isinstance(k, str) for k in kind)
):
raise TypeError(f"kind must be a string or a tuple of strings: {repr(kind)}")

if isinstance(dtype, np.dtype):
return npcompat.isdtype(dtype, kind)
elif is_extension_array_dtype(dtype):
# we never want to match pandas extension array dtypes
return False
else:
if xp is None:
xp = np
return xp.isdtype(dtype, kind)


def result_type(
Expand All @@ -184,12 +227,26 @@ def result_type(
-------
numpy.dtype for the result.
"""
types = {np.result_type(t).type for t in arrays_and_dtypes}
from xarray.core.duck_array_ops import get_array_namespace

# TODO(shoyer): consider moving this logic into get_array_namespace()
# or another helper function.
namespaces = {get_array_namespace(t) for t in arrays_and_dtypes}
non_numpy = namespaces - {np}
if non_numpy:
[xp] = non_numpy
else:
xp = np

types = {xp.result_type(t) for t in arrays_and_dtypes}

for left, right in PROMOTE_TO_OBJECT:
if any(issubclass(t, left) for t in types) and any(
issubclass(t, right) for t in types
):
return np.dtype(object)
if any(isinstance(t, np.dtype) for t in types):
# only check if there's numpy dtypes – the array API does not
# define the types we're checking for
for left, right in PROMOTE_TO_OBJECT:
if any(np.issubdtype(t, left) for t in types) and any(
np.issubdtype(t, right) for t in types
):
return xp.dtype(object)

return np.result_type(*arrays_and_dtypes)
return xp.result_type(*arrays_and_dtypes)
51 changes: 36 additions & 15 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,17 +142,25 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):

def isnull(data):
data = asarray(data)
scalar_type = data.dtype.type
if issubclass(scalar_type, (np.datetime64, np.timedelta64)):

xp = get_array_namespace(data)
scalar_type = data.dtype
if dtypes.is_datetime_like(scalar_type):
# datetime types use NaT for null
# note: must check timedelta64 before integers, because currently
# timedelta64 inherits from np.integer
return isnat(data)
elif issubclass(scalar_type, np.inexact):
elif dtypes.isdtype(scalar_type, ("real floating", "complex floating"), xp=xp):
# float types use NaN for null
xp = get_array_namespace(data)
return xp.isnan(data)
elif issubclass(scalar_type, (np.bool_, np.integer, np.character, np.void)):
elif dtypes.isdtype(scalar_type, ("bool", "integral"), xp=xp) or (
isinstance(scalar_type, np.dtype)
and (
np.issubdtype(scalar_type, np.character)
or np.issubdtype(scalar_type, np.void)
)
):
# these types cannot represent missing values
return full_like(data, dtype=bool, fill_value=False)
else:
Expand Down Expand Up @@ -406,13 +414,22 @@ def f(values, axis=None, skipna=None, **kwargs):
if invariant_0d and axis == ():
return values

values = asarray(values)
xp = get_array_namespace(values)
values = asarray(values, xp=xp)

if coerce_strings and values.dtype.kind in "SU":
if coerce_strings and dtypes.is_string(values.dtype):
values = astype(values, object)

func = None
if skipna or (skipna is None and values.dtype.kind in "cfO"):
if skipna or (
skipna is None
and (
dtypes.isdtype(
values.dtype, ("complex floating", "real floating"), xp=xp
)
or dtypes.is_object(values.dtype)
)
):
nanname = "nan" + name
func = getattr(nanops, nanname)
else:
Expand Down Expand Up @@ -477,8 +494,8 @@ def _datetime_nanmin(array):
- numpy nanmin() don't work on datetime64 (all versions at the moment of writing)
- dask min() does not work on datetime64 (all versions at the moment of writing)
"""
assert array.dtype.kind in "mM"
dtype = array.dtype
assert dtypes.is_datetime_like(dtype)
# (NaT).astype(float) does not produce NaN...
array = where(pandas_isnull(array), np.nan, array.astype(float))
array = min(array, skipna=True)
Expand Down Expand Up @@ -515,7 +532,7 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
"""
# Set offset to minimum if not given
if offset is None:
if array.dtype.kind in "Mm":
if dtypes.is_datetime_like(array.dtype):
offset = _datetime_nanmin(array)
else:
offset = min(array)
Expand All @@ -527,7 +544,7 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
# This map_blocks call is for backwards compatibility.
# dask == 2021.04.1 does not support subtracting object arrays
# which is required for cftime
if is_duck_dask_array(array) and np.issubdtype(array.dtype, object):
if is_duck_dask_array(array) and dtypes.is_object(array.dtype):
array = array.map_blocks(lambda a, b: a - b, offset, meta=array._meta)
else:
array = array - offset
Expand All @@ -537,11 +554,11 @@ def datetime_to_numeric(array, offset=None, datetime_unit=None, dtype=float):
array = np.array(array)

# Convert timedelta objects to float by first converting to microseconds.
if array.dtype.kind in "O":
if dtypes.is_object(array.dtype):
return py_timedelta_to_float(array, datetime_unit or "ns").astype(dtype)

# Convert np.NaT to np.nan
elif array.dtype.kind in "mM":
elif dtypes.is_datetime_like(array.dtype):
# Convert to specified timedelta units.
if datetime_unit:
array = array / np.timedelta64(1, datetime_unit)
Expand Down Expand Up @@ -641,7 +658,7 @@ def mean(array, axis=None, skipna=None, **kwargs):
from xarray.core.common import _contains_cftime_datetimes

array = asarray(array)
if array.dtype.kind in "Mm":
if dtypes.is_datetime_like(array.dtype):
offset = _datetime_nanmin(array)

# xarray always uses np.datetime64[ns] for np.datetime64 data
Expand Down Expand Up @@ -689,7 +706,9 @@ def cumsum(array, axis=None, **kwargs):

def first(values, axis, skipna=None):
"""Return the first non-NA elements in this array along the given axis"""
if (skipna or skipna is None) and values.dtype.kind not in "iSU":
if (skipna or skipna is None) and not (
dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype)
):
# only bother for dtypes that can hold NaN
if is_chunked_array(values):
return chunked_nanfirst(values, axis)
Expand All @@ -700,7 +719,9 @@ def first(values, axis, skipna=None):

def last(values, axis, skipna=None):
"""Return the last non-NA elements in this array along the given axis"""
if (skipna or skipna is None) and values.dtype.kind not in "iSU":
if (skipna or skipna is None) and not (
dtypes.isdtype(values.dtype, "signed integer") or dtypes.is_string(values.dtype)
):
# only bother for dtypes that can hold NaN
if is_chunked_array(values):
return chunked_nanlast(values, axis)
Expand Down
30 changes: 30 additions & 0 deletions xarray/core/npcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,33 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

try:
# requires numpy>=2.0
from numpy import isdtype # type: ignore[attr-defined,unused-ignore]
except ImportError:
import numpy as np

dtype_kinds = {
"bool": np.bool_,
"signed integer": np.signedinteger,
"unsigned integer": np.unsignedinteger,
"integral": np.integer,
"real floating": np.floating,
"complex floating": np.complexfloating,
"numeric": np.number,
}

def isdtype(dtype, kind):
kinds = kind if isinstance(kind, tuple) else (kind,)

unknown_dtypes = [kind for kind in kinds if kind not in dtype_kinds]
if unknown_dtypes:
raise ValueError(f"unknown dtype kinds: {unknown_dtypes}")

# verified the dtypes already, no need to check again
translated_kinds = [dtype_kinds[kind] for kind in kinds]
if isinstance(dtype, np.generic):
return any(isinstance(dtype, kind) for kind in translated_kinds)
else:
return any(np.issubdtype(dtype, kind) for kind in translated_kinds)
Loading
Loading