Skip to content

Commit 41aa312

Browse files
Testing all the backbones
Signed-off-by: João Lucas de Sousa Almeida <[email protected]> Testing to save and load checkpoints Signed-off-by: João Lucas de Sousa Almeida <[email protected]> Testing finetuning for Swin Signed-off-by: João Lucas de Sousa Almeida <[email protected]> More config files used for executing the manufactured tests Signed-off-by: João Lucas de Sousa Almeida <[email protected]> More input/target files to perform the manufactured tests Signed-off-by: João Lucas de Sousa Almeida <[email protected]> Automatically testing fine-tuning Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 4cbf229 commit 41aa312

8 files changed

+662
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 42
3+
trainer:
4+
accelerator: auto
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
# precision: 16-mixed
9+
logger:
10+
class_path: TensorBoardLogger
11+
init_args:
12+
save_dir: /tmp
13+
name: all_ecos_random
14+
callbacks:
15+
- class_path: RichProgressBar
16+
- class_path: LearningRateMonitor
17+
init_args:
18+
logging_interval: epoch
19+
- class_path: EarlyStopping
20+
init_args:
21+
monitor: val/loss
22+
patience: 100
23+
max_epochs: 5
24+
check_val_every_n_epoch: 1
25+
log_every_n_steps: 20
26+
enable_checkpointing: true
27+
default_root_dir: /tmp
28+
data:
29+
class_path: GenericNonGeoPixelwiseRegressionDataModule
30+
init_args:
31+
batch_size: 2
32+
num_workers: 8
33+
train_transform:
34+
- class_path: albumentations.HorizontalFlip
35+
init_args:
36+
p: 0.5
37+
- class_path: albumentations.Rotate
38+
init_args:
39+
limit: 30
40+
border_mode: 0 # cv2.BORDER_CONSTANT
41+
value: 0
42+
# mask_value: 1
43+
p: 0.5
44+
- class_path: ToTensorV2
45+
dataset_bands:
46+
- 0
47+
- BLUE
48+
- GREEN
49+
- RED
50+
- NIR_NARROW
51+
- SWIR_1
52+
- SWIR_2
53+
- 1
54+
- 2
55+
- 3
56+
- 4
57+
output_bands:
58+
- BLUE
59+
- GREEN
60+
- RED
61+
- NIR_NARROW
62+
- SWIR_1
63+
- SWIR_2
64+
rgb_indices:
65+
- 2
66+
- 1
67+
- 0
68+
train_data_root: tests/
69+
train_label_data_root: tests/
70+
val_data_root: tests/
71+
val_label_data_root: tests/
72+
test_data_root: tests/
73+
test_label_data_root: tests/
74+
img_grep: "regression*input*.tif"
75+
label_grep: "regression*label*.tif"
76+
means:
77+
- 547.36707
78+
- 898.5121
79+
- 1020.9082
80+
- 2665.5352
81+
- 2340.584
82+
- 1610.1407
83+
stds:
84+
- 411.4701
85+
- 558.54065
86+
- 815.94025
87+
- 812.4403
88+
- 1113.7145
89+
- 1067.641
90+
no_label_replace: -1
91+
no_data_replace: 0
92+
93+
model:
94+
class_path: terratorch.tasks.PixelwiseRegressionTask
95+
init_args:
96+
model_args:
97+
decoder: UperNetDecoder
98+
pretrained: true
99+
backbone: prithvi_swin_B
100+
backbone_pretrained_cfg_overlay:
101+
file: /tmp/prithvi_swin_B.pt
102+
backbone_drop_path_rate: 0.3
103+
# backbone_window_size: 8
104+
decoder_channels: 256
105+
in_channels: 6
106+
bands:
107+
- BLUE
108+
- GREEN
109+
- RED
110+
- NIR_NARROW
111+
- SWIR_1
112+
- SWIR_2
113+
num_frames: 1
114+
head_dropout: 0.5708022831486758
115+
head_final_act: torch.nn.ReLU
116+
head_learned_upscale_layers: 2
117+
loss: rmse
118+
#aux_heads:
119+
# - name: aux_head
120+
# decoder: IdentityDecoder
121+
# decoder_args:
122+
# decoder_out_index: 2
123+
# head_dropout: 0,5
124+
# head_channel_list:
125+
# - 64
126+
# head_final_act: torch.nn.ReLU
127+
#aux_loss:
128+
# aux_head: 0.4
129+
ignore_index: -1
130+
freeze_backbone: false
131+
freeze_decoder: false
132+
model_factory: PrithviModelFactory
133+
134+
# uncomment this block for tiled inference
135+
# tiled_inference_parameters:
136+
# h_crop: 224
137+
# h_stride: 192
138+
# w_crop: 224
139+
# w_stride: 192
140+
# average_patches: true
141+
optimizer:
142+
class_path: torch.optim.AdamW
143+
init_args:
144+
lr: 0.00013524680528283027
145+
weight_decay: 0.047782217873995426
146+
lr_scheduler:
147+
class_path: ReduceLROnPlateau
148+
init_args:
149+
monitor: val/loss
150+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 42
3+
trainer:
4+
accelerator: auto
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
# precision: 16-mixed
9+
logger:
10+
class_path: TensorBoardLogger
11+
init_args:
12+
save_dir: /tmp
13+
name: all_ecos_random
14+
callbacks:
15+
- class_path: RichProgressBar
16+
- class_path: LearningRateMonitor
17+
init_args:
18+
logging_interval: epoch
19+
- class_path: EarlyStopping
20+
init_args:
21+
monitor: val/loss
22+
patience: 100
23+
max_epochs: 5
24+
check_val_every_n_epoch: 1
25+
log_every_n_steps: 20
26+
enable_checkpointing: true
27+
default_root_dir: /tmp
28+
data:
29+
class_path: GenericNonGeoPixelwiseRegressionDataModule
30+
init_args:
31+
batch_size: 2
32+
num_workers: 8
33+
train_transform:
34+
- class_path: albumentations.HorizontalFlip
35+
init_args:
36+
p: 0.5
37+
- class_path: albumentations.Rotate
38+
init_args:
39+
limit: 30
40+
border_mode: 0 # cv2.BORDER_CONSTANT
41+
value: 0
42+
# mask_value: 1
43+
p: 0.5
44+
- class_path: ToTensorV2
45+
dataset_bands:
46+
- 0
47+
- BLUE
48+
- GREEN
49+
- RED
50+
- NIR_NARROW
51+
- SWIR_1
52+
- SWIR_2
53+
- 1
54+
- 2
55+
- 3
56+
- 4
57+
output_bands:
58+
- BLUE
59+
- GREEN
60+
- RED
61+
- NIR_NARROW
62+
- SWIR_1
63+
- SWIR_2
64+
rgb_indices:
65+
- 2
66+
- 1
67+
- 0
68+
train_data_root: tests/
69+
train_label_data_root: tests/
70+
val_data_root: tests/
71+
val_label_data_root: tests/
72+
test_data_root: tests/
73+
test_label_data_root: tests/
74+
img_grep: "regression*input*.tif"
75+
label_grep: "regression*label*.tif"
76+
means:
77+
- 547.36707
78+
- 898.5121
79+
- 1020.9082
80+
- 2665.5352
81+
- 2340.584
82+
- 1610.1407
83+
stds:
84+
- 411.4701
85+
- 558.54065
86+
- 815.94025
87+
- 812.4403
88+
- 1113.7145
89+
- 1067.641
90+
no_label_replace: -1
91+
no_data_replace: 0
92+
93+
model:
94+
class_path: terratorch.tasks.PixelwiseRegressionTask
95+
init_args:
96+
model_args:
97+
decoder: UperNetDecoder
98+
pretrained: true
99+
backbone: prithvi_swin_L
100+
backbone_pretrained_cfg_overlay:
101+
file: /tmp/prithvi_swin_L.pt
102+
backbone_drop_path_rate: 0.3
103+
# backbone_window_size: 8
104+
decoder_channels: 64
105+
in_channels: 6
106+
bands:
107+
- BLUE
108+
- GREEN
109+
- RED
110+
- NIR_NARROW
111+
- SWIR_1
112+
- SWIR_2
113+
num_frames: 1
114+
head_dropout: 0.5708022831486758
115+
head_final_act: torch.nn.ReLU
116+
head_learned_upscale_layers: 2
117+
loss: rmse
118+
#aux_heads:
119+
# - name: aux_head
120+
# decoder: IdentityDecoder
121+
# decoder_args:
122+
# decoder_out_index: 2
123+
# head_dropout: 0,5
124+
# head_channel_list:
125+
# - 64
126+
# head_final_act: torch.nn.ReLU
127+
#aux_loss:
128+
# aux_head: 0.4
129+
ignore_index: -1
130+
freeze_backbone: false
131+
freeze_decoder: false
132+
model_factory: PrithviModelFactory
133+
134+
# uncomment this block for tiled inference
135+
# tiled_inference_parameters:
136+
# h_crop: 224
137+
# h_stride: 192
138+
# w_crop: 224
139+
# w_stride: 192
140+
# average_patches: true
141+
optimizer:
142+
class_path: torch.optim.AdamW
143+
init_args:
144+
lr: 0.00013524680528283027
145+
weight_decay: 0.047782217873995426
146+
lr_scheduler:
147+
class_path: ReduceLROnPlateau
148+
init_args:
149+
monitor: val/loss
150+

0 commit comments

Comments
 (0)