Skip to content

Commit da480f0

Browse files
Merge pull request #54 from IBM/improve/bands_definition
Improve/bands definition
2 parents 3b76391 + 1a254e4 commit da480f0

7 files changed

+533
-41
lines changed

terratorch/datamodules/generic_pixel_wise_data_module.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ def __init__(
9191
test_split: Path | None = None,
9292
ignore_split_file_extensions: bool = True,
9393
allow_substring_split_file: bool = True,
94-
dataset_bands: list[HLSBands | int] | None = None,
95-
predict_dataset_bands: list[HLSBands | int] | None = None,
96-
output_bands: list[HLSBands | int] | None = None,
94+
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
95+
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
96+
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
9797
constant_scale: float = 1,
9898
rgb_indices: list[int] | None = None,
9999
train_transform: A.Compose | None | list[A.BasicTransform] = None,
@@ -330,9 +330,9 @@ def __init__(
330330
test_split: Path | None = None,
331331
ignore_split_file_extensions: bool = True,
332332
allow_substring_split_file: bool = True,
333-
dataset_bands: list[HLSBands | int] | None = None,
334-
predict_dataset_bands: list[HLSBands | int] | None = None,
335-
output_bands: list[HLSBands | int] | None = None,
333+
dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
334+
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
335+
output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
336336
constant_scale: float = 1,
337337
rgb_indices: list[int] | None = None,
338338
train_transform: A.Compose | None | list[A.BasicTransform] = None,

terratorch/datasets/generic_pixel_wise_dataset.py

+43-18
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,19 @@
11
# Copyright contributors to the Terratorch project
22

3-
"""Module containing generic dataset classes
4-
"""
3+
"""Module containing generic dataset classes"""
4+
55
import glob
66
import os
77
from abc import ABC
8-
from functools import partial
98
from pathlib import Path
109
from typing import Any
1110

1211
import albumentations as A
1312
import matplotlib as mpl
1413
import numpy as np
1514
import rioxarray
16-
import torch
1715
import xarray as xr
18-
from albumentations.pytorch import ToTensorV2
1916
from einops import rearrange
20-
from matplotlib import cm
2117
from matplotlib import pyplot as plt
2218
from matplotlib.figure import Figure
2319
from matplotlib.patches import Rectangle
@@ -43,8 +39,8 @@ def __init__(
4339
ignore_split_file_extensions: bool = True,
4440
allow_substring_split_file: bool = True,
4541
rgb_indices: list[int] | None = None,
46-
dataset_bands: list[HLSBands | int] | None = None,
47-
output_bands: list[HLSBands | int] | None = None,
42+
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
43+
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
4844
constant_scale: float = 1,
4945
transform: A.Compose | None = None,
5046
no_data_replace: float | None = None,
@@ -73,8 +69,8 @@ def __init__(
7369
that must be present in file names to be included (as in mmsegmentation), or exact
7470
matches (e.g. eurosat). Defaults to True.
7571
rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
76-
dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
77-
output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
72+
dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
73+
output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
7874
constant_scale (float): Factor to multiply image values by. Defaults to 1.
7975
transform (Albumentations.Compose | None): Albumentations transform to be applied.
8076
Should end with ToTensorV2(). If used through the generic_data_module,
@@ -88,6 +84,7 @@ def __init__(
8884
expected 0. Defaults to False.
8985
"""
9086
super().__init__()
87+
9188
self.split_file = split
9289

9390
label_data_root = label_data_root if label_data_root is not None else data_root
@@ -120,19 +117,24 @@ def __init__(
120117
)
121118
self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices
122119

123-
self.dataset_bands = dataset_bands
124-
self.output_bands = output_bands
120+
self.dataset_bands = self._generate_bands_intervals(dataset_bands)
121+
self.output_bands = self._generate_bands_intervals(output_bands)
122+
125123
if self.output_bands and not self.dataset_bands:
126124
msg = "If output bands provided, dataset_bands must also be provided"
127125
return Exception(msg) # noqa: PLE0101
128126

127+
# There is a special condition if the bands are defined as simple strings.
129128
if self.output_bands:
130129
if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
131130
msg = "Output bands must be a subset of dataset bands"
132131
raise Exception(msg)
132+
133133
self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]
134+
134135
else:
135136
self.filter_indices = None
137+
136138
# If no transform is given, apply only to transform to torch tensor
137139
self.transform = transform if transform else lambda **batch: to_tensor(batch)
138140
# self.transform = transform if transform else ToTensorV2()
@@ -141,7 +143,7 @@ def __len__(self) -> int:
141143
return len(self.image_files)
142144

143145
def __getitem__(self, index: int) -> dict[str, Any]:
144-
image = self._load_file(self.image_files[index], nan_replace = self.no_data_replace).to_numpy()
146+
image = self._load_file(self.image_files[index], nan_replace=self.no_data_replace).to_numpy()
145147
# to channels last
146148
if self.expand_temporal_dimension:
147149
image = rearrange(image, "(channels time) h w -> channels time h w", channels=len(self.output_bands))
@@ -151,9 +153,12 @@ def __getitem__(self, index: int) -> dict[str, Any]:
151153
image = image[..., self.filter_indices]
152154
output = {
153155
"image": image.astype(np.float32) * self.constant_scale,
154-
"mask": self._load_file(self.segmentation_mask_files[index], nan_replace = self.no_label_replace).to_numpy()[0],
156+
"mask": self._load_file(self.segmentation_mask_files[index], nan_replace=self.no_label_replace).to_numpy()[
157+
0
158+
],
155159
"filename": self.image_files[index],
156160
}
161+
157162
if self.reduce_zero_label:
158163
output["mask"] -= 1
159164
if self.transform:
@@ -166,6 +171,26 @@ def _load_file(self, path, nan_replace: int | float | None = None) -> xr.DataArr
166171
data = data.fillna(nan_replace)
167172
return data
168173

174+
def _generate_bands_intervals(self, bands_intervals: list[int | str | HLSBands | tuple[int]] | None = None):
175+
if bands_intervals is None:
176+
return None
177+
bands = []
178+
for element in bands_intervals:
179+
# if its an interval
180+
if isinstance(element, tuple):
181+
if len(element) != 2: # noqa: PLR2004
182+
msg = "When defining an interval, a tuple of two integers should be passed, defining start and end indices inclusive"
183+
raise Exception(msg)
184+
expanded_element = list(range(element[0], element[1] + 1))
185+
bands.extend(expanded_element)
186+
else:
187+
bands.append(element)
188+
# check the expansion didnt result in duplicate elements
189+
if len(set(bands)) != len(bands):
190+
msg = "Duplicate indices detected. Indices must be unique."
191+
raise Exception(msg)
192+
return bands
193+
169194

170195
class GenericNonGeoSegmentationDataset(GenericPixelWiseDataset):
171196
"""GenericNonGeoSegmentationDataset"""
@@ -181,8 +206,8 @@ def __init__(
181206
ignore_split_file_extensions: bool = True,
182207
allow_substring_split_file: bool = True,
183208
rgb_indices: list[str] | None = None,
184-
dataset_bands: list[HLSBands | int] | None = None,
185-
output_bands: list[HLSBands | int] | None = None,
209+
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
210+
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
186211
class_names: list[str] | None = None,
187212
constant_scale: float = 1,
188213
transform: A.Compose | None = None,
@@ -348,8 +373,8 @@ def __init__(
348373
ignore_split_file_extensions: bool = True,
349374
allow_substring_split_file: bool = True,
350375
rgb_indices: list[int] | None = None,
351-
dataset_bands: list[HLSBands | int] | None = None,
352-
output_bands: list[HLSBands | int] | None = None,
376+
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
377+
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
353378
constant_scale: float = 1,
354379
transform: A.Compose | None = None,
355380
no_data_replace: float | None = None,

terratorch/datasets/generic_scalar_label_dataset.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def __init__(
4242
ignore_split_file_extensions: bool = True,
4343
allow_substring_split_file: bool = True,
4444
rgb_indices: list[int] | None = None,
45-
dataset_bands: list[HLSBands | int] | None = None,
46-
output_bands: list[HLSBands | int] | None = None,
45+
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
46+
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
4747
constant_scale: float = 1,
4848
transform: A.Compose | None = None,
4949
no_data_replace: float = 0,
@@ -64,8 +64,8 @@ def __init__(
6464
that must be present in file names to be included (as in mmsegmentation), or exact
6565
matches (e.g. eurosat). Defaults to True.
6666
rgb_indices (list[str], optional): Indices of RGB channels. Defaults to [0, 1, 2].
67-
dataset_bands (list[HLSBands | int] | None): Bands present in the dataset.
68-
output_bands (list[HLSBands | int] | None): Bands that should be output by the dataset.
67+
dataset_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands present in the dataset. This parameter names input channels (bands) using HLSBands, ints, int ranges, or strings, so that they can then be refered to by output_bands. Defaults to None.
68+
output_bands (list[HLSBands | int | tuple[int, int] | str] | None): Bands that should be output by the dataset as named by dataset_bands.
6969
constant_scale (float): Factor to multiply image values by. Defaults to 1.
7070
transform (Albumentations.Compose | None): Albumentations transform to be applied.
7171
Should end with ToTensorV2(). If used through the generic_data_module,
@@ -110,17 +110,21 @@ def is_valid_file(x):
110110

111111
self.rgb_indices = [0, 1, 2] if rgb_indices is None else rgb_indices
112112

113-
self.dataset_bands = dataset_bands
114-
self.output_bands = output_bands
113+
self.dataset_bands = self._generate_bands_intervals(dataset_bands)
114+
self.output_bands = self._generate_bands_intervals(output_bands)
115+
115116
if self.output_bands and not self.dataset_bands:
116117
msg = "If output bands provided, dataset_bands must also be provided"
117118
return Exception(msg) # noqa: PLE0101
118119

120+
# There is a special condition if the bands are defined as simple strings.
119121
if self.output_bands:
120122
if len(set(self.output_bands) & set(self.dataset_bands)) != len(self.output_bands):
121123
msg = "Output bands must be a subset of dataset bands"
122124
raise Exception(msg)
125+
123126
self.filter_indices = [self.dataset_bands.index(band) for band in self.output_bands]
127+
124128
else:
125129
self.filter_indices = None
126130
# If no transform is given, apply only to transform to torch tensor
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 42
3+
trainer:
4+
accelerator: auto
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
# precision: 16-mixed
9+
logger:
10+
class_path: TensorBoardLogger
11+
init_args:
12+
save_dir: tests/
13+
name: all_ecos_random
14+
callbacks:
15+
- class_path: RichProgressBar
16+
- class_path: LearningRateMonitor
17+
init_args:
18+
logging_interval: epoch
19+
- class_path: EarlyStopping
20+
init_args:
21+
monitor: val/loss
22+
patience: 100
23+
max_epochs: 5
24+
check_val_every_n_epoch: 1
25+
log_every_n_steps: 20
26+
enable_checkpointing: true
27+
default_root_dir: tests/
28+
data:
29+
class_path: GenericNonGeoPixelwiseRegressionDataModule
30+
init_args:
31+
batch_size: 2
32+
num_workers: 4
33+
train_transform:
34+
- class_path: albumentations.HorizontalFlip
35+
init_args:
36+
p: 0.5
37+
- class_path: albumentations.Rotate
38+
init_args:
39+
limit: 30
40+
border_mode: 0 # cv2.BORDER_CONSTANT
41+
value: 0
42+
# mask_value: 1
43+
p: 0.5
44+
- class_path: ToTensorV2
45+
dataset_bands:
46+
- [0, 11]
47+
output_bands:
48+
- [1, 3]
49+
- [4, 6]
50+
rgb_indices:
51+
- 2
52+
- 1
53+
- 0
54+
train_data_root: tests/
55+
train_label_data_root: tests/
56+
val_data_root: tests/
57+
val_label_data_root: tests/
58+
test_data_root: tests/
59+
test_label_data_root: tests/
60+
img_grep: "regression*input*.tif"
61+
label_grep: "regression*label*.tif"
62+
means:
63+
- 547.36707
64+
- 898.5121
65+
- 1020.9082
66+
- 2665.5352
67+
- 2340.584
68+
- 1610.1407
69+
stds:
70+
- 411.4701
71+
- 558.54065
72+
- 815.94025
73+
- 812.4403
74+
- 1113.7145
75+
- 1067.641
76+
no_label_replace: -1
77+
no_data_replace: 0
78+
79+
model:
80+
class_path: terratorch.tasks.PixelwiseRegressionTask
81+
init_args:
82+
model_args:
83+
decoder: UperNetDecoder
84+
pretrained: true
85+
backbone: prithvi_swin_B
86+
backbone_pretrained_cfg_overlay:
87+
file: tests/prithvi_swin_B.pt
88+
backbone_drop_path_rate: 0.3
89+
# backbone_window_size: 8
90+
decoder_channels: 256
91+
in_channels: 6
92+
bands:
93+
- BLUE
94+
- GREEN
95+
- RED
96+
- NIR_NARROW
97+
- SWIR_1
98+
- SWIR_2
99+
num_frames: 1
100+
head_dropout: 0.5708022831486758
101+
head_final_act: torch.nn.ReLU
102+
head_learned_upscale_layers: 2
103+
loss: rmse
104+
#aux_heads:
105+
# - name: aux_head
106+
# decoder: IdentityDecoder
107+
# decoder_args:
108+
# decoder_out_index: 2
109+
# head_dropout: 0,5
110+
# head_channel_list:
111+
# - 64
112+
# head_final_act: torch.nn.ReLU
113+
#aux_loss:
114+
# aux_head: 0.4
115+
ignore_index: -1
116+
freeze_backbone: true
117+
freeze_decoder: false
118+
model_factory: PrithviModelFactory
119+
120+
# uncomment this block for tiled inference
121+
# tiled_inference_parameters:
122+
# h_crop: 224
123+
# h_stride: 192
124+
# w_crop: 224
125+
# w_stride: 192
126+
# average_patches: true
127+
optimizer:
128+
class_path: torch.optim.AdamW
129+
init_args:
130+
lr: 0.00013524680528283027
131+
weight_decay: 0.047782217873995426
132+
lr_scheduler:
133+
class_path: ReduceLROnPlateau
134+
init_args:
135+
monitor: val/loss
136+

0 commit comments

Comments
 (0)