Skip to content

Commit 23944e3

Browse files
Merge pull request #237 from blumenstiel/multimodal
[WIP] Generic multimodal dataset and MultiMAE
2 parents 378990f + 98416ab commit 23944e3

19 files changed

+4807
-8
lines changed
+161
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 0
3+
trainer:
4+
accelerator: auto
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
precision: 16-mixed
9+
logger:
10+
class_path: TensorBoardLogger
11+
init_args:
12+
save_dir: output
13+
name: multimae_sen1floods11
14+
callbacks:
15+
- class_path: RichProgressBar
16+
- class_path: LearningRateMonitor
17+
init_args:
18+
logging_interval: epoch
19+
- class_path: EarlyStopping
20+
init_args:
21+
monitor: val/loss
22+
patience: 40
23+
24+
max_epochs: 2
25+
check_val_every_n_epoch: 1
26+
log_every_n_steps: 50
27+
enable_checkpointing: true
28+
default_root_dir: output/multimae_sen1floods11/
29+
30+
data:
31+
class_path: GenericMultiModalDataModule
32+
init_args:
33+
task: 'segmentation'
34+
batch_size: 4
35+
num_workers: 0
36+
modalities:
37+
- S2L2A
38+
- S1
39+
- LULC
40+
rgb_modality: S2L2A # If not provided, uses first modality
41+
rgb_indices:
42+
- 3
43+
- 2
44+
- 1
45+
46+
train_data_root:
47+
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
48+
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
49+
LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand
50+
train_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
51+
val_data_root:
52+
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
53+
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
54+
LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand
55+
val_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
56+
test_data_root:
57+
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
58+
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
59+
LULC: data/sen1floods11/data/data/flood_events/HandLabeled/LULCHand
60+
test_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
61+
62+
train_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_train.txt
63+
val_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_valid.txt
64+
test_split: data/sen1floods11/splits/splits/flood_handlabeled/dev_test.txt
65+
66+
allow_substring_file_names: True
67+
image_grep:
68+
S2L2A: "*_S2L2AHand.tif"
69+
S1: "*_S1Hand.tif"
70+
LULC: "*_LULCHand.npy"
71+
label_grep: "*_LabelHand.tif"
72+
no_label_replace: -1
73+
no_data_replace: 0
74+
75+
means:
76+
S2L2A:
77+
- 1793.243
78+
- 1924.863
79+
- 2184.553
80+
- 2340.936
81+
- 2671.402
82+
- 3240.082
83+
- 3468.412
84+
- 3563.244
85+
- 3627.704
86+
- 3711.071
87+
- 3416.714
88+
- 2849.625
89+
S1:
90+
- -12.577
91+
- -20.265
92+
93+
stds:
94+
S2L2A:
95+
- 1160.144
96+
- 1201.092
97+
- 1219.943
98+
- 1397.225
99+
- 1400.035
100+
- 1373.136
101+
- 1429.17
102+
- 1485.025
103+
- 1447.836
104+
- 1652.703
105+
- 1471.002
106+
- 1365.30
107+
S1:
108+
- 5.179
109+
- 5.872
110+
111+
num_classes: 2
112+
113+
train_transform:
114+
- class_path: albumentations.RandomCrop
115+
init_args:
116+
height: 224
117+
width: 224
118+
- class_path: albumentations.D4
119+
- class_path: ToTensorV2
120+
121+
122+
model:
123+
class_path: terratorch.tasks.SemanticSegmentationTask
124+
init_args:
125+
model_factory: EncoderDecoderFactory
126+
model_args:
127+
backbone_pretrained: false
128+
backbone: multimae_base
129+
backbone_input_adapters:
130+
- S1
131+
- S2L2A
132+
- LULC
133+
decoder: FCNDecoder # UperNetDecoder
134+
decoder_num_convs: 4 # only for FCNDecoder
135+
# decoder_scale_modules: True # only for UperNetDecoder
136+
decoder_channels: 256
137+
num_classes: 2
138+
head_dropout: 0.1
139+
head_channel_list:
140+
- 256
141+
loss: ce
142+
ignore_index: -1
143+
class_weights:
144+
- 0.3
145+
- 0.7
146+
class_names:
147+
- Others
148+
- Flood
149+
freeze_backbone: false
150+
freeze_decoder: false
151+
152+
optimizer:
153+
class_path: torch.optim.AdamW
154+
init_args:
155+
lr: 6.e-5
156+
weight_decay: 0.05
157+
lr_scheduler:
158+
class_path: ReduceLROnPlateau
159+
init_args:
160+
monitor: val/loss
161+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 0
3+
trainer:
4+
accelerator: auto
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
precision: 16-mixed
9+
logger:
10+
class_path: TensorBoardLogger
11+
init_args:
12+
save_dir: output
13+
name: multimodal_prithvi_sen1floods11
14+
version: test_best
15+
callbacks:
16+
- class_path: RichProgressBar
17+
- class_path: LearningRateMonitor
18+
init_args:
19+
logging_interval: epoch
20+
- class_path: EarlyStopping
21+
init_args:
22+
monitor: val/loss
23+
patience: 40
24+
25+
max_epochs: 100
26+
check_val_every_n_epoch: 1
27+
log_every_n_steps: 50
28+
enable_checkpointing: True
29+
default_root_dir: output/multimodal_prithvi_sen1floods11/
30+
31+
data:
32+
class_path: GenericMultiModalDataModule
33+
init_args:
34+
task: 'segmentation'
35+
batch_size: 16
36+
num_workers: 4
37+
modalities: # Define names of modalities
38+
- S2L2A
39+
- S1
40+
rgb_modality: S2L2A # If not provided, uses first modality
41+
rgb_indices:
42+
- 3
43+
- 2
44+
- 1
45+
46+
# Data roots are defined as dicts with modalities as keys
47+
train_data_root:
48+
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
49+
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
50+
train_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
51+
val_data_root:
52+
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
53+
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
54+
val_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
55+
test_data_root:
56+
S2L2A: data/sen1floods11/data/data/flood_events/HandLabeled/S2L2AHand
57+
S1: data/sen1floods11/data/data/flood_events/HandLabeled/S1Hand
58+
test_label_data_root: data/sen1floods11/data/data/flood_events/HandLabeled/LabelHand
59+
60+
train_split: data/sen1floods11/splits/splits/flood_handlabeled/flood_train_data.txt
61+
val_split: data/sen1floods11/splits/splits/flood_handlabeled/flood_valid_data.txt
62+
test_split: data/sen1floods11/splits/splits/flood_handlabeled/flood_test_data.txt
63+
64+
allow_substring_file_names: True
65+
image_grep:
66+
S2L2A: "*_S2L2AHand.tif"
67+
S1: "*_S1Hand.tif"
68+
label_grep: "*_LabelHand.tif"
69+
no_label_replace: -1
70+
no_data_replace: 0
71+
concat_bands: true # Concatenate modalities along band dim for single-modal models like Prithvi
72+
73+
# Define standardization values as dicts (no scaling if modality is not included)
74+
means:
75+
S2L2A:
76+
- 1793.243
77+
- 1924.863
78+
- 2184.553
79+
- 2340.936
80+
- 2671.402
81+
- 3240.082
82+
- 3468.412
83+
- 3563.244
84+
- 3627.704
85+
- 3711.071
86+
- 3416.714
87+
- 2849.625
88+
S1:
89+
- -12.577
90+
- -20.265
91+
92+
stds:
93+
S2L2A:
94+
- 1160.144
95+
- 1201.092
96+
- 1219.943
97+
- 1397.225
98+
- 1400.035
99+
- 1373.136
100+
- 1429.17
101+
- 1485.025
102+
- 1447.836
103+
- 1652.703
104+
- 1471.002
105+
- 1365.30
106+
S1:
107+
- 5.179
108+
- 5.872
109+
110+
num_classes: 2
111+
112+
# Transforms are shared between all image modalities (e.g. same crop area)
113+
train_transform:
114+
- class_path: albumentations.RandomCrop
115+
init_args:
116+
height: 224
117+
width: 224
118+
- class_path: albumentations.D4
119+
- class_path: ToTensorV2
120+
121+
122+
model:
123+
class_path: terratorch.tasks.SemanticSegmentationTask
124+
init_args:
125+
model_factory: EncoderDecoderFactory
126+
model_args:
127+
backbone: prithvi_vit_100
128+
backbone_pretrained: false
129+
backbone_bands:
130+
- COASTAL_AEROSOL
131+
- BLUE
132+
- GREEN
133+
- RED
134+
- RED_EDGE_1
135+
- RED_EDGE_2
136+
- RED_EDGE_3
137+
- NIR_BROAD
138+
- NIR_NARROW
139+
- CIRRUS
140+
- SWIR_1
141+
- SWIR_2
142+
- VV
143+
- VH
144+
decoder: FCNDecoder # FCNDecoder
145+
decoder_num_convs: 4 # only for FCNDecoder
146+
# decoder_scale_modules: True # only for UperNetDecoder
147+
decoder_channels: 256
148+
num_classes: 2
149+
head_dropout: 0.1
150+
head_channel_list:
151+
- 256
152+
153+
loss: dice
154+
ignore_index: -1
155+
class_weights:
156+
- 0.3
157+
- 0.7
158+
class_names:
159+
- Others
160+
- Flood
161+
freeze_backbone: false
162+
freeze_decoder: false
163+
164+
optimizer:
165+
class_path: torch.optim.AdamW
166+
init_args:
167+
lr: 6.e-5
168+
weight_decay: 0.05
169+
lr_scheduler:
170+
class_path: ReduceLROnPlateau
171+
init_args:
172+
monitor: val/loss
173+

terratorch/datamodules/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from terratorch.datamodules.sen1floods11 import Sen1Floods11NonGeoDataModule
4040
from terratorch.datamodules.sen4agrinet import Sen4AgriNetDataModule
4141
from terratorch.datamodules.torchgeo_data_module import TorchGeoDataModule, TorchNonGeoDataModule
42+
from terratorch.datamodules.generic_multimodal_data_module import GenericMultiModalDataModule
4243

4344

4445
# miscellaneous datamodules
@@ -74,7 +75,8 @@
7475
"OpenEarthMapModule"
7576
"OpenSentinelMapDataModule",
7677
"PASTISDataModule",
77-
"Sen4AgriNetDataModule"
78+
"Sen4AgriNetDataModule",
79+
"GenericMultiModalDataModule",
7880
)
7981

8082
if wxc_present:

0 commit comments

Comments
 (0)