Skip to content

Commit 712afb7

Browse files
committed
Fix nan masking in prepare_data_for_ml, add read_data_for_evaluation utility function
1 parent 2cc7bf9 commit 712afb7

File tree

1 file changed

+70
-1
lines changed

1 file changed

+70
-1
lines changed

eis_toolkit/prediction/machine_learning_general.py

+70-1
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,14 @@ def _read_and_stack_feature_raster(filepath: Union[str, os.PathLike]) -> Tuple[n
178178
raster_reshaped = raster.reshape(raster.shape[0], -1).T
179179
reshaped_data.append(raster_reshaped)
180180

181+
nan_mask = (raster_reshaped == np.nan).any(axis=1)
182+
combined_mask = nan_mask if nodata_mask is None else nodata_mask | nan_mask
183+
181184
if nodata is not None:
182185
raster_mask = (raster_reshaped == nodata).any(axis=1)
183-
nodata_mask = raster_mask if nodata_mask is None else nodata_mask | raster_mask
186+
combined_mask = combined_mask | raster_mask
187+
188+
nodata_mask = combined_mask
184189

185190
X = np.concatenate(reshaped_data, axis=1)
186191

@@ -197,6 +202,7 @@ def _read_and_stack_feature_raster(filepath: Union[str, os.PathLike]) -> Tuple[n
197202
with rasterio.open(label_file) as label_raster:
198203
y = label_raster.read(1) # Assuming labels are in the first band
199204
label_nodata = label_raster.nodata
205+
profiles = list(profiles)
200206
profiles.append(label_raster.profile)
201207
if not check_raster_grids(profiles, same_extent=True):
202208
raise NonMatchingRasterMetadataException(
@@ -218,6 +224,69 @@ def _read_and_stack_feature_raster(filepath: Union[str, os.PathLike]) -> Tuple[n
218224
return X, y, reference_profile, nodata_mask
219225

220226

227+
@beartype
228+
def read_data_for_evaluation(
229+
rasters: Sequence[Union[str, os.PathLike]]
230+
) -> Tuple[Sequence[np.ndarray], rasterio.profiles.Profile, Any]:
231+
"""
232+
Prepare data ready for evaluating modeling outputs.
233+
234+
Reads all rasters (only first band), reshapes them (flattens) and masks out all NaN
235+
and nodata pixels by creating a combined mask from all input rasters.
236+
237+
Args:
238+
rasters: List of filepaths of input rasters. Files should only include raster that have
239+
the same grid properties and extent.
240+
241+
Returns:
242+
List of reshaped and masked raster data.
243+
Refrence raster profile.
244+
Nodata mask applied to raster data.
245+
246+
Raises:
247+
InvalidDatasetException: Input rasters contains only one path.
248+
NonMatchingRasterMetadataException: Input rasters don't have same grid properties.
249+
"""
250+
if len(rasters) < 2:
251+
raise InvalidDatasetException(f"Expected more than one raster file: {len(rasters)}.")
252+
253+
profiles = []
254+
raster_data = []
255+
nodata_values = []
256+
257+
for raster in rasters:
258+
with rasterio.open(raster) as src:
259+
data = src.read(1)
260+
profile = src.profile
261+
profiles.append(profile)
262+
raster_data.append(data)
263+
nodata_values.append(profile.get("nodata"))
264+
265+
if not check_raster_grids(profiles, same_extent=True):
266+
raise NonMatchingRasterMetadataException(f"Input rasters should have the same grid properties: {profiles}.")
267+
268+
reference_profile = profiles[0]
269+
nodata_mask = None
270+
271+
for data, nodata in zip(raster_data, nodata_values):
272+
nan_mask = np.isnan(data)
273+
combined_mask = nan_mask if nodata_mask is None else nodata_mask | nan_mask
274+
275+
if nodata is not None:
276+
raster_mask = data == nodata
277+
combined_mask = combined_mask | raster_mask
278+
279+
nodata_mask = combined_mask
280+
nodata_mask = nodata_mask.flatten()
281+
282+
masked_data = []
283+
for data in raster_data:
284+
flattened_data = data.flatten()
285+
masked_data.append(flattened_data[~nodata_mask])
286+
287+
return masked_data, reference_profile, nodata_mask
288+
289+
221290
@beartype
222291
def _train_and_validate_sklearn_model(
223292
X: Union[np.ndarray, pd.DataFrame],

0 commit comments

Comments
 (0)