Skip to content

Commit 4384db5

Browse files
merging
Signed-off-by: João Lucas de Sousa Almeida <[email protected]>
2 parents e0b59dd + 308d540 commit 4384db5

File tree

115 files changed

+3539
-1217
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

115 files changed

+3539
-1217
lines changed

.github/dependabot.yaml

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
2+
# mostly from https://github.com/microsoft/torchgeo/blob/main/.github/dependabot.yml
3+
version: 2
4+
updates:
5+
- package-ecosystem: "github-actions"
6+
directory: "/"
7+
schedule:
8+
interval: "weekly"
9+
- package-ecosystem: "pip"
10+
directory: "/"
11+
schedule:
12+
interval: "daily"
13+
groups:
14+
# torchvision pins torch, must update in unison
15+
torch:
16+
patterns:
17+
- "torch"
18+
- "torchvision"
19+
ignore:
20+
# setuptools releases new versions almost daily
21+
- dependency-name: "setuptools"
22+
update-types: ["version-update:semver-patch"]
23+
# segmentation-models-pytorch pins timm, must update in unison
24+
- dependency-name: "timm"

.github/workflows/test.yaml

+14-6
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
name: terratorch tuning toolkit
22

3-
on: [pull_request]
3+
on:
4+
push:
5+
branches:
6+
- main
7+
8+
pull_request:
9+
branches:
10+
- main
411

512
jobs:
613
build:
7-
814
runs-on: ubuntu-latest
915
strategy:
1016
matrix:
1117
python-version: ["3.10", "3.11"]
1218

