Skip to content

Commit e48965f

Browse files
Basic support to use simple strings to name the bands
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 64dcf5d commit e48965f

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

terratorch/datasets/generic_pixel_wise_dataset.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,18 @@ def __init__(
130130
else:
131131
self.dataset_bands = dataset_bands
132132
self.output_bands = output_bands
133+
134+
bands_type = self._bands_as_int_or_str(dataset_bands, output_bands)
135+
136+
if bands_type == str:
137+
raise UserWarning("When the bands are defined as str, guarantee your input files"+
138+
"are organized by band and all have its specific name.")
133139

134140
if self.output_bands and not self.dataset_bands:
135141
msg = "If output bands provided, dataset_bands must also be provided"
136142
return Exception(msg) # noqa: PLE0101
137143

144+
# There is a special condition if the bands are defined as simple strings.
138145
if self.output_bands:
139146
if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
140147
msg = "Output bands must be a subset of dataset bands"
@@ -183,18 +190,33 @@ def _generate_bands_intervals(self, bands_intervals:List[List[int]] = None):
183190
bands.append(bands_sublist)
184191
return sorted(sum(bands, []))
185192

193+
def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type:
194+
195+
band_type = [None, None]
196+
for b, bands_list in enumerate([dataset_bands, output_bands]):
197+
if all([type(band)==int for band in bands_list]):
198+
band_type[b] = int
199+
elif all([type(band)==str for band in bands_list]):
200+
band_type[b] = str
201+
else:
202+
pass
203+
if band_type.cound(band_type[0]) == len(band_type)
204+
return band_type[0]
205+
else:
206+
raise Exception("The bands must be or all str or all int.")
207+
186208
def _bands_defined_by_interval(self, bands_list: list[int] | list[list[int]] = None) -> bool:
187209
if not bands_list:
188210
return False
189-
elif all([type(band)==int or isinstance(band, HLSBands) for band in bands_list]):
211+
elif all([type(band)==int or type(band)==str or isinstance(band, HLSBands) for band in bands_list]):
190212
return False
191213
elif all([isinstance(subinterval, list) for subinterval in bands_list]):
192214
if all([type(band)==int for band in sum(bands_list, [])]):
193215
return True
194216
else:
195217
raise Exception(f"Whe using subintervals, the limits must be int.")
196218
else:
197-
raise Exception(f"Excpected List[int] or List[List[int]], but received {type(bands_list)}.")
219+
raise Exception(f"Excpected List[int] or List[str] or List[List[int]], but received {type(bands_list)}.")
198220

199221
class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
200222
"""GenericNonGeoSegmentationDataset"""

0 commit comments

Comments
 (0)