Skip to content

Commit a0ce8aa

Browse files
Extending the supported formats for bands to include list[int]
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 1dd6650 commit a0ce8aa

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

terratorch/datamodules/generic_pixel_wise_data_module.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def __init__(
9191
test_split: Path | None = None,
9292
ignore_split_file_extensions: bool = True,
9393
allow_substring_split_file: bool = True,
94-
dataset_bands: list[HLSBands | int] | None = None,
95-
predict_dataset_bands: list[HLSBands | int] | None = None,
96-
output_bands: list[HLSBands | int] | None = None,
94+
dataset_bands: list[HLSBands | int | list[int]] | None = None,
95+
predict_dataset_bands: list[HLSBands | int | list[int]] | None = None,
96+
output_bands: list[HLSBands | int | list[int]] | None = None,
9797
constant_scale: float = 1,
9898
rgb_indices: list[int] | None = None,
9999
train_transform: A.Compose | None | list[A.BasicTransform] = None,
@@ -330,9 +330,9 @@ def __init__(
330330
test_split: Path | None = None,
331331
ignore_split_file_extensions: bool = True,
332332
allow_substring_split_file: bool = True,
333-
dataset_bands: list[HLSBands | int] | None = None,
334-
predict_dataset_bands: list[HLSBands | int] | None = None,
335-
output_bands: list[HLSBands | int] | None = None,
333+
dataset_bands: list[HLSBands | int | list[int]] | None = None,
334+
predict_dataset_bands: list[HLSBands | int | list[int]] | None = None,
335+
output_bands: list[HLSBands | int | list[int]] | None = None,
336336
constant_scale: float = 1,
337337
rgb_indices: list[int] | None = None,
338338
train_transform: A.Compose | None | list[A.BasicTransform] = None,

terratorch/datasets/generic_pixel_wise_dataset.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def __init__(
4343
ignore_split_file_extensions: bool = True,
4444
allow_substring_split_file: bool = True,
4545
rgb_indices: list[int] | None = None,
46-
dataset_bands: list[HLSBands | int] | None = None,
47-
output_bands: list[HLSBands | int] | None = None,
46+
dataset_bands: list[HLSBands | int | list[int]] | None = None,
47+
output_bands: list[HLSBands | int | list[int]] | None = None,
4848
constant_scale: float = 1,
4949
transform: A.Compose | None = None,
5050
no_data_replace: float | None = None,
@@ -179,16 +179,18 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr
179179
def _generate_bands_intervals(self, bands_intervals:List[List[int]] = None):
180180
bands = list()
181181
for b_interval in bands_intervals:
182-
b_interval[-1] += 1
183-
bands_sublist = np.arange(*b_interval).astype(int)
182+
bands_sublist = np.arange(b_interval[0], b_interval[1] + 1).astype(int).tolist()
184183
bands.append(bands_sublist)
185184
return sorted(sum(bands, []))
186185

187-
def _bands_defined_by_interval(self, bands_list: List[int] | List[List[int]] = None) -> bool:
186+
def _bands_defined_by_interval(self, bands_list: list[int] | list[list[int]] = None) -> bool:
188187
if all([type(band)==int or isinstance(band, HLSBands) for band in bands_list]):
189188
return False
190-
elif all([isinstance(band, list) for band in bands_list]):
191-
return True
189+
elif all([isinstance(subinterval, list) for subinterval in bands_list]):
190+
if all([type(band)==int for band in sum(bands_list, [])]):
191+
return True
192+
else:
193+
raise Exception(f"Whe using subintervals, the limits must be int.")
192194
else:
193195
raise Exception(f"Excpected List[int] or List[List[int]], but received {type(bands_list)}.")
194196

@@ -206,8 +208,8 @@ def __init__(
206208
ignore_split_file_extensions: bool = True,
207209
allow_substring_split_file: bool = True,
208210
rgb_indices: list[str] | None = None,
209-
dataset_bands: list[HLSBands | int] | None = None,
210-
output_bands: list[HLSBands | int] | None = None,
211+
dataset_bands: list[HLSBands | int | list[int]] | None = None,
212+
output_bands: list[HLSBands | int | list[int]] | None = None,
211213
class_names: list[str] | None = None,
212214
constant_scale: float = 1,
213215
transform: A.Compose | None = None,

0 commit comments

Comments
 (0)