@@ -178,9 +178,14 @@ def _read_and_stack_feature_raster(filepath: Union[str, os.PathLike]) -> Tuple[n
178
178
raster_reshaped = raster .reshape (raster .shape [0 ], - 1 ).T
179
179
reshaped_data .append (raster_reshaped )
180
180
181
+ nan_mask = (raster_reshaped == np .nan ).any (axis = 1 )
182
+ combined_mask = nan_mask if nodata_mask is None else nodata_mask | nan_mask
183
+
181
184
if nodata is not None :
182
185
raster_mask = (raster_reshaped == nodata ).any (axis = 1 )
183
- nodata_mask = raster_mask if nodata_mask is None else nodata_mask | raster_mask
186
+ combined_mask = combined_mask | raster_mask
187
+
188
+ nodata_mask = combined_mask
184
189
185
190
X = np .concatenate (reshaped_data , axis = 1 )
186
191
@@ -197,6 +202,7 @@ def _read_and_stack_feature_raster(filepath: Union[str, os.PathLike]) -> Tuple[n
197
202
with rasterio .open (label_file ) as label_raster :
198
203
y = label_raster .read (1 ) # Assuming labels are in the first band
199
204
label_nodata = label_raster .nodata
205
+ profiles = list (profiles )
200
206
profiles .append (label_raster .profile )
201
207
if not check_raster_grids (profiles , same_extent = True ):
202
208
raise NonMatchingRasterMetadataException (
@@ -218,6 +224,69 @@ def _read_and_stack_feature_raster(filepath: Union[str, os.PathLike]) -> Tuple[n
218
224
return X , y , reference_profile , nodata_mask
219
225
220
226
227
+ @beartype
228
+ def read_data_for_evaluation (
229
+ rasters : Sequence [Union [str , os .PathLike ]]
230
+ ) -> Tuple [Sequence [np .ndarray ], rasterio .profiles .Profile , Any ]:
231
+ """
232
+ Prepare data ready for evaluating modeling outputs.
233
+
234
+ Reads all rasters (only first band), reshapes them (flattens) and masks out all NaN
235
+ and nodata pixels by creating a combined mask from all input rasters.
236
+
237
+ Args:
238
+ rasters: List of filepaths of input rasters. Files should only include raster that have
239
+ the same grid properties and extent.
240
+
241
+ Returns:
242
+ List of reshaped and masked raster data.
243
+ Refrence raster profile.
244
+ Nodata mask applied to raster data.
245
+
246
+ Raises:
247
+ InvalidDatasetException: Input rasters contains only one path.
248
+ NonMatchingRasterMetadataException: Input rasters don't have same grid properties.
249
+ """
250
+ if len (rasters ) < 2 :
251
+ raise InvalidDatasetException (f"Expected more than one raster file: { len (rasters )} ." )
252
+
253
+ profiles = []
254
+ raster_data = []
255
+ nodata_values = []
256
+
257
+ for raster in rasters :
258
+ with rasterio .open (raster ) as src :
259
+ data = src .read (1 )
260
+ profile = src .profile
261
+ profiles .append (profile )
262
+ raster_data .append (data )
263
+ nodata_values .append (profile .get ("nodata" ))
264
+
265
+ if not check_raster_grids (profiles , same_extent = True ):
266
+ raise NonMatchingRasterMetadataException (f"Input rasters should have the same grid properties: { profiles } ." )
267
+
268
+ reference_profile = profiles [0 ]
269
+ nodata_mask = None
270
+
271
+ for data , nodata in zip (raster_data , nodata_values ):
272
+ nan_mask = np .isnan (data )
273
+ combined_mask = nan_mask if nodata_mask is None else nodata_mask | nan_mask
274
+
275
+ if nodata is not None :
276
+ raster_mask = data == nodata
277
+ combined_mask = combined_mask | raster_mask
278
+
279
+ nodata_mask = combined_mask
280
+ nodata_mask = nodata_mask .flatten ()
281
+
282
+ masked_data = []
283
+ for data in raster_data :
284
+ flattened_data = data .flatten ()
285
+ masked_data .append (flattened_data [~ nodata_mask ])
286
+
287
+ return masked_data , reference_profile , nodata_mask
288
+
289
+
221
290
@beartype
222
291
def _train_and_validate_sklearn_model (
223
292
X : Union [np .ndarray , pd .DataFrame ],
0 commit comments