1319
steps:
14-
- uses: actions/checkout@v3
20+
- name: Clone repo
21+
uses: actions/checkout@v3
1522
- name: Set up Python ${{ matrix.python-version }}
1623
uses: actions/setup-python@v4
1724
with:
@@ -20,8 +27,9 @@ jobs:
2027
- name: Install dependencies
2128
run: |
2229
python -m pip install --upgrade pip
23-
pip install pytest
24-
pip install -e .
30+
pip install -r requirements/required.txt -r requirements/test.txt
31+
- name: List pip dependencies
32+
run: pip list
2533
- name: Test with pytest
2634
run: |
27-
pytest tests
35+
pytest -s tests

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ dist/*
55
*.egg-info
66
*.coverage.*
77
**/*.pt
8-
8+
*.ipynb_checkpoints
9+
**/*pth

README.md

+8-2
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,15 @@ The library provides:
1616
### Pip
1717
In order to use th file `pyproject.toml` it is necessary to guarantee `pip>=21.8`. If necessary upgrade `pip` using `python -m pip install --upgrade pip`.
1818

19-
Install the library with `pip install git+https://github.com/IBM/terratorch.git`
19+
For a stable point-release, use `pip install terratorch`.
20+
If you prefer to get the most recent version of the main branch, install the library with `pip install git+https://github.com/IBM/terratorch.git`.
2021

21-
To install as a developer (e.g. to extend the library) clone this repo, install dependencies using `pip install -r requirements.txt` and run `pip install -e .`
22+
Another alternative is to install using [pipx](https://github.com/pypa/pipx) via `pipx install terratorch`, which creates an isolated environment and allows the user to run the application as
23+
a common CLI tool, with no need of installing dependencies or activating environments.
24+
25+
TerraTorch requires gdal to be installed, which can be quite a complex process. If you don't have GDAL set up on your system, we reccomend using a conda environment and installing it with `conda install -c conda-forge gdal`.
26+
27+
To install as a developer (e.g. to extend the library) clone this repo, install dependencies using `pip install -r requirements/required.txt -r requirements/dev.txt` and run `pip install -e .`
2228

2329
## Quick start
2430

docs/examples.md

+3-5
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@ For some examples of training using the existing tasks, check out the following
66

77
Under `examples/confs`
88

9-
* Flood Segmentation with ViT: `segmentation_config_vit.yaml`
9+
* Flood Segmentation with ViT: `sen1floods11_vit.yaml`
1010

1111
* Multitemporal Crop Segmentation: `multitemporal_crop.yaml`
1212

13-
* Scene Classification: `eurosat.yaml`
13+
* Burn Scar Segmentation: `burn_scars.yaml`
1414

15-
* Usage of an SMP backbone `geobench/segmentation/m_chesapeake_landcover_smp_resnet_unet.yaml`
16-
17-
* Usage of a timm backbone `geobench/classification/m_bigearthnet_timm_resnet.yaml`
15+
* Scene Classification: `eurosat.yaml`

docs/models.md

+7
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ We also provide a model factory that can build a task specific model for a downs
3434

3535
By passing a list of bands being used to the constructor, we automatically filter out unused bands, and randomly initialize weights for new bands that were not pretrained on.
3636

37+
!!! info
38+
39+
To pass your own path from where to load the weights with the PrithviModelFactory, you can make use of timm's `pretrained_cfg_overlay`.
40+
E.g. to pass a local path, you can pass the parameter `backbone_pretrained_cfg_overlay = {"file": "<local_path>"}` to the model factory.
41+
42+
Besides `file`, you can also pass `url`, `hf_hub_id`, amongst others. Check timm's documentation for full details.
43+
3744
:::terratorch.models.backbones.prithvi_select_patch_embed_weights
3845

3946
## Decoders

docs/quick_start.md

+19-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
# Quick start
22
We suggest using Python==3.10.
3-
To get started, make sure to have `PyTorch >= 2` [installed](https://pytorch.org/get-started/locally/).
3+
To get started, make sure to have [PyTorch](https://pytorch.org/get-started/locally/) >= 2.0.0 and [GDAL](https://gdal.org/index.html) installed.
44

5-
To install the package, clone the repository and install it with `pip install -e .` from within the repository directory.
5+
Installing GDAL can be quite a complex process. If you don't have GDAL set up on your system, we reccomend using a conda environment and installing it with `conda install -c conda-forge gdal`.
6+
7+
For a stable point-release, use `pip install terratorch`.
8+
If you prefer to get the most recent version of the main branch, install the library with `pip install git+https://github.com/IBM/terratorch.git`.
9+
10+
To install as a developer (e.g. to extend the library) clone this repo, and run `pip install -e .`.
611

712
You can interact with the library at several levels of abstraction. Each deeper level of abstraction trades off some amount of flexibility for ease of use and configuration.
813

@@ -105,6 +110,17 @@ task = PixelwiseRegressionTask(
105110

106111
At this level of abstraction, you can also provide a configuration file (see [LightningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html#lightning-cli)) with all the details of the training. See an example for semantic segmentation below:
107112

113+
!!! info
114+
115+
To pass your own path from where to load the weights with the PrithviModelFactory, you can make use of timm's `pretrained_cfg_overlay`.
116+
E.g. to pass a local path, you can add, under model_args:
117+
118+
```yaml
119+
backbone_pretrained_cfg_overlay:
120+
file: <local_path>
121+
```
122+
Besides `file`, you can also pass `url`, `hf_hub_id`, amongst others. Check timm's documentation for full details.
123+
108124
```yaml title="Configuration file for a Semantic Segmentation Task"
109125
# lightning.pytorch==2.1.1
110126
seed_everything: 0
@@ -220,4 +236,4 @@ To run this training task, simply execute `terratorch fit --config <path_to_conf
220236

221237
To test your model on the test set, execute `terratorch test --config <path_to_config_file> --ckpt_path <path_to_checkpoint_file>`
222238

223-
For inference, execute `terratorch predict -c <path_to_config_file> --ckpt_path<path_to_checkpoint> --predict_output_dir <path_to_output_dir> --data.init_args.predict_data_root <path_to_input_dir> --data.init_args.predict_dataset_bands <all bands in the predicted dataset, e.g. [BLUE,GREEN,RED,NIR_NARROW,SWIR_1,SWIR_2,0]>`
239+
For inference, execute `terratorch predict -c <path_to_config_file> --ckpt_path<path_to_checkpoint> --predict_output_dir <path_to_output_dir> --data.init_args.predict_data_root <path_to_input_dir> --data.init_args.predict_dataset_bands <all bands in the predicted dataset, e.g. [BLUE,GREEN,RED,NIR_NARROW,SWIR_1,SWIR_2,0]>`

examples/confs/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
Instructions to download the sen1floods11 data are [here](https://github.com/cloudtostreet/Sen1Floods11).
22
Split files must be converted using `terratorch/examples/scripts/convert_sen1floods11_splits.py`.
33
EOFM checkpoints can be donwloaded from [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11/tree/main).
4+
Copy the `*txt` files to `<senfloods_root>/v1.1/splits/flood_handlabeled/` before starting the job.

examples/confs/burn_scars.yaml

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 0
3+
trainer:
4+
accelerator: auto
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
precision: 16-mixed
9+
logger:
10+
class_path: TensorBoardLogger
11+
init_args:
12+
save_dir: <path>
13+
name: fire_scars
14+
callbacks:
15+
- class_path: RichProgressBar
16+
- class_path: LearningRateMonitor
17+
init_args:
18+
logging_interval: epoch
19+
- class_path: EarlyStopping
20+
init_args:
21+
monitor: val/loss
22+
patience: 40
23+
24+
max_epochs: 200
25+
check_val_every_n_epoch: 1
26+
log_every_n_steps: 50
27+
enable_checkpointing: true
28+
default_root_dir: <path>
29+
30+
# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars
31+
data:
32+
class_path: GenericNonGeoSegmentationDataModule
33+
init_args:
34+
batch_size: 4
35+
num_workers: 8
36+
dataset_bands:
37+
- BLUE
38+
- GREEN
39+
- RED
40+
- NIR_NARROW
41+
- SWIR_1
42+
- SWIR_2
43+
output_bands:
44+
- BLUE
45+
- GREEN
46+
- RED
47+
- NIR_NARROW
48+
- SWIR_1
49+
- SWIR_2
50+
rgb_indices:
51+
- 0
52+
- 1
53+
- 2
54+
train_transform:
55+
- class_path: albumentations.RandomCrop
56+
init_args:
57+
height: 224
58+
width: 224
59+
- class_path: albumentations.HorizontalFlip
60+
init_args:
61+
p: 0.5
62+
- class_path: ToTensorV2
63+
no_data_replace: 0
64+
no_label_replace: -1
65+
train_data_root: <data_path>/training
66+
train_label_data_root: <data_path>/training
67+
val_data_root: <data_path>/validation
68+
val_label_data_root: <data_path>/validation
69+
test_data_root: <data_path>/validation
70+
test_label_data_root: <data_path>/validation
71+
img_grep: "*_merged.tif"
72+
label_grep: "*.mask.tif"
73+
means:
74+
- 0.033349706741586264
75+
- 0.05701185520536176
76+
- 0.05889748132001316
77+
- 0.2323245113436119
78+
- 0.1972854853760658
79+
- 0.11944914225186566
80+
stds:
81+
- 0.02269135568823774
82+
- 0.026807560223070237
83+
- 0.04004109844362779
84+
- 0.07791732423672691
85+
- 0.08708738838140137
86+
- 0.07241979477437814
87+
num_classes: 2
88+
89+
model:
90+
class_path: terratorch.tasks.SemanticSegmentationTask
91+
init_args:
92+
model_args:
93+
decoder: FCNDecoder
94+
pretrained: true
95+
backbone: prithvi_vit_100
96+
decoder_channels: 256
97+
in_channels: 6
98+
bands:
99+
- BLUE
100+
- GREEN
101+
- RED
102+
- NIR_NARROW
103+
- SWIR_1
104+
- SWIR_2
105+
num_frames: 1
106+
num_classes: 2
107+
head_dropout: 0.1
108+
decoder_num_convs: 4
109+
head_channel_list:
110+
- 256
111+
loss: dice
112+
plot_on_val: 10
113+
ignore_index: -1
114+
freeze_backbone: false
115+
freeze_decoder: false
116+
model_factory: PrithviModelFactory
117+
tiled_inference_parameters:
118+
h_crop: 512
119+
h_stride: 496
120+
w_crop: 512
121+
w_stride: 496
122+
average_patches: true
123+
optimizer:
124+
class_path: torch.optim.Adam
125+
init_args:
126+
lr: 1.5e-5
127+
weight_decay: 0.05
128+
lr_scheduler:
129+
class_path: ReduceLROnPlateau
130+
init_args:
131+
monitor: val/loss

0 commit comments

Comments
 (0)