Skip to content

Commit 50b909b

Browse files
Merge pull request #14 from IBM/feature/add_more_examples
Feature/add more examples
2 parents b699344 + d891589 commit 50b909b

File tree

3 files changed

+303
-5
lines changed

3 files changed

+303
-5
lines changed

docs/examples.md

+3-5
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@ For some examples of training using the existing tasks, check out the following
66

77
Under `examples/confs`
88

9-
* Flood Segmentation with ViT: `segmentation_config_vit.yaml`
9+
* Flood Segmentation with ViT: `sen1floods11_vit.yaml`
1010

1111
* Multitemporal Crop Segmentation: `multitemporal_crop.yaml`
1212

13-
* Scene Classification: `eurosat.yaml`
13+
* Burn Scar Segmentation: `burn_scars.yaml`
1414

15-
* Usage of an SMP backbone `geobench/segmentation/m_chesapeake_landcover_smp_resnet_unet.yaml`
16-
17-
* Usage of a timm backbone `geobench/classification/m_bigearthnet_timm_resnet.yaml`
15+
* Scene Classification: `eurosat.yaml`

examples/confs/burn_scars.yaml

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 0
3+
trainer:
4+
accelerator: auto
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
precision: 16-mixed
9+
logger:
10+
class_path: TensorBoardLogger
11+
init_args:
12+
save_dir: <path>
13+
name: fire_scars
14+
callbacks:
15+
- class_path: RichProgressBar
16+
- class_path: LearningRateMonitor
17+
init_args:
18+
logging_interval: epoch
19+
- class_path: EarlyStopping
20+
init_args:
21+
monitor: val/loss
22+
patience: 40
23+
24+
max_epochs: 200
25+
check_val_every_n_epoch: 1
26+
log_every_n_steps: 50
27+
enable_checkpointing: true
28+
default_root_dir: <path>
29+
30+
# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars
31+
data:
32+
class_path: GenericNonGeoSegmentationDataModule
33+
init_args:
34+
batch_size: 4
35+
num_workers: 8
36+
dataset_bands:
37+
- BLUE
38+
- GREEN
39+
- RED
40+
- NIR_NARROW
41+
- SWIR_1
42+
- SWIR_2
43+
output_bands:
44+
- BLUE
45+
- GREEN
46+
- RED
47+
- NIR_NARROW
48+
- SWIR_1
49+
- SWIR_2
50+
rgb_indices:
51+
- 0
52+
- 1
53+
- 2
54+
train_transform:
55+
- class_path: albumentations.RandomCrop
56+
init_args:
57+
height: 224
58+
width: 224
59+
- class_path: albumentations.HorizontalFlip
60+
init_args:
61+
p: 0.5
62+
- class_path: ToTensorV2
63+
no_data_replace: 0
64+
no_label_replace: -1
65+
train_data_root: <data_path>/training
66+
train_label_data_root: <data_path>/training
67+
val_data_root: <data_path>/validation
68+
val_label_data_root: <data_path>/validation
69+
test_data_root: <data_path>/validation
70+
test_label_data_root: <data_path>/validation
71+
img_grep: "*_merged.tif"
72+
label_grep: "*.mask.tif"
73+
means:
74+
- 0.033349706741586264
75+
- 0.05701185520536176
76+
- 0.05889748132001316
77+
- 0.2323245113436119
78+
- 0.1972854853760658
79+
- 0.11944914225186566
80+
stds:
81+
- 0.02269135568823774
82+
- 0.026807560223070237
83+
- 0.04004109844362779
84+
- 0.07791732423672691
85+
- 0.08708738838140137
86+
- 0.07241979477437814
87+
num_classes: 2
88+
89+
model:
90+
class_path: terratorch.tasks.SemanticSegmentationTask
91+
init_args:
92+
model_args:
93+
decoder: FCNDecoder
94+
pretrained: true
95+
backbone: prithvi_vit_100
96+
decoder_channels: 256
97+
in_channels: 6
98+
bands:
99+
- BLUE
100+
- GREEN
101+
- RED
102+
- NIR_NARROW
103+
- SWIR_1
104+
- SWIR_2
105+
num_frames: 1
106+
num_classes: 2
107+
head_dropout: 0.1
108+
decoder_num_convs: 4
109+
head_channel_list:
110+
- 256
111+
loss: dice
112+
plot_on_val: 10
113+
ignore_index: -1
114+
freeze_backbone: false
115+
freeze_decoder: false
116+
model_factory: PrithviModelFactory
117+
tiled_inference_parameters:
118+
h_crop: 512
119+
h_stride: 496
120+
w_crop: 512
121+
w_stride: 496
122+
average_patches: true
123+
optimizer:
124+
class_path: torch.optim.Adam
125+
init_args:
126+
lr: 1.5e-5
127+
weight_decay: 0.05
128+
lr_scheduler:
129+
class_path: ReduceLROnPlateau
130+
init_args:
131+
monitor: val/loss
+169
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 0
3+
trainer:
4+
accelerator: auto
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
precision: 16-mixed
9+
logger:
10+
class_path: TensorBoardLogger
11+
init_args:
12+
save_dir: <path>
13+
name: replicate
14+
callbacks:
15+
- class_path: RichProgressBar
16+
- class_path: LearningRateMonitor
17+
init_args:
18+
logging_interval: epoch
19+
20+
max_epochs: 200
21+
check_val_every_n_epoch: 1
22+
log_every_n_steps: 50
23+
enable_checkpointing: true
24+
default_root_dir: <path>
25+
26+
# data available at: https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification
27+
data:
28+
class_path: GenericNonGeoSegmentationDataModule
29+
init_args:
30+
batch_size: 8
31+
num_workers: 12
32+
train_transform:
33+
- class_path: FlattenTemporalIntoChannels
34+
- class_path: albumentations.Flip
35+
- class_path: ToTensorV2
36+
- class_path: UnflattenTemporalFromChannels
37+
init_args:
38+
n_timesteps: 3
39+
40+
dataset_bands:
41+
- BLUE
42+
- GREEN
43+
- RED
44+
- NIR_NARROW
45+
- SWIR_1
46+
- SWIR_2
47+
output_bands:
48+
- BLUE
49+
- GREEN
50+
- RED
51+
- NIR_NARROW
52+
- SWIR_1
53+
- SWIR_2
54+
rgb_indices:
55+
- 2
56+
- 1
57+
- 0
58+
reduce_zero_label: True
59+
expand_temporal_dimension: True
60+
train_data_root: <data_path>/training_chips
61+
train_label_data_root: <data_path>/training_chips
62+
val_data_root: <data_path>/validation_chips
63+
val_label_data_root: <data_path>/validation_chips
64+
test_data_root: <data_path>/validation_chips
65+
test_label_data_root: <data_path>/validation_chips
66+
train_split: <data_path>/training_chips/training_data.txt
67+
test_split: <data_path>/validation_chips/validation_data.txt
68+
val_split: <data_path>/validation_chips/validation_data.txt
69+
img_grep: "*_merged.tif"
70+
label_grep: "*.mask.tif"
71+
means:
72+
- 494.905781
73+
- 815.239594
74+
- 924.335066
75+
- 2968.881459
76+
- 2634.621962
77+
- 1739.579917
78+
stds:
79+
- 284.925432
80+
- 357.84876
81+
- 575.566823
82+
- 896.601013
83+
- 951.900334
84+
- 921.407808
85+
num_classes: 13
86+
87+
model:
88+
class_path: terratorch.tasks.SemanticSegmentationTask
89+
init_args:
90+
model_args:
91+
decoder: FCNDecoder
92+
pretrained: true
93+
backbone: prithvi_vit_100
94+
in_channels: 6
95+
rescale: False
96+
bands:
97+
- BLUE
98+
- GREEN
99+
- RED
100+
- NIR_NARROW
101+
- SWIR_1
102+
- SWIR_2
103+
num_frames: 3
104+
num_classes: 13
105+
head_dropout: 0.1
106+
decoder_channels: 512
107+
head_channel_list:
108+
- 128
109+
- 64
110+
loss: ce
111+
class_names:
112+
- Natural Vegetation
113+
- Forest
114+
- Corn
115+
- Soybeans
116+
- Wetlands
117+
- Developed/Barren
118+
- Open Water
119+
- Winter Wheat
120+
- Alfalfa
121+
- Fallow/Idle Cropland
122+
- Cotton
123+
- Sorghum
124+
- Other
125+
# aux_heads:
126+
# - name: aux_head
127+
# decoder: FCNDecoder
128+
# decoder_args:
129+
# decoder_channels: 256
130+
# decoder_in_index: 2
131+
# decoder_num_convs: 2
132+
# head_channel_list:
133+
# - 64
134+
# aux_loss:
135+
# aux_head: 1.0
136+
class_weights:
137+
- 0.386375
138+
- 0.661126
139+
- 0.548184
140+
- 0.640482
141+
- 0.876862
142+
- 0.925186
143+
- 3.249462
144+
- 1.542289
145+
- 2.175141
146+
- 2.272419
147+
- 3.062762
148+
- 3.626097
149+
- 1.198702
150+
151+
ignore_index: -1
152+
freeze_backbone: false
153+
freeze_decoder: false
154+
model_factory: PrithviModelFactory
155+
tiled_inference_parameters:
156+
h_crop: 224
157+
h_stride: 196
158+
w_crop: 224
159+
w_stride: 196
160+
average_patches: true
161+
optimizer:
162+
class_path: torch.optim.AdamW
163+
init_args:
164+
lr: 1.5e-5
165+
weight_decay: 0.05
166+
lr_scheduler:
167+
class_path: ReduceLROnPlateau
168+
init_args:
169+
monitor: val/loss

0 commit comments

Comments
 (0)