From 9db951cc7ca31500101f31d0a83e440158e6154b Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Wed, 26 Feb 2025 18:49:23 -0800 Subject: [PATCH 1/6] [WIP]: Force Zarr coordinate reads to be on the host zarr-python 3.x supports reading data to host (CPU) memory or device (GPU) memory. Because coordinates are small and really do need to be on the host (IIUC because putting them in an Index) then there's no benefit to reading them to device. zarr-python includes a global config for whether to use host or device memory for reads, with `zarr.config.enable_gpu()`. But you can override that on a per-read basis by passing `prototype` to the getitem call. This does that for arrays that are coordinates. --- xarray/backends/zarr.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index e83f5556369..f394a2e8b8c 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -180,13 +180,14 @@ def encode_zarr_attr_value(value): class ZarrArrayWrapper(BackendArray): - __slots__ = ("_array", "dtype", "shape") + __slots__ = ("_array", "dtype", "shape", "is_coordinate") - def __init__(self, zarr_array): + def __init__(self, zarr_array, is_coordinate: bool): # some callers attempt to evaluate an array if an `array` property exists on the object. # we prefix with _ to avoid this inference. self._array = zarr_array self.shape = self._array.shape + self.is_coordinate = is_coordinate # preserve vlen string object dtype (GH 7328) if ( @@ -210,7 +211,12 @@ def _vindex(self, key): return self._array.vindex[key] def _getitem(self, key): - return self._array[key] + from zarr.core.buffer.cpu import buffer_prototype + if self.is_coordinate: + prototype = buffer_prototype + else: + prototype = None + return self._array.get_basic_selection(key, prototype=prototype) def __getitem__(self, key): array = self._array @@ -809,7 +815,8 @@ def ds(self): def open_store_variable(self, name): zarr_array = self.members[name] - data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array)) + is_coordinate = name in zarr_array.metadata.dimension_names + data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array, is_coordinate=is_coordinate)) try_nczarr = self._mode == "r" dimensions, attributes = _get_zarr_dims_and_attrs( zarr_array, DIMENSION_KEY, try_nczarr From d620357efa00dfd927f6acc0a39b66a31f4a9adf Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 27 Feb 2025 07:29:09 -0800 Subject: [PATCH 2/6] fixups (cherry picked from commit 48f85ed4e452709607a11a7b526e844fd3e41df3) --- xarray/backends/zarr.py | 73 ++++++++++++++++++++++++++++++----- xarray/tests/test_backends.py | 31 +++++++++++++++ 2 files changed, 94 insertions(+), 10 deletions(-) diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index f394a2e8b8c..2dcfcdb3860 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -179,15 +179,31 @@ def encode_zarr_attr_value(value): return encoded +def _is_coordinate_variable(zarr_array, name): + if _zarr_v3(): + if zarr_array.metadata.zarr_format == 2: + is_coordinate = name in zarr_array.metadata.attributes.get( + "_ARRAY_DIMENSIONS", [] + ) + else: + is_coordinate = name in (zarr_array.metadata.dimension_names or []) + else: + is_coordinate = name in zarr_array.attrs.get("_ARRAY_DIMENSIONS", []) + return is_coordinate + + class ZarrArrayWrapper(BackendArray): - __slots__ = ("_array", "dtype", "shape", "is_coordinate") + __slots__ = ("_array", "coords_buffer_prototype", "dtype", "is_coordinate", "shape") - def __init__(self, zarr_array, is_coordinate: bool): + def __init__( + self, zarr_array, is_coordinate: bool, coords_buffer_prototype: Any | None + ): # some callers attempt to evaluate an array if an `array` property exists on the object. # we prefix with _ to avoid this inference. self._array = zarr_array self.shape = self._array.shape self.is_coordinate = is_coordinate + self.coords_buffer_prototype = coords_buffer_prototype # preserve vlen string object dtype (GH 7328) if ( @@ -211,12 +227,14 @@ def _vindex(self, key): return self._array.vindex[key] def _getitem(self, key): - from zarr.core.buffer.cpu import buffer_prototype - if self.is_coordinate: - prototype = buffer_prototype - else: - prototype = None - return self._array.get_basic_selection(key, prototype=prototype) + kwargs = {} + if _zarr_v3(): + if self.is_coordinate: + prototype = self.coords_buffer_prototype + else: + prototype = None + kwargs["prototype"] = prototype + return self._array.get_basic_selection(key, **kwargs) def __getitem__(self, key): array = self._array @@ -611,6 +629,7 @@ class ZarrStore(AbstractWritableDataStore): "_cache_members", "_close_store_on_close", "_consolidate_on_close", + "_coords_buffer_prototype", "_group", "_members", "_mode", @@ -642,6 +661,7 @@ def open_store( use_zarr_fill_value_as_mask=None, write_empty: bool | None = None, cache_members: bool = True, + coords_buffer_prototype: Any | None = None, ): ( zarr_group, @@ -674,6 +694,7 @@ def open_store( close_store_on_close, use_zarr_fill_value_as_mask, cache_members=cache_members, + coords_buffer_prototype=coords_buffer_prototype, ) for group in group_paths } @@ -697,6 +718,7 @@ def open_group( use_zarr_fill_value_as_mask=None, write_empty: bool | None = None, cache_members: bool = True, + coords_buffer_prototype: Any | None = None, ): ( zarr_group, @@ -728,6 +750,7 @@ def open_group( close_store_on_close, use_zarr_fill_value_as_mask, cache_members, + coords_buffer_prototype, ) def __init__( @@ -742,6 +765,7 @@ def __init__( close_store_on_close: bool = False, use_zarr_fill_value_as_mask=None, cache_members: bool = True, + coords_buffer_prototype: Any | None = None, ): self.zarr_group = zarr_group self._read_only = self.zarr_group.read_only @@ -757,6 +781,14 @@ def __init__( self._use_zarr_fill_value_as_mask = use_zarr_fill_value_as_mask self._cache_members: bool = cache_members self._members: dict[str, ZarrArray | ZarrGroup] = {} + if _zarr_v3() and coords_buffer_prototype is None: + # Once zarr-v3 is required we can just have this as the default + # https://github.com/zarr-developers/zarr-python/issues/2871 + # Use the public API once available + from zarr.core.buffer.cpu import buffer_prototype + + coords_buffer_prototype = buffer_prototype + self._coords_buffer_prototype = coords_buffer_prototype if self._cache_members: # initialize the cache @@ -815,8 +847,15 @@ def ds(self): def open_store_variable(self, name): zarr_array = self.members[name] - is_coordinate = name in zarr_array.metadata.dimension_names - data = indexing.LazilyIndexedArray(ZarrArrayWrapper(zarr_array, is_coordinate=is_coordinate)) + is_coordinate = _is_coordinate_variable(zarr_array, name) + + data = indexing.LazilyIndexedArray( + ZarrArrayWrapper( + zarr_array, + is_coordinate=is_coordinate, + coords_buffer_prototype=self._coords_buffer_prototype, + ) + ) try_nczarr = self._mode == "r" dimensions, attributes = _get_zarr_dims_and_attrs( zarr_array, DIMENSION_KEY, try_nczarr @@ -1339,6 +1378,7 @@ def open_zarr( use_zarr_fill_value_as_mask=None, chunked_array_type: str | None = None, from_array_kwargs: dict[str, Any] | None = None, + coords_buffer_prototype: Any | None = None, **kwargs, ): """Load and decode a dataset from a Zarr store. @@ -1449,6 +1489,12 @@ def open_zarr( chunked arrays, via whichever chunk manager is specified through the ``chunked_array_type`` kwarg. Defaults to ``{'manager': 'dask'}``, meaning additional kwargs will be passed eventually to :py:func:`dask.array.from_array`. Experimental API that should not be relied upon. + coords_buffer_prototype : zarr.buffer.BufferPrototype, optional + The buffer prototype to use for loading coordinate arrays. Zarr offers control over + which device's memory buffers are read into. By default, xarray will always load + *coordinate* buffers into host (CPU) memory, regardless of the global zarr + configuration. To override this behavior, explicitly pass the buffer prototype + to use for coordinates here. Returns ------- @@ -1492,6 +1538,7 @@ def open_zarr( "storage_options": storage_options, "zarr_version": zarr_version, "zarr_format": zarr_format, + "coords_buffer_prototype": coords_buffer_prototype, } ds = open_dataset( @@ -1564,6 +1611,7 @@ def open_dataset( engine=None, use_zarr_fill_value_as_mask=None, cache_members: bool = True, + coords_buffer_prototype: Any | None = None, ) -> Dataset: filename_or_obj = _normalize_path(filename_or_obj) if not store: @@ -1580,6 +1628,7 @@ def open_dataset( use_zarr_fill_value_as_mask=None, zarr_format=zarr_format, cache_members=cache_members, + coords_buffer_prototype=coords_buffer_prototype, ) store_entrypoint = StoreBackendEntrypoint() @@ -1615,6 +1664,7 @@ def open_datatree( storage_options=None, zarr_version=None, zarr_format=None, + coords_buffer_prototype: Any | None = None, ) -> DataTree: filename_or_obj = _normalize_path(filename_or_obj) groups_dict = self.open_groups_as_dict( @@ -1634,6 +1684,7 @@ def open_datatree( storage_options=storage_options, zarr_version=zarr_version, zarr_format=zarr_format, + coords_buffer_prototype=coords_buffer_prototype, ) return datatree_from_dict_with_io_cleanup(groups_dict) @@ -1657,6 +1708,7 @@ def open_groups_as_dict( storage_options=None, zarr_version=None, zarr_format=None, + coords_buffer_prototype: Any | None = None, ) -> dict[str, Dataset]: from xarray.core.treenode import NodePath @@ -1679,6 +1731,7 @@ def open_groups_as_dict( storage_options=storage_options, zarr_version=zarr_version, zarr_format=zarr_format, + coords_buffer_prototype=coords_buffer_prototype, ) groups_dict = {} diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 83d5afa6a09..a7560630d98 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3766,6 +3766,37 @@ def test_zarr_version_deprecated() -> None: xr.open_zarr(store=store, zarr_version=2, zarr_format=3) +@requires_zarr +def test_coords_buffer_prototype() -> None: + pytest.importorskip("zarr", minversion="3") + + from zarr.core.buffer import cpu + from zarr.core.buffer.core import BufferPrototype + + counter = 0 + + class Buffer(cpu.Buffer): + def __init__(self, *args, **kwargs): + nonlocal counter + counter += 1 + super().__init__(*args, **kwargs) + + class NDBuffer(cpu.NDBuffer): + def __init__(self, *args, **kwargs): + nonlocal counter + counter += 1 + super().__init__(*args, **kwargs) + + prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer) + + ds = create_test_data() + store = KVStore() + ds.to_zarr(store=store, zarr_format=3) + + xr.open_dataset(store, engine="zarr", coords_buffer_prototype=prototype) + assert counter > 0 + + @requires_scipy class TestScipyInMemoryData(CFEncodedBase, NetCDF3Only): engine: T_NetcdfEngine = "scipy" From 335fa154a154b5dfd19cbf39d3dc1db181f4aba7 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 27 Feb 2025 15:06:11 -0800 Subject: [PATCH 3/6] mypy --- xarray/tests/test_backends.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index a7560630d98..eb56974c664 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3791,9 +3791,10 @@ def __init__(self, *args, **kwargs): ds = create_test_data() store = KVStore() - ds.to_zarr(store=store, zarr_format=3) + # type-ignore for zarr v2/v3 compat, even though this test is skipped for v2 + ds.to_zarr(store=store, zarr_format=3) # type: ignore[call-overload] - xr.open_dataset(store, engine="zarr", coords_buffer_prototype=prototype) + xr.open_dataset(store, engine="zarr", coords_buffer_prototype=prototype) # type: ignore[arg-type] assert counter > 0 From 83baf49ba79b0b984ca52adec60060eca5d25bbe Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Thu, 27 Feb 2025 15:11:42 -0800 Subject: [PATCH 4/6] remove unused type --- xarray/tests/test_backends.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index eb56974c664..86802883d59 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3794,7 +3794,7 @@ def __init__(self, *args, **kwargs): # type-ignore for zarr v2/v3 compat, even though this test is skipped for v2 ds.to_zarr(store=store, zarr_format=3) # type: ignore[call-overload] - xr.open_dataset(store, engine="zarr", coords_buffer_prototype=prototype) # type: ignore[arg-type] + xr.open_dataset(store, engine="zarr", coords_buffer_prototype=prototype) assert counter > 0 From 2788300434e05d4110f894670827739b5c950f4c Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Fri, 28 Feb 2025 14:37:15 -0800 Subject: [PATCH 5/6] mypy fixes --- xarray/tests/test_backends.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 86802883d59..ecada9de07f 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3792,9 +3792,10 @@ def __init__(self, *args, **kwargs): ds = create_test_data() store = KVStore() # type-ignore for zarr v2/v3 compat, even though this test is skipped for v2 - ds.to_zarr(store=store, zarr_format=3) # type: ignore[call-overload] - - xr.open_dataset(store, engine="zarr", coords_buffer_prototype=prototype) + # Because we run mypy with both zarr 2 and 3, we can't easily specify + # the specific codes to ignore. + ds.to_zarr(store=store, zarr_format=3) # type: ignore # noqa: PGH003 + xr.open_dataset(store, engine="zarr", coords_buffer_prototype=prototype) # type: ignore # noqa: PGH003 assert counter > 0 From ecf56f884f0b3e34c117438a59cf36ffdc81f2c7 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Fri, 28 Feb 2025 15:04:13 -0800 Subject: [PATCH 6/6] more my fixes --- xarray/tests/test_backends.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index ecada9de07f..9922530b453 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -3792,10 +3792,8 @@ def __init__(self, *args, **kwargs): ds = create_test_data() store = KVStore() # type-ignore for zarr v2/v3 compat, even though this test is skipped for v2 - # Because we run mypy with both zarr 2 and 3, we can't easily specify - # the specific codes to ignore. - ds.to_zarr(store=store, zarr_format=3) # type: ignore # noqa: PGH003 - xr.open_dataset(store, engine="zarr", coords_buffer_prototype=prototype) # type: ignore # noqa: PGH003 + ds.to_zarr(store=store, zarr_format=3) # type: ignore[call-overload, unused-ignore] + xr.open_dataset(store, engine="zarr", coords_buffer_prototype=prototype) # type: ignore[arg-type, unused-ignore] assert counter > 0