diff --git a/doc/api.rst b/doc/api.rst index 84b272e847d..4d95b7425e7 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -117,6 +117,8 @@ Dataset contents Dataset.convert_calendar Dataset.interp_calendar Dataset.get_index + Dataset.as_array_type + Dataset.is_array_type Comparisons ----------- @@ -315,6 +317,8 @@ DataArray contents DataArray.get_index DataArray.astype DataArray.item + DataArray.as_array_type + DataArray.is_array_type Indexing -------- diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 56d9a3d9bed..89fea4552f4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,9 @@ v2025.02.0 (unreleased) New Features ~~~~~~~~~~~~ +- Add convenience methods ``as_array_type`` and ``is_array_type`` for converting wrapped + data to other duck array types. (:issue:`7848`, :pull:`9823`). + By `Sam Levang `_. Breaking changes diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index d287564cfe5..6c6ab04e25d 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -844,6 +844,46 @@ def as_numpy(self) -> Self: coords = {k: v.as_numpy() for k, v in self._coords.items()} return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes) + def as_array_type(self, asarray: Callable, **kwargs) -> Self: + """ + Converts wrapped data into a specific array type. + + If the data is a chunked array, the conversion is applied to each block. + + `asarray` should output an object that supports the Array API Standard. + This method does not convert index coordinates, which can't generally be + represented as arbitrary array types. + + Parameters + ---------- + asarray : Callable + Function that converts an array-like object to the desired array type. + For example, `cupy.asarray`, `jax.numpy.asarray`, `sparse.COO.from_numpy`, + or any `from_dlpack` method. + **kwargs : dict + Additional keyword arguments passed to the `asarray` function. + + Returns + ------- + DataArray + """ + return self._replace(self.variable.as_array_type(asarray, **kwargs)) + + def is_array_type(self, array_type: type) -> bool: + """ + Check if the wrapped data is of a specific array type. + + Parameters + ---------- + array_type : type + The array type to check for. + + Returns + ------- + bool + """ + return self.variable.is_array_type(array_type) + @property def _in_memory(self) -> bool: return self.variable._in_memory diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 36cc76cbf8d..af9ee7938cf 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1461,6 +1461,54 @@ def as_numpy(self) -> Self: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) + def as_array_type(self, asarray: Callable, **kwargs) -> Self: + """ + Converts wrapped data into a specific array type. + + If the data is a chunked array, the conversion is applied to each block. + + `asarray` should output an object that supports the Array API Standard. + This method does not convert index coordinates, which can't generally be + represented as arbitrary array types. + + Parameters + ---------- + asarray : Callable + Function that converts an array-like object to the desired array type. + For example, `cupy.asarray`, `jax.numpy.asarray`, `sparse.COO.from_numpy`, + or any `from_dlpack` method. + **kwargs : dict + Additional keyword arguments passed to the `asarray` function. + + Returns + ------- + Dataset + """ + array_variables = { + k: v.as_array_type(asarray, **kwargs) if k not in self._indexes else v + for k, v in self.variables.items() + } + return self._replace(variables=array_variables) + + def is_array_type(self, array_type: type) -> bool: + """ + Check if all data variables and non-index coordinates are of a specific array type. + + Parameters + ---------- + array_type : type + The array type to check for. + + Returns + ------- + bool + """ + return all( + v.is_array_type(array_type) + for k, v in self.variables.items() + if k not in self._indexes + ) + def _copy_listed(self, names: Iterable[Hashable]) -> Self: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index cdf9eab5c8d..8800b39a151 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -40,8 +40,8 @@ _SupportsImag, _SupportsReal, ) -from xarray.namedarray.parallelcompat import guess_chunkmanager -from xarray.namedarray.pycompat import to_numpy +from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager +from xarray.namedarray.pycompat import is_chunked_array, to_numpy from xarray.namedarray.utils import ( either_dict_or_kwargs, infix_dims, @@ -863,6 +863,49 @@ def as_numpy(self) -> Self: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) + def as_array_type( + self, + asarray: Callable[[duckarray[Any, _DType_co]], duckarray[Any, _DType_co]], + **kwargs: Any, + ) -> Self: + """Converts wrapped data into a specific array type. + + If the data is a chunked array, the conversion is applied to each block. + + Parameters + ---------- + asarray : callable + Function that converts the data into a specific array type. + **kwargs : dict + Additional keyword arguments passed on to `asarray`. + + Returns + ------- + array : NamedArray + Array with the same data, but converted into a specific array type + """ + if is_chunked_array(self._data): + chunkmanager = get_chunked_array_type(self._data) + new_data = chunkmanager.map_blocks(asarray, self._data, **kwargs) + else: + new_data = asarray(self._data, **kwargs) + + return self._replace(data=new_data) + + def is_array_type(self, array_type: type) -> bool: + """Check if the data is an instance of a specific array type. + + Parameters + ---------- + array_type : type + Array type to check against. + + Returns + ------- + is_array_type : bool + """ + return isinstance(self._data, array_type) + def reduce( self, func: Callable[..., Any], diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index bfcfda19df3..4a779ac1410 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -39,6 +39,7 @@ from xarray.core.utils import is_scalar from xarray.testing import _assert_internal_invariants from xarray.tests import ( + DuckArrayWrapper, InaccessibleArray, ReturnItem, assert_allclose, @@ -7167,6 +7168,32 @@ def test_from_pint_wrapping_dask(self) -> None: np.testing.assert_equal(da.to_numpy(), arr) +def test_as_array_type_is_array_type() -> None: + da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [4, 5, 6]}) + + assert da.is_array_type(np.ndarray) + + result = da.as_array_type(lambda x: DuckArrayWrapper(x)) + + assert isinstance(result.data, DuckArrayWrapper) + assert isinstance(result.x.data, np.ndarray) + assert result.is_array_type(DuckArrayWrapper) + + +@requires_dask +def test_as_array_type_dask() -> None: + import dask.array + + da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [4, 5, 6]}).chunk() + + result = da.as_array_type(lambda x: DuckArrayWrapper(x)) + + assert isinstance(result.data, dask.array.Array) + assert isinstance(result.data._meta, DuckArrayWrapper) + assert isinstance(result.x.data, np.ndarray) + assert result.is_array_type(dask.array.Array) + + class TestStackEllipsis: # https://github.com/pydata/xarray/issues/6051 def test_result_as_expected(self) -> None: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f3867bd67d2..1fed3d18f18 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7641,6 +7641,40 @@ def test_from_pint_wrapping_dask(self) -> None: assert_identical(result, expected) +def test_as_array_type_is_array_type() -> None: + ds = xr.Dataset( + {"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6]), "x": [7, 8, 9]} + ) + # lat is a PandasIndex here + assert ds.drop_vars("lat").is_array_type(np.ndarray) + + result = ds.as_array_type(lambda x: DuckArrayWrapper(x)) + + assert isinstance(result.a.data, DuckArrayWrapper) + assert isinstance(result.lat.data, DuckArrayWrapper) + assert isinstance(result.x.data, np.ndarray) + assert result.is_array_type(DuckArrayWrapper) + + +@requires_dask +def test_as_array_type_dask() -> None: + import dask.array + + ds = xr.Dataset( + {"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6]), "x": [7, 8, 9]} + ).chunk() + + assert ds.is_array_type(dask.array.Array) + + result = ds.as_array_type(lambda x: DuckArrayWrapper(x)) + + assert isinstance(result.a.data, dask.array.Array) + assert isinstance(result.a.data._meta, DuckArrayWrapper) + assert isinstance(result.lat.data, dask.array.Array) + assert isinstance(result.lat.data._meta, DuckArrayWrapper) + assert isinstance(result.x.data, np.ndarray) + + def test_string_keys_typing() -> None: """Tests that string keys to `variables` are permitted by mypy"""