Skip to content

Commit 0a84c42

Browse files
Band intervals should be tuples with two entries
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 7a94a82 commit 0a84c42

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

terratorch/datamodules/generic_pixel_wise_data_module.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def __init__(
9191
test_split: Path | None = None,
9292
ignore_split_file_extensions: bool = True,
9393
allow_substring_split_file: bool = True,
94-
dataset_bands: list[HLSBands | int | list[int] | str] | None = None,
95-
predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
96-
output_bands: list[HLSBands | int | list[int] | str] | None = None,
94+
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
95+
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
96+
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
9797
constant_scale: float = 1,
9898
rgb_indices: list[int] | None = None,
9999
train_transform: A.Compose | None | list[A.BasicTransform] = None,
@@ -330,9 +330,9 @@ def __init__(
330330
test_split: Path | None = None,
331331
ignore_split_file_extensions: bool = True,
332332
allow_substring_split_file: bool = True,
333-
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
334-
predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
335-
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
333+
dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
334+
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
335+
output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
336336
constant_scale: float = 1,
337337
rgb_indices: list[int] | None = None,
338338
train_transform: A.Compose | None | list[A.BasicTransform] = None,

terratorch/datasets/generic_pixel_wise_dataset.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def __init__(
4343
ignore_split_file_extensions: bool = True,
4444
allow_substring_split_file: bool = True,
4545
rgb_indices: list[int] | None = None,
46-
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
47-
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
46+
dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
47+
output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
4848
constant_scale: float = 1,
4949
transform: A.Compose | None = None,
5050
no_data_replace: float | None = None,
@@ -212,18 +212,20 @@ def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type:
212212
else:
213213
raise Exception("The bands must be or all str or all int.")
214214

215-
def _bands_defined_by_interval(self, bands_list: list[int] | list[list[int]] = None) -> bool:
215+
def _bands_defined_by_interval(self, bands_list: list[int] | list[tuple[int]] = None) -> bool:
216216
if not bands_list:
217217
return False
218218
elif all([type(band)==int or type(band)==str or isinstance(band, HLSBands) for band in bands_list]):
219219
return False
220-
elif all([isinstance(subinterval, list) for subinterval in bands_list]):
221-
if all([type(band)==int for band in sum(bands_list, [])]):
220+
elif all([isinstance(subinterval, tuple) for subinterval in bands_list]):
221+
bands_list_ = [list(subinterval) for subinterval in bands_list]
222+
if all([type(band)==int for band in sum(bands_list_, [])]):
222223
return True
223224
else:
224225
raise Exception(f"Whe using subintervals, the limits must be int.")
225226
else:
226-
raise Exception(f"Excpected List[int] or List[str] or List[List[int]], but received {type(bands_list)}.")
227+
print(bands_list)
228+
raise Exception(f"Excpected List[int] or List[str] or List[tuple[int, int]], but received {type(bands_list)}.")
227229

228230
class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
229231
"""GenericNonGeoSegmentationDataset"""
@@ -239,8 +241,8 @@ def __init__(
239241
ignore_split_file_extensions: bool = True,
240242
allow_substring_split_file: bool = True,
241243
rgb_indices: list[str] | None = None,
242-
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
243-
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
244+
dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
245+
output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
244246
class_names: list[str] | None = None,
245247
constant_scale: float = 1,
246248
transform: A.Compose | None = None,
@@ -406,8 +408,8 @@ def __init__(
406408
ignore_split_file_extensions: bool = True,
407409
allow_substring_split_file: bool = True,
408410
rgb_indices: list[int] | None = None,
409-
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
410-
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
411+
dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
412+
output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
411413
constant_scale: float = 1,
412414
transform: A.Compose | None = None,
413415
no_data_replace: float | None = None,

0 commit comments

Comments
 (0)