Skip to content

Commit 5dba481

Browse files
Strings are allowed to define bands
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent e48965f commit 5dba481

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

terratorch/datamodules/generic_pixel_wise_data_module.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def __init__(
9191
test_split: Path | None = None,
9292
ignore_split_file_extensions: bool = True,
9393
allow_substring_split_file: bool = True,
94-
dataset_bands: list[HLSBands | int | list[int]] | None = None,
95-
predict_dataset_bands: list[HLSBands | int | list[int]] | None = None,
96-
output_bands: list[HLSBands | int | list[int]] | None = None,
94+
dataset_bands: list[HLSBands | int | list[int] | str] | None = None,
95+
predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
96+
output_bands: list[HLSBands | int | list[int] | str] | None = None,
9797
constant_scale: float = 1,
9898
rgb_indices: list[int] | None = None,
9999
train_transform: A.Compose | None | list[A.BasicTransform] = None,
@@ -330,9 +330,9 @@ def __init__(
330330
test_split: Path | None = None,
331331
ignore_split_file_extensions: bool = True,
332332
allow_substring_split_file: bool = True,
333-
dataset_bands: list[HLSBands | int | list[int]] | None = None,
334-
predict_dataset_bands: list[HLSBands | int | list[int]] | None = None,
335-
output_bands: list[HLSBands | int | list[int]] | None = None,
333+
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
334+
predict_dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
335+
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
336336
constant_scale: float = 1,
337337
rgb_indices: list[int] | None = None,
338338
train_transform: A.Compose | None | list[A.BasicTransform] = None,

terratorch/datasets/generic_pixel_wise_dataset.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
expected 0. Defaults to False.
8989
"""
9090
super().__init__()
91+
9192
self.split_file = split
9293

9394
label_data_root = label_data_root if label_data_root is not None else data_root
@@ -136,7 +137,7 @@ def __init__(
136137
if bands_type == str:
137138
raise UserWarning("When the bands are defined as str, guarantee your input files"+
138139
"are organized by band and all have its specific name.")
139-
140+
140141
if self.output_bands and not self.dataset_bands:
141142
msg = "If output bands provided, dataset_bands must also be provided"
142143
return Exception(msg) # noqa: PLE0101
@@ -146,7 +147,9 @@ def __init__(
146147
if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
147148
msg = "Output bands must be a subset of dataset bands"
148149
raise Exception(msg)
150+
149151
self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]
152+
150153
else:
151154
self.filter_indices = None
152155

@@ -176,7 +179,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
176179
if self.transform:
177180
output = self.transform(**output)
178181
return output
179-
182+
180183
def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArray:
181184
data = rioxarray.open_rasterio(path, masked=True)
182185
if nan_replace is not None:
@@ -200,7 +203,7 @@ def _bands_as_int_or_str(self, dataset_bands, output_bands) -> type:
200203
band_type[b] = str
201204
else:
202205
pass
203-
if band_type.cound(band_type[0]) == len(band_type)
206+
if band_type.cound(band_type[0]) == len(band_type):
204207
return band_type[0]
205208
else:
206209
raise Exception("The bands must be or all str or all int.")
@@ -232,8 +235,8 @@ def __init__(
232235
ignore_split_file_extensions: bool = True,
233236
allow_substring_split_file: bool = True,
234237
rgb_indices: list[str] | None = None,
235-
dataset_bands: list[HLSBands | int | list[int]] | None = None,
236-
output_bands: list[HLSBands | int | list[int]] | None = None,
238+
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
239+
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
237240
class_names: list[str] | None = None,
238241
constant_scale: float = 1,
239242
transform: A.Compose | None = None,
@@ -399,8 +402,8 @@ def __init__(
399402
ignore_split_file_extensions: bool = True,
400403
allow_substring_split_file: bool = True,
401404
rgb_indices: list[int] | None = None,
402-
dataset_bands: list[HLSBands | int | list[int]] | None = None,
403-
output_bands: list[HLSBands | int | list[int]] | None = None,
405+
dataset_bands: list[HLSBands | int | list[int] | str ] | None = None,
406+
output_bands: list[HLSBands | int | list[int] | str ] | None = None,
404407
constant_scale: float = 1,
405408
transform: A.Compose | None = None,
406409
no_data_replace: float | None = None,

0 commit comments

Comments
 (0)