Skip to content

Commit 989bf80

Browse files
Exception for None inputs
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 174f2f1 commit 989bf80

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

terratorch/datasets/generic_pixel_wise_dataset.py

+14-10
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
174174
"mask": self._load_file(self.segmentation_mask_files[index], nan_replace = self.no_label_replace).to_numpy()[0],
175175
"filename": self.image_files[index],
176176
}
177+
177178
if self.reduce_zero_label:
178179
output["mask"] -= 1
179180
if self.transform:
@@ -196,17 +197,20 @@ def _generate_bands_intervals(self, bands_intervals:List[List[int]] = None):
196197
def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type:
197198

198199
band_type = [None, None]
199-
for b, bands_list in enumerate([dataset_bands, output_bands]):
200-
if all([type(band)==int for band in bands_list]):
201-
band_type[b] = int
202-
elif all([type(band)==str for band in bands_list]):
203-
band_type[b] = str
204-
else:
205-
pass
206-
if band_type.count(band_type[0]) == len(band_type):
207-
return band_type[0]
200+
if not dataset_bands and not output_bands:
201+
return None
208202
else:
209-
raise Exception("The bands must be or all str or all int.")
203+
for b, bands_list in enumerate([dataset_bands, output_bands]):
204+
if all([type(band)==int for band in bands_list]):
205+
band_type[b] = int
206+
elif all([type(band)==str for band in bands_list]):
207+
band_type[b] = str
208+
else:
209+
pass
210+
if band_type.count(band_type[0]) == len(band_type):
211+
return band_type[0]
212+
else:
213+
raise Exception("The bands must be or all str or all int.")
210214

211215
def _bands_defined_by_interval(self, bands_list: list[int] | list[list[int]] = None) -> bool:
212216
if not bands_list:

0 commit comments

Comments
 (0)