@@ -486,7 +486,7 @@ def identical(self, other: Self) -> bool:
486
486
return self .to_dataset ().identical (other .to_dataset ())
487
487
488
488
def _update_coords (
489
- self , coords : dict [Hashable , Variable ], indexes : Mapping [ Any , Index ]
489
+ self , coords : dict [Hashable , Variable ], indexes : dict [ Hashable , Index ]
490
490
) -> None :
491
491
# redirect to DatasetCoordinates._update_coords
492
492
self ._data .coords ._update_coords (coords , indexes )
@@ -780,7 +780,7 @@ def to_dataset(self) -> Dataset:
780
780
return self ._data ._copy_listed (names )
781
781
782
782
def _update_coords (
783
- self , coords : dict [Hashable , Variable ], indexes : Mapping [ Any , Index ]
783
+ self , coords : dict [Hashable , Variable ], indexes : dict [ Hashable , Index ]
784
784
) -> None :
785
785
variables = self ._data ._variables .copy ()
786
786
variables .update (coords )
@@ -880,7 +880,7 @@ def to_dataset(self) -> Dataset:
880
880
return self ._data .dataset ._copy_listed (self ._names )
881
881
882
882
def _update_coords (
883
- self , coords : dict [Hashable , Variable ], indexes : Mapping [ Any , Index ]
883
+ self , coords : dict [Hashable , Variable ], indexes : dict [ Hashable , Index ]
884
884
) -> None :
885
885
from xarray .core .datatree import check_alignment
886
886
@@ -964,22 +964,14 @@ def __getitem__(self, key: Hashable) -> T_DataArray:
964
964
return self ._data ._getitem_coord (key )
965
965
966
966
def _update_coords (
967
- self , coords : dict [Hashable , Variable ], indexes : Mapping [ Any , Index ]
967
+ self , coords : dict [Hashable , Variable ], indexes : dict [ Hashable , Index ]
968
968
) -> None :
969
- coords_plus_data = coords .copy ()
970
- coords_plus_data [_THIS_ARRAY ] = self ._data .variable
971
- dims = calculate_dimensions (coords_plus_data )
972
- if not set (dims ) <= set (self .dims ):
973
- raise ValueError (
974
- "cannot add coordinates with new dimensions to a DataArray"
975
- )
976
- self ._data ._coords = coords
969
+ validate_dataarray_coords (
970
+ self ._data .shape , Coordinates ._construct_direct (coords , indexes ), self .dims
971
+ )
977
972
978
- # TODO(shoyer): once ._indexes is always populated by a dict, modify
979
- # it to update inplace instead.
980
- original_indexes = dict (self ._data .xindexes )
981
- original_indexes .update (indexes )
982
- self ._data ._indexes = original_indexes
973
+ self ._data ._coords = coords
974
+ self ._data ._indexes = indexes
983
975
984
976
def _drop_coords (self , coord_names ):
985
977
# should drop indexed coordinates only
@@ -1154,9 +1146,58 @@ def create_coords_with_default_indexes(
1154
1146
return new_coords
1155
1147
1156
1148
1157
- def _coordinates_from_variable (variable : Variable ) -> Coordinates :
1158
- from xarray .core .indexes import create_default_index_implicit
1149
+ class CoordinateValidationError (ValueError ):
1150
+ """Error class for Xarray coordinate validation failures."""
1151
+
1152
+
1153
+ def validate_dataarray_coords (
1154
+ shape : tuple [int , ...],
1155
+ coords : Coordinates | Mapping [Hashable , Variable ],
1156
+ dim : tuple [Hashable , ...],
1157
+ ):
1158
+ """Validate coordinates ``coords`` to include in a DataArray defined by
1159
+ ``shape`` and dimensions ``dim``.
1160
+
1161
+ If a coordinate is associated with an index, the validation is performed by
1162
+ the index. By default the coordinate dimensions must match (a subset of) the
1163
+ array dimensions (in any order) to conform to the DataArray model. The index
1164
+ may override this behavior with other validation rules, though.
1165
+
1166
+ Non-index coordinates must all conform to the DataArray model. Scalar
1167
+ coordinates are always valid.
1168
+ """
1169
+ sizes = dict (zip (dim , shape , strict = True ))
1170
+ dim_set = set (dim )
1171
+
1172
+ indexes : Mapping [Hashable , Index ]
1173
+ if isinstance (coords , Coordinates ):
1174
+ indexes = coords .xindexes
1175
+ else :
1176
+ indexes = {}
1177
+
1178
+ for k , v in coords .items ():
1179
+ if k in indexes :
1180
+ invalid = not indexes [k ].should_add_coord_to_array (k , v , dim_set )
1181
+ else :
1182
+ invalid = any (d not in dim for d in v .dims )
1183
+
1184
+ if invalid :
1185
+ raise CoordinateValidationError (
1186
+ f"coordinate { k } has dimensions { v .dims } , but these "
1187
+ "are not a subset of the DataArray "
1188
+ f"dimensions { dim } "
1189
+ )
1190
+
1191
+ for d , s in v .sizes .items ():
1192
+ if d in sizes and s != sizes [d ]:
1193
+ raise CoordinateValidationError (
1194
+ f"conflicting sizes for dimension { d !r} : "
1195
+ f"length { sizes [d ]} on the data but length { s } on "
1196
+ f"coordinate { k !r} "
1197
+ )
1198
+
1159
1199
1200
+ def coordinates_from_variable (variable : Variable ) -> Coordinates :
1160
1201
(name ,) = variable .dims
1161
1202
new_index , index_vars = create_default_index_implicit (variable )
1162
1203
indexes = dict .fromkeys (index_vars , new_index )
0 commit comments