Skip to content

Commit 91b4043

Browse files
Configuration template for Swin
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent dd40613 commit 91b4043

File tree

1 file changed

+144
-0
lines changed

1 file changed

+144
-0
lines changed

examples/confs/sen1floods11_swin.yaml

+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 0
3+
# save the model checkpoints to a separate location then the logger
4+
# ModelCheckpoint:
5+
# dirpath: /dccstor/geofm-finetuning/carlosgomes/torchgeo_floods/sen1floods11_terratorch
6+
trainer:
7+
accelerator: auto
8+
strategy: auto
9+
devices: auto
10+
num_nodes: 1
11+
precision: 16-mixed
12+
logger:
13+
class_path: lightning.pytorch.loggers.mlflow.MLFlowLogger
14+
init_args:
15+
experiment_name: "test_experiment_mlflow"
16+
run_name: "plot_every_2_epochs"
17+
save_dir: <your_path_here>/mlflow
18+
callbacks:
19+
- class_path: RichProgressBar
20+
- class_path: LearningRateMonitor
21+
init_args:
22+
logging_interval: epoch
23+
- class_path: EarlyStopping
24+
init_args:
25+
monitor: val/loss
26+
patience: 20
27+
28+
max_epochs: 200
29+
check_val_every_n_epoch: 1
30+
log_every_n_steps: 50
31+
enable_checkpointing: true
32+
default_root_dir: <your_path_here>/torchgeo_floods
33+
data:
34+
class_path: GenericNonGeoSegmentationDataModule
35+
init_args:
36+
batch_size: 4
37+
num_workers: 8
38+
constant_scale: 0.0001
39+
dataset_bands:
40+
- RED
41+
- GREEN
42+
- BLUE
43+
- NIR_NARROW
44+
- SWIR_1
45+
- SWIR_2
46+
output_bands:
47+
- BLUE
48+
- GREEN
49+
- RED
50+
- NIR_NARROW
51+
- SWIR_1
52+
- SWIR_2
53+
rgb_indices:
54+
- 2
55+
- 1
56+
- 0
57+
train_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
58+
train_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
59+
val_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
60+
val_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
61+
test_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/S2Hand/
62+
test_label_data_root: <sen1floods11_root>/v1.1/data/flood_events/HandLabeled/LabelHand
63+
# these must be obtained by running terratorch/examples/scripts/convert_sen1floods11_splits.py on the original split csv files
64+
train_split: <sen1floods11_root>/splits/splits/flood_handlabeled/flood_train_data.txt
65+
test_split: <sen1floods11_root>/splits/splits/flood_handlabeled/flood_test_data.txt
66+
val_split: <sen1floods11_root>/splits/splits/flood_handlabeled/flood_valid_data.txt
67+
img_grep: "*_S2GeodnHand.tif"
68+
label_grep: "*_LabelHand.tif"
69+
no_label_replace: -1
70+
no_data_replace: 0
71+
means:
72+
- 0.107582
73+
- 0.13471393
74+
- 0.12520133
75+
- 0.3236181
76+
- 0.2341743
77+
- 0.15878009
78+
stds:
79+
- 0.07145836
80+
- 0.06783548
81+
- 0.07323416
82+
- 0.09489725
83+
- 0.07938496
84+
- 0.07089546
85+
num_classes: 2
86+
87+
model:
88+
class_path: terratorch.tasks.SemanticSegmentationTask
89+
init_args:
90+
model_args:
91+
decoder: UperNetDecoder
92+
pretrained: true
93+
backbone: prithvi_swin_B
94+
backbone_drop_path_rate: 0.3
95+
backbone_window_size: 7
96+
decoder_channels: 256
97+
in_channels: 6
98+
bands:
99+
- BLUE
100+
- GREEN
101+
- RED
102+
- NIR_NARROW
103+
- SWIR_1
104+
- SWIR_2
105+
num_frames: 1
106+
num_classes: 2
107+
head_dropout: 0.1
108+
head_channel_list:
109+
- 256
110+
loss: ce
111+
plot_on_val: 2
112+
# aux_heads:
113+
# - name: aux_head
114+
# decoder: FCNDecoder
115+
# decoder_args:
116+
# decoder_channels: 256
117+
# decoder_in_index: 2
118+
# decoder_num_convs: 1
119+
# head_channel_list:
120+
# - 64
121+
# aux_loss:
122+
# aux_head: 1.0
123+
ignore_index: -1
124+
class_weights:
125+
- 0.3
126+
- 0.7
127+
freeze_backbone: false
128+
freeze_decoder: false
129+
model_factory: PrithviModelFactory
130+
# tiled_inference_parameters:
131+
# h_crop: 512
132+
# h_stride: 512
133+
# w_crop: 512
134+
# w_stride: 512
135+
# average_patches: true
136+
optimizer:
137+
class_path: torch.optim.AdamW
138+
init_args:
139+
lr: 6.e-5
140+
weight_decay: 0.05
141+
lr_scheduler:
142+
class_path: ReduceLROnPlateau
143+
init_args:
144+
monitor: val/loss

0 commit comments

Comments
 (0)