Skip to content

Commit 2888252

Browse files
More compact way to check if the bands are defined by interval
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 0a84c42 commit 2888252

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

terratorch/datasets/generic_pixel_wise_dataset.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,10 @@ def __init__(
121121
)
122122
self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices
123123

124-
bands_by_interval = (self._bands_defined_by_interval(bands_list=dataset_bands) and
125-
self._bands_defined_by_interval(bands_list=output_bands))
124+
is_bands_by_interval = self._check_if_its_defined_by_interval(dataset_bands, output_bands)
126125

127126
# If the bands are defined by sub-intervals or not.
128-
if bands_by_interval:
127+
if is_bands_by_interval:
129128
self.dataset_bands = self._generate_bands_intervals(dataset_bands)
130129
self.output_bands = self._generate_bands_intervals(output_bands)
131130
else:
@@ -212,6 +211,19 @@ def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type:
212211
else:
213212
raise Exception("The bands must be or all str or all int.")
214213

214+
def _check_if_its_defined_by_interval(self, dataset_bands: list[int] | list[tuple[int]] = None,
215+
output_bands: list[int] | list[tuple[int]] = None) -> bool:
216+
217+
is_dataset_bands_defined = self._bands_defined_by_interval(bands_list=dataset_bands)
218+
is_output_bands_defined = self._bands_defined_by_interval(bands_list=output_bands)
219+
220+
if is_dataset_bands_defined and is_output_bands_defined:
221+
return True
222+
elif not is_dataset_bands_defined and not is_output_bands_defined:
223+
return False
224+
else:
225+
raise Exception(f"Both dataset_bands and output_bands must have the same type, but received {dataset_bands} and {output_bands}")
226+
215227
def _bands_defined_by_interval(self, bands_list: list[int] | list[tuple[int]] = None) -> bool:
216228
if not bands_list:
217229
return False
@@ -224,7 +236,6 @@ def _bands_defined_by_interval(self, bands_list: list[int] | list[tuple[int]] =
224236
else:
225237
raise Exception(f"Whe using subintervals, the limits must be int.")
226238
else:
227-
print(bands_list)
228239
raise Exception(f"Excpected List[int] or List[str] or List[tuple[int, int]], but received {type(bands_list)}.")
229240

230241
class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):

0 commit comments

Comments
 (0)