Skip to content

Commit 4eee3c9

Browse files
committed
accept mixed band specifications
1 parent 5708e0a commit 4eee3c9

File tree

3 files changed

+184
-157
lines changed

3 files changed

+184
-157
lines changed

terratorch/datasets/generic_pixel_wise_dataset.py

+25-73
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
11
# Copyright contributors to the Terratorch project
22

3-
"""Module containing generic dataset classes
4-
"""
3+
"""Module containing generic dataset classes"""
4+
55
import glob
6+
import operator
67
import os
78
from abc import ABC
8-
from functools import partial
9-
from pathlib import Path
10-
from typing import Any, List, Union
119
from functools import reduce
12-
import operator
10+
from pathlib import Path
11+
from typing import Any
12+
1313
import albumentations as A
1414
import matplotlib as mpl
1515
import numpy as np
1616
import rioxarray
17-
import torch
1817
import xarray as xr
19-
from albumentations.pytorch import ToTensorV2
2018
from einops import rearrange
21-
from matplotlib import cm
2219
from matplotlib import pyplot as plt
2320
from matplotlib.figure import Figure
2421
from matplotlib.patches import Rectangle
@@ -122,15 +119,8 @@ def __init__(
122119
)
123120
self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices
124121

125-
is_bands_by_interval = self._check_if_its_defined_by_interval(dataset_bands, output_bands)
126-
127-
# If the bands are defined by sub-intervals or not.
128-
if is_bands_by_interval:
129-
self.dataset_bands = self._generate_bands_intervals(dataset_bands)
130-
self.output_bands = self._generate_bands_intervals(output_bands)
131-
else:
132-
self.dataset_bands = dataset_bands
133-
self.output_bands = output_bands
122+
self.dataset_bands = self._generate_bands_intervals(dataset_bands)
123+
self.output_bands = self._generate_bands_intervals(output_bands)
134124

135125
if self.output_bands and not self.dataset_bands:
136126
msg = "If output bands provided, dataset_bands must also be provided"
@@ -183,63 +173,25 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr
183173
data = data.fillna(nan_replace)
184174
return data
185175

186-
def _generate_bands_intervals(self, bands_intervals: List[List[int]] = None):
187-
bands = []
188-
for b_interval in bands_intervals:
189-
bands_sublist = list(range(b_interval[0], b_interval[1] + 1))
190-
bands.append(bands_sublist)
191-
return reduce(operator.iadd, bands, [])
192-
193-
def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type:
194-
195-
band_type = [None, None]
196-
if not dataset_bands and not output_bands:
176+
def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None):
177+
if bands_intervals is None:
197178
return None
198-
else:
199-
for b, bands_list in enumerate([dataset_bands, output_bands]):
200-
if all([type(band) == int for band in bands_list]):
201-
band_type[b] = int
202-
elif all([type(band) == str for band in bands_list]):
203-
band_type[b] = str
204-
else:
205-
pass
206-
if band_type.count(band_type[0]) == len(band_type):
207-
return band_type[0]
208-
else:
209-
raise Exception("The bands must be or all str or all int.")
210-
211-
def _check_if_its_defined_by_interval(
212-
self, dataset_bands: list[int] | list[tuple[int]] = None, output_bands: list[int] | list[tuple[int]] = None
213-
) -> bool:
214-
215-
is_dataset_bands_defined = self._bands_defined_by_interval(bands_list=dataset_bands)
216-
is_output_bands_defined = self._bands_defined_by_interval(bands_list=output_bands)
217-
218-
if is_dataset_bands_defined and is_output_bands_defined:
219-
return True
220-
elif not is_dataset_bands_defined and not is_output_bands_defined:
221-
return False
222-
else:
223-
raise Exception(
224-
f"Both dataset_bands and output_bands must have the same type, but received {dataset_bands} and {output_bands}"
225-
)
226-
227-
def _bands_defined_by_interval(self, bands_list: list[int] | list[tuple[int]] = None) -> bool:
228-
if not bands_list:
229-
return False
230-
elif all([type(band) == int or type(band) == str or isinstance(band, HLSBands) for band in bands_list]):
231-
return False
232-
elif all([isinstance(subinterval, tuple) for subinterval in bands_list]):
233-
bands_list_ = [list(subinterval) for subinterval in bands_list]
234-
if all([type(band) == int for band in sum(bands_list_, [])]):
235-
return True
179+
bands = []
180+
for element in bands_intervals:
181+
# if its an interval
182+
if isinstance(element, tuple):
183+
if len(element) != 2: # noqa: PLR2004
184+
msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive"
185+
raise Exception(msg)
186+
expanded_element = list(range(element[0], element[1] + 1))
187+
bands.extend(expanded_element)
236188
else:
237-
raise Exception(f"Whe using subintervals, the limits must be int.")
238-
else:
239-
raise Exception(
240-
f"Excpected List[int] or List[str] or List[tuple[int, int]], but received {type(bands_list)}."
241-
)
242-
189+
bands.append(element)
190+
# check the expansion didnt result in duplicate elements
191+
if len(set(bands)) != len(bands):
192+
msg = "Duplicate indices detected. Indices must be unique."
193+
raise Exception(msg)
194+
return bands
243195

244196
class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
245197
"""GenericNonGeoSegmentationDataset"""

terratorch/datasets/generic_scalar_label_dataset.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -110,17 +110,21 @@ def is_valid_file(x):
110110

111111
self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices
112112

113-
self.dataset_bands = dataset_bands
114-
self.output_bands = output_bands
113+
self.dataset_bands = self._generate_bands_intervals(dataset_bands)
114+
self.output_bands = self._generate_bands_intervals(output_bands)
115+
115116
if self.output_bands and not self.dataset_bands:
116117
msg = "If output bands provided, dataset_bands must also be provided"
117118
return Exception(msg) # noqa: PLE0101
118119

120+
# There is a special condition if the bands are defined as simple strings.
119121
if self.output_bands:
120122
if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
121123
msg = "Output bands must be a subset of dataset bands"
122124
raise Exception(msg)
125+
123126
self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]
127+
124128
else:
125129
self.filter_indices = None
126130
# If no transform is given, apply only to transform to torch tensor

0 commit comments

Comments
 (0)