Skip to content

Commit de533dd

Browse files
YAML file for testing string as bands
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
1 parent 831c662 commit de533dd

File tree

1 file changed

+149
-0
lines changed

1 file changed

+149
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 42
3+
trainer:
4+
accelerator: auto
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
# precision: 16-mixed
9+
logger:
10+
class_path: TensorBoardLogger
11+
init_args:
12+
save_dir: tests/
13+
name: all_ecos_random
14+
callbacks:
15+
- class_path: RichProgressBar
16+
- class_path: LearningRateMonitor
17+
init_args:
18+
logging_interval: epoch
19+
- class_path: EarlyStopping
20+
init_args:
21+
monitor: val/loss
22+
patience: 100
23+
max_epochs: 5
24+
check_val_every_n_epoch: 1
25+
log_every_n_steps: 20
26+
enable_checkpointing: true
27+
default_root_dir: tests/
28+
data:
29+
class_path: GenericNonGeoPixelwiseRegressionDataModule
30+
init_args:
31+
batch_size: 2
32+
num_workers: 4
33+
train_transform:
34+
- class_path: albumentations.HorizontalFlip
35+
init_args:
36+
p: 0.5
37+
- class_path: albumentations.Rotate
38+
init_args:
39+
limit: 30
40+
border_mode: 0 # cv2.BORDER_CONSTANT
41+
value: 0
42+
# mask_value: 1
43+
p: 0.5
44+
- class_path: ToTensorV2
45+
dataset_bands:
46+
- "band_1"
47+
- "band_2"
48+
- "band_3"
49+
- "band_4"
50+
- "band_5"
51+
- "band_6"
52+
- "band_7"
53+
- "band_8"
54+
- "band_9"
55+
- "band_10"
56+
output_bands:
57+
- "band_2"
58+
- "band_3"
59+
- "band_4"
60+
- "band_5"
61+
- "band_6"
62+
- "band_7"
63+
rgb_indices:
64+
- 2
65+
- 1
66+
- 0
67+
train_data_root: tests/
68+
train_label_data_root: tests/
69+
val_data_root: tests/
70+
val_label_data_root: tests/
71+
test_data_root: tests/
72+
test_label_data_root: tests/
73+
img_grep: "regression*input*.tif"
74+
label_grep: "regression*label*.tif"
75+
means:
76+
- 547.36707
77+
- 898.5121
78+
- 1020.9082
79+
- 2665.5352
80+
- 2340.584
81+
- 1610.1407
82+
stds:
83+
- 411.4701
84+
- 558.54065
85+
- 815.94025
86+
- 812.4403
87+
- 1113.7145
88+
- 1067.641
89+
no_label_replace: -1
90+
no_data_replace: 0
91+
92+
model:
93+
class_path: terratorch.tasks.PixelwiseRegressionTask
94+
init_args:
95+
model_args:
96+
decoder: UperNetDecoder
97+
pretrained: true
98+
backbone: prithvi_swin_B
99+
backbone_pretrained_cfg_overlay:
100+
file: tests/prithvi_swin_B.pt
101+
backbone_drop_path_rate: 0.3
102+
# backbone_window_size: 8
103+
decoder_channels: 256
104+
in_channels: 6
105+
bands:
106+
- BLUE
107+
- GREEN
108+
- RED
109+
- NIR_NARROW
110+
- SWIR_1
111+
- SWIR_2
112+
num_frames: 1
113+
head_dropout: 0.5708022831486758
114+
head_final_act: torch.nn.ReLU
115+
head_learned_upscale_layers: 2
116+
loss: rmse
117+
#aux_heads:
118+
# - name: aux_head
119+
# decoder: IdentityDecoder
120+
# decoder_args:
121+
# decoder_out_index: 2
122+
# head_dropout: 0,5
123+
# head_channel_list:
124+
# - 64
125+
# head_final_act: torch.nn.ReLU
126+
#aux_loss:
127+
# aux_head: 0.4
128+
ignore_index: -1
129+
freeze_backbone: true
130+
freeze_decoder: false
131+
model_factory: PrithviModelFactory
132+
133+
# uncomment this block for tiled inference
134+
# tiled_inference_parameters:
135+
# h_crop: 224
136+
# h_stride: 192
137+
# w_crop: 224
138+
# w_stride: 192
139+
# average_patches: true
140+
optimizer:
141+
class_path: torch.optim.AdamW
142+
init_args:
143+
lr: 0.00013524680528283027
144+
weight_decay: 0.047782217873995426
145+
lr_scheduler:
146+
class_path: ReduceLROnPlateau
147+
init_args:
148+
monitor: val/loss
149+

0 commit comments

Comments
 (0)