Skip to content

Commit e42aa6f

Browse files
benbovydcherianpre-commit-ci[bot]
authored
Add Index.validate_dataarray_coord (#10137)
Co-authored-by: Deepak Cherian <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 649e830 commit e42aa6f

13 files changed

+209
-60
lines changed

doc/api-hidden.rst

+1
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,7 @@
520520
Index.stack
521521
Index.unstack
522522
Index.create_variables
523+
Index.should_add_coord_to_array
523524
Index.to_pandas_index
524525
Index.isel
525526
Index.sel

doc/api.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1645,6 +1645,7 @@ Exceptions
16451645
:toctree: generated/
16461646

16471647
AlignmentError
1648+
CoordinateValidationError
16481649
MergeError
16491650
SerializationWarning
16501651

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ New Features
3737
- Improved compatibility with OPeNDAP DAP4 data model for backend engine ``pydap``. This
3838
includes ``datatree`` support, and removing slashes from dimension names. By
3939
`Miguel Jimenez-Urias <https://github.com/Mikejmnez>`_.
40+
- Allow assigning index coordinates with non-array dimension(s) in a :py:class:`DataArray` by overriding
41+
:py:meth:`Index.should_add_coord_to_array`. For example, this enables support for CF boundaries coordinate (e.g.,
42+
``time(time)`` and ``time_bnds(time, nbnd)``) in a DataArray (:pull:`10137`).
43+
By `Benoit Bovy <https://github.com/benbovy>`_.
4044
- Improved support pandas categorical extension as indices (i.e., :py:class:`pandas.IntervalIndex`). (:issue:`9661`, :pull:`9671`)
4145
By `Ilan Gold <https://github.com/ilan-gold>`_.
4246
- Improved checks and errors raised when trying to align objects with conflicting indexes.

xarray/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929
from xarray.conventions import SerializationWarning, decode_cf
3030
from xarray.core.common import ALL_DIMS, full_like, ones_like, zeros_like
31-
from xarray.core.coordinates import Coordinates
31+
from xarray.core.coordinates import Coordinates, CoordinateValidationError
3232
from xarray.core.dataarray import DataArray
3333
from xarray.core.dataset import Dataset
3434
from xarray.core.datatree import DataTree
@@ -129,6 +129,7 @@
129129
"Variable",
130130
# Exceptions
131131
"AlignmentError",
132+
"CoordinateValidationError",
132133
"InvalidTreeError",
133134
"MergeError",
134135
"NotFoundInTreeError",

xarray/core/coordinates.py

+60-19
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool:
486486
return self.to_dataset().identical(other.to_dataset())
487487

488488
def _update_coords(
489-
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
489+
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
490490
) -> None:
491491
# redirect to DatasetCoordinates._update_coords
492492
self._data.coords._update_coords(coords, indexes)
@@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset:
780780
return self._data._copy_listed(names)
781781

782782
def _update_coords(
783-
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
783+
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
784784
) -> None:
785785
variables = self._data._variables.copy()
786786
variables.update(coords)
@@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset:
880880
return self._data.dataset._copy_listed(self._names)
881881

882882
def _update_coords(
883-
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
883+
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
884884
) -> None:
885885
from xarray.core.datatree import check_alignment
886886

@@ -964,22 +964,14 @@ def __getitem__(self, key: Hashable) -> T_DataArray:
964964
return self._data._getitem_coord(key)
965965

966966
def _update_coords(
967-
self, coords: dict[Hashable, Variable], indexes: Mapping[Any, Index]
967+
self, coords: dict[Hashable, Variable], indexes: dict[Hashable, Index]
968968
) -> None:
969-
coords_plus_data = coords.copy()
970-
coords_plus_data[_THIS_ARRAY] = self._data.variable
971-
dims = calculate_dimensions(coords_plus_data)
972-
if not set(dims) <= set(self.dims):
973-
raise ValueError(
974-
"cannot add coordinates with new dimensions to a DataArray"
975-
)
976-
self._data._coords = coords
969+
validate_dataarray_coords(
970+
self._data.shape, Coordinates._construct_direct(coords, indexes), self.dims
971+
)
977972

978-
# TODO(shoyer): once ._indexes is always populated by a dict, modify
979-
# it to update inplace instead.
980-
original_indexes = dict(self._data.xindexes)
981-
original_indexes.update(indexes)
982-
self._data._indexes = original_indexes
973+
self._data._coords = coords
974+
self._data._indexes = indexes
983975

984976
def _drop_coords(self, coord_names):
985977
# should drop indexed coordinates only
@@ -1154,9 +1146,58 @@ def create_coords_with_default_indexes(
11541146
return new_coords
11551147

11561148

1157-
def _coordinates_from_variable(variable: Variable) -> Coordinates:
1158-
from xarray.core.indexes import create_default_index_implicit
1149+
class CoordinateValidationError(ValueError):
1150+
"""Error class for Xarray coordinate validation failures."""
1151+
1152+
1153+
def validate_dataarray_coords(
1154+
shape: tuple[int, ...],
1155+
coords: Coordinates | Mapping[Hashable, Variable],
1156+
dim: tuple[Hashable, ...],
1157+
):
1158+
"""Validate coordinates ``coords`` to include in a DataArray defined by
1159+
``shape`` and dimensions ``dim``.
1160+
1161+
If a coordinate is associated with an index, the validation is performed by
1162+
the index. By default the coordinate dimensions must match (a subset of) the
1163+
array dimensions (in any order) to conform to the DataArray model. The index
1164+
may override this behavior with other validation rules, though.
1165+
1166+
Non-index coordinates must all conform to the DataArray model. Scalar
1167+
coordinates are always valid.
1168+
"""
1169+
sizes = dict(zip(dim, shape, strict=True))
1170+
dim_set = set(dim)
1171+
1172+
indexes: Mapping[Hashable, Index]
1173+
if isinstance(coords, Coordinates):
1174+
indexes = coords.xindexes
1175+
else:
1176+
indexes = {}
1177+
1178+
for k, v in coords.items():
1179+
if k in indexes:
1180+
invalid = not indexes[k].should_add_coord_to_array(k, v, dim_set)
1181+
else:
1182+
invalid = any(d not in dim for d in v.dims)
1183+
1184+
if invalid:
1185+
raise CoordinateValidationError(
1186+
f"coordinate {k} has dimensions {v.dims}, but these "
1187+
"are not a subset of the DataArray "
1188+
f"dimensions {dim}"
1189+
)
1190+
1191+
for d, s in v.sizes.items():
1192+
if d in sizes and s != sizes[d]:
1193+
raise CoordinateValidationError(
1194+
f"conflicting sizes for dimension {d!r}: "
1195+
f"length {sizes[d]} on the data but length {s} on "
1196+
f"coordinate {k!r}"
1197+
)
1198+
11591199

1200+
def coordinates_from_variable(variable: Variable) -> Coordinates:
11601201
(name,) = variable.dims
11611202
new_index, index_vars = create_default_index_implicit(variable)
11621203
indexes = dict.fromkeys(index_vars, new_index)

xarray/core/dataarray.py

+2-20
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
DataArrayCoordinates,
3434
assert_coordinate_consistent,
3535
create_coords_with_default_indexes,
36+
validate_dataarray_coords,
3637
)
3738
from xarray.core.dataset import Dataset
3839
from xarray.core.extension_array import PandasExtensionArray
@@ -124,25 +125,6 @@
124125
T_XarrayOther = TypeVar("T_XarrayOther", bound="DataArray" | Dataset)
125126

126127

127-
def _check_coords_dims(shape, coords, dim):
128-
sizes = dict(zip(dim, shape, strict=True))
129-
for k, v in coords.items():
130-
if any(d not in dim for d in v.dims):
131-
raise ValueError(
132-
f"coordinate {k} has dimensions {v.dims}, but these "
133-
"are not a subset of the DataArray "
134-
f"dimensions {dim}"
135-
)
136-
137-
for d, s in v.sizes.items():
138-
if s != sizes[d]:
139-
raise ValueError(
140-
f"conflicting sizes for dimension {d!r}: "
141-
f"length {sizes[d]} on the data but length {s} on "
142-
f"coordinate {k!r}"
143-
)
144-
145-
146128
def _infer_coords_and_dims(
147129
shape: tuple[int, ...],
148130
coords: (
@@ -206,7 +188,7 @@ def _infer_coords_and_dims(
206188
var.dims = (dim,)
207189
new_coords[dim] = var.to_index_variable()
208190

209-
_check_coords_dims(shape, new_coords, dims_tuple)
191+
validate_dataarray_coords(shape, new_coords, dims_tuple)
210192

211193
return new_coords, dims_tuple
212194

xarray/core/dataset.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,15 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
11591159
coords: dict[Hashable, Variable] = {}
11601160
# preserve ordering
11611161
for k in self._variables:
1162-
if k in self._coord_names and set(self._variables[k].dims) <= needed_dims:
1162+
if k in self._indexes:
1163+
add_coord = self._indexes[k].should_add_coord_to_array(
1164+
k, self._variables[k], needed_dims
1165+
)
1166+
else:
1167+
var_dims = set(self._variables[k].dims)
1168+
add_coord = k in self._coord_names and var_dims <= needed_dims
1169+
1170+
if add_coord:
11631171
coords[k] = self._variables[k]
11641172

11651173
indexes = filter_indexes_from_coords(self._indexes, set(coords))

xarray/core/groupby.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
DatasetGroupByAggregations,
2424
)
2525
from xarray.core.common import ImplementsArrayReduce, ImplementsDatasetReduce
26-
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
26+
from xarray.core.coordinates import Coordinates, coordinates_from_variable
2727
from xarray.core.duck_array_ops import where
2828
from xarray.core.formatting import format_array_flat
2929
from xarray.core.indexes import (
@@ -1147,7 +1147,7 @@ def _flox_reduce(
11471147
new_coords.append(
11481148
# Using IndexVariable here ensures we reconstruct PandasMultiIndex with
11491149
# all associated levels properly.
1150-
_coordinates_from_variable(
1150+
coordinates_from_variable(
11511151
IndexVariable(
11521152
dims=grouper.name,
11531153
data=output_index,

xarray/core/indexes.py

+43
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,49 @@ def create_variables(
196196
else:
197197
return {}
198198

199+
def should_add_coord_to_array(
200+
self,
201+
name: Hashable,
202+
var: Variable,
203+
dims: set[Hashable],
204+
) -> bool:
205+
"""Define whether or not an index coordinate variable should be added to
206+
a new DataArray.
207+
208+
This method is called repeatedly for each Variable associated with this
209+
index when creating a new DataArray (via its constructor or from a
210+
Dataset) or updating an existing one. The variables associated with this
211+
index are the ones passed to :py:meth:`Index.from_variables` and/or
212+
returned by :py:meth:`Index.create_variables`.
213+
214+
By default returns ``True`` if the dimensions of the coordinate variable
215+
are a subset of the array dimensions and ``False`` otherwise (DataArray
216+
model). This default behavior may be overridden in Index subclasses to
217+
bypass strict conformance with the DataArray model. This is useful for
218+
example to include the (n+1)-dimensional cell boundary coordinate
219+
associated with an interval index.
220+
221+
Returning ``False`` will either:
222+
223+
- raise a :py:class:`CoordinateValidationError` when passing the
224+
coordinate directly to a new or an existing DataArray, e.g., via
225+
``DataArray.__init__()`` or ``DataArray.assign_coords()``
226+
227+
- drop the coordinate (and therefore drop the index) when a new
228+
DataArray is constructed by indexing a Dataset
229+
230+
Parameters
231+
----------
232+
name : Hashable
233+
Name of a coordinate variable associated to this index.
234+
var : Variable
235+
Coordinate variable object.
236+
dims: tuple
237+
Dimensions of the new DataArray object being created.
238+
239+
"""
240+
return all(d in dims for d in var.dims)
241+
199242
def to_pandas_index(self) -> pd.Index:
200243
"""Cast this xarray index to a pandas.Index object or raise a
201244
``TypeError`` if this is not supported.

xarray/groupers.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
2020
from xarray.computation.apply_ufunc import apply_ufunc
21-
from xarray.core.coordinates import Coordinates, _coordinates_from_variable
21+
from xarray.core.coordinates import Coordinates, coordinates_from_variable
2222
from xarray.core.dataarray import DataArray
2323
from xarray.core.duck_array_ops import array_all, isnull
2424
from xarray.core.groupby import T_Group, _DummyGroup
@@ -115,7 +115,7 @@ def __init__(
115115

116116
if coords is None:
117117
assert not isinstance(self.unique_coord, _DummyGroup)
118-
self.coords = _coordinates_from_variable(self.unique_coord)
118+
self.coords = coordinates_from_variable(self.unique_coord)
119119
else:
120120
self.coords = coords
121121

@@ -252,7 +252,7 @@ def _factorize_unique(self) -> EncodedGroups:
252252
codes=codes,
253253
full_index=full_index,
254254
unique_coord=unique_coord,
255-
coords=_coordinates_from_variable(unique_coord),
255+
coords=coordinates_from_variable(unique_coord),
256256
)
257257

258258
def _factorize_dummy(self) -> EncodedGroups:
@@ -280,7 +280,7 @@ def _factorize_dummy(self) -> EncodedGroups:
280280
else:
281281
if TYPE_CHECKING:
282282
assert isinstance(unique_coord, Variable)
283-
coords = _coordinates_from_variable(unique_coord)
283+
coords = coordinates_from_variable(unique_coord)
284284

285285
return EncodedGroups(
286286
codes=codes,
@@ -417,7 +417,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
417417
codes=codes,
418418
full_index=full_index,
419419
unique_coord=unique_coord,
420-
coords=_coordinates_from_variable(unique_coord),
420+
coords=coordinates_from_variable(unique_coord),
421421
)
422422

423423

@@ -551,7 +551,7 @@ def factorize(self, group: T_Group) -> EncodedGroups:
551551
group_indices=group_indices,
552552
full_index=full_index,
553553
unique_coord=unique_coord,
554-
coords=_coordinates_from_variable(unique_coord),
554+
coords=coordinates_from_variable(unique_coord),
555555
)
556556

557557

xarray/testing/assertions.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -401,12 +401,12 @@ def _assert_dataarray_invariants(da: DataArray, check_default_indexes: bool):
401401

402402
assert isinstance(da._coords, dict), da._coords
403403
assert all(isinstance(v, Variable) for v in da._coords.values()), da._coords
404-
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), (
405-
da.dims,
406-
{k: v.dims for k, v in da._coords.items()},
407-
)
408404

409405
if check_default_indexes:
406+
assert all(set(v.dims) <= set(da.dims) for v in da._coords.values()), (
407+
da.dims,
408+
{k: v.dims for k, v in da._coords.items()},
409+
)
410410
assert all(
411411
isinstance(v, IndexVariable)
412412
for (k, v) in da._coords.items()

0 commit comments

Comments
 (0)