Skip to content

Commit 295128f

Browse files
Testing the definition by interval using a dedicated yaml file
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent a0ce8aa commit 295128f

File tree

2 files changed

+149
-0
lines changed

2 files changed

+149
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 42
3+
trainer:
4+
accelerator: auto
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
# precision: 16-mixed
9+
logger:
10+
class_path: TensorBoardLogger
11+
init_args:
12+
save_dir: tests/
13+
name: all_ecos_random
14+
callbacks:
15+
- class_path: RichProgressBar
16+
- class_path: LearningRateMonitor
17+
init_args:
18+
logging_interval: epoch
19+
- class_path: EarlyStopping
20+
init_args:
21+
monitor: val/loss
22+
patience: 100
23+
max_epochs: 5
24+
check_val_every_n_epoch: 1
25+
log_every_n_steps: 20
26+
enable_checkpointing: true
27+
default_root_dir: tests/
28+
data:
29+
class_path: GenericNonGeoPixelwiseRegressionDataModule
30+
init_args:
31+
batch_size: 2
32+
num_workers: 4
33+
train_transform:
34+
- class_path: albumentations.HorizontalFlip
35+
init_args:
36+
p: 0.5
37+
- class_path: albumentations.Rotate
38+
init_args:
39+
limit: 30
40+
border_mode: 0 # cv2.BORDER_CONSTANT
41+
value: 0
42+
# mask_value: 1
43+
p: 0.5
44+
- class_path: ToTensorV2
45+
dataset_bands:
46+
- [0, 11]
47+
output_bands:
48+
- [1, 3]
49+
- [4, 6]
50+
rgb_indices:
51+
- 2
52+
- 1
53+
- 0
54+
train_data_root: tests/
55+
train_label_data_root: tests/
56+
val_data_root: tests/
57+
val_label_data_root: tests/
58+
test_data_root: tests/
59+
test_label_data_root: tests/
60+
img_grep: "regression*input*.tif"
61+
label_grep: "regression*label*.tif"
62+
means:
63+
- 547.36707
64+
- 898.5121
65+
- 1020.9082
66+
- 2665.5352
67+
- 2340.584
68+
- 1610.1407
69+
stds:
70+
- 411.4701
71+
- 558.54065
72+
- 815.94025
73+
- 812.4403
74+
- 1113.7145
75+
- 1067.641
76+
no_label_replace: -1
77+
no_data_replace: 0
78+
79+
model:
80+
class_path: terratorch.tasks.PixelwiseRegressionTask
81+
init_args:
82+
model_args:
83+
decoder: UperNetDecoder
84+
pretrained: true
85+
backbone: prithvi_swin_B
86+
backbone_pretrained_cfg_overlay:
87+
file: tests/prithvi_swin_B.pt
88+
backbone_drop_path_rate: 0.3
89+
# backbone_window_size: 8
90+
decoder_channels: 256
91+
in_channels: 6
92+
bands:
93+
- BLUE
94+
- GREEN
95+
- RED
96+
- NIR_NARROW
97+
- SWIR_1
98+
- SWIR_2
99+
num_frames: 1
100+
head_dropout: 0.5708022831486758
101+
head_final_act: torch.nn.ReLU
102+
head_learned_upscale_layers: 2
103+
loss: rmse
104+
#aux_heads:
105+
# - name: aux_head
106+
# decoder: IdentityDecoder
107+
# decoder_args:
108+
# decoder_out_index: 2
109+
# head_dropout: 0,5
110+
# head_channel_list:
111+
# - 64
112+
# head_final_act: torch.nn.ReLU
113+
#aux_loss:
114+
# aux_head: 0.4
115+
ignore_index: -1
116+
freeze_backbone: true
117+
freeze_decoder: false
118+
model_factory: PrithviModelFactory
119+
120+
# uncomment this block for tiled inference
121+
# tiled_inference_parameters:
122+
# h_crop: 224
123+
# h_stride: 192
124+
# w_crop: 224
125+
# w_stride: 192
126+
# average_patches: true
127+
optimizer:
128+
class_path: torch.optim.AdamW
129+
init_args:
130+
lr: 0.00013524680528283027
131+
weight_decay: 0.047782217873995426
132+
lr_scheduler:
133+
class_path: ReduceLROnPlateau
134+
init_args:
135+
monitor: val/loss
136+

tests/test_finetune.py

+13
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,19 @@ def test_finetune_multiple_backbones(model_name):
2323
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}.yaml"]
2424
_ = build_lightning_cli(command_list)
2525

26+
@pytest.mark.parametrize("model_name", ["prithvi_swin_B"])
27+
def test_finetune_bands_intervals(model_name):
28+
29+
model_instance = timm.create_model(model_name)
30+
31+
state_dict = model_instance.state_dict()
32+
33+
torch.save(state_dict, os.path.join("tests/", model_name + ".pt"))
34+
35+
# Running the terratorch CLI
36+
command_list = ["fit", "-c", f"tests/manufactured-finetune_{model_name}_band_interval.yaml"]
37+
_ = build_lightning_cli(command_list)
38+
2639
"""
2740
@pytest.mark.parametrize("model_name", ["prithvi_swin_B", "prithvi_swin_L", "prithvi_vit_100", "prithvi_vit_300"])
2841
def test_finetune_multiple_backbones(model_name):

0 commit comments

Comments
 (0)