Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds SMPModelFactory example #72

Closed
wants to merge 41 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
fb68d27
extends smp_model_factory class
PedroConrado Jul 19, 2024
52699a9
extends smp model factory and adds functionalities in prithvi model f…
PedroConrado Jul 28, 2024
93c523e
extends SMPModelFactory
PedroConrado Jul 28, 2024
e544105
Extends SMPModelFactory and smp decoder in PrithviModelFactoy
PedroConrado Jul 29, 2024
7edbbf8
adds SMPModelFactory tests and SMPModelFactory to model.md
PedroConrado Jul 29, 2024
97b4688
adds SMPModelFactory tests and SMPModelFactory to docs/model.md
PedroConrado Jul 29, 2024
4a29f56
adds smp_model_factory exaple
PedroConrado Aug 2, 2024
a0d19aa
Bands could be define by intervals
Joao-L-S-Almeida Jul 18, 2024
4991ff2
Constructing the bands using the definition by interval
Joao-L-S-Almeida Jul 19, 2024
beb2235
Extending the supported formats for bands to include list[int]
Joao-L-S-Almeida Jul 19, 2024
e8c80e9
Extending the supported formats for bands to include list[int]
Joao-L-S-Almeida Jul 19, 2024
6c37dc6
Testing the definition by interval using a dedicated yaml file
Joao-L-S-Almeida Jul 19, 2024
3a24877
Special case for bands_list=:None
Joao-L-S-Almeida Jul 19, 2024
ef1fef8
Basic support to use simple strings to name the bands
Joao-L-S-Almeida Jul 22, 2024
3726306
Strings are allowed to define bands
Joao-L-S-Almeida Jul 22, 2024
0fe28de
Testing to use strings to define a model
Joao-L-S-Almeida Jul 22, 2024
5c6e18c
Exception for None inputs
Joao-L-S-Almeida Jul 22, 2024
b31953e
Support for str
Joao-L-S-Almeida Jul 22, 2024
2aaf3f9
YAML file for testing string as bands
Joao-L-S-Almeida Jul 22, 2024
75e5ecf
This test is no longer required
Joao-L-S-Almeida Jul 22, 2024
77b63a3
Band intervals should be tuples with two entries
Joao-L-S-Almeida Jul 23, 2024
ddcf42c
More compact way to check if the bands are defined by interval
Joao-L-S-Almeida Jul 23, 2024
9e3e301
This warning is not necessary
Joao-L-S-Almeida Jul 23, 2024
8f202cb
Reformatting using black
Joao-L-S-Almeida Jul 23, 2024
f06c51e
Minor improvements
Joao-L-S-Almeida Jul 23, 2024
44b25bd
Missing imports
Joao-L-S-Almeida Jul 23, 2024
61d0061
More tests to check if the bands ar properly returned
Joao-L-S-Almeida Jul 24, 2024
6b870cf
accept mixed band specifications
CarlosGomes98 Jul 26, 2024
20753e4
improve docstring comments
CarlosGomes98 Jul 26, 2024
2212a30
add scale modules to upernet for vit backbone
CarlosGomes98 Jul 29, 2024
5d9350d
Metrics as std and mean could be informed via file
Joao-L-S-Almeida Jul 26, 2024
8b23422
Reading metrics (std and mean) from file
Joao-L-S-Almeida Jul 26, 2024
9d30be2
Auxiliary files
Joao-L-S-Almeida Jul 26, 2024
461a1f8
Minor issues solved
Joao-L-S-Almeida Jul 29, 2024
8a74d0f
space removed
Joao-L-S-Almeida Jul 29, 2024
c779521
This function must be shared
Joao-L-S-Almeida Jul 30, 2024
c626f77
This function should not be here
Joao-L-S-Almeida Jul 30, 2024
bd03080
pin albumentations
CarlosGomes98 Jul 31, 2024
c596fc6
add filename only after transforms
CarlosGomes98 Jul 31, 2024
9912e04
remove spaces in metric names
CarlosGomes98 Jul 31, 2024
e4789cf
extends smp_model_factory class (#56)
PedroConrado Aug 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ By passing a list of bands being used to the constructor, we automatically filte

## Model Factory
### :::terratorch.models.PrithviModelFactory
### :::terratorch.models.SMPModelFactory

# Adding new model types
Adding new model types is as simple as creating a new factory that produces models. See for instance the example below for a potential `SMPModelFactory`
Expand Down
93 changes: 93 additions & 0 deletions examples/confs/smp_model_factory.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
benchmark_suffix: smp_test
experiment_name: smp_test
backbone:
# backbone: resnet18
# backbone_args:
# pretrained: False
# output_stride: 2
# smp_decoder_channels: 512
# smp_encoder_depth: 5

backbone: swin3d.swin3d_backbone.Swin3dBackbone
backbone_args:
pretrained: False
output_stride: 2
out_channels:
- 192
- 384
- 768
- 768
smp_decoder_channels: 768
smp_encoder_depth: 5


tasks:
- name: cashew
type: segmentation
loss: ce
model_factory: SMPModelFactory
bands:
- RED
- GREEN
- BLUE
num_classes: 7
max_epochs: 60
direction: max
datamodule:
class_path: terratorch.datamodules.MBeninSmallHolderCashewsNonGeoDataModule
init_args:
batch_size: 16
num_workers: 4
train_transform:
- class_path: albumentations.Resize
init_args:
always_apply: True
height: 224
width: 224
- class_path: ToTensorV2
test_transform:
- class_path: albumentations.Resize
init_args:
always_apply: True
height: 224
width: 224
- class_path: ToTensorV2
val_transform:
- class_path: albumentations.Resize
init_args:
height: 224
width: 224
- class_path: ToTensorV2
data_root: "/dccstor/geofm-finetuning/geobench/segmentation_v1.0"
bands:
- "RED"
- "GREEN"
- "BLUE"
decoder: Unet
decoder_args:
channels: 128
metric: val/Multiclass Jaccard Index

n_trials: 16
save_models: False
storage_uri: /dccstor/geofm-finetuning/pedrohc/smp_test
optimization_space:
model:
- DeepLabV3
lr:
min: 6e-5
max: 1e-3
type: real
log: true
batch_size:
- 8
- 16
- 32
decoder_channels:
- 32
- 64
- 128
head_dropout:
min: 0.2
max: 0.8
type: real
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ dependencies = [
"geobench>=1.0.0",
"mlflow>=2.12.1",
# broken due to https://github.com/Lightning-AI/pytorch-lightning/issues/19977
"lightning>=2, <=2.2.5"
"lightning>=2, <=2.2.5",
# see issue #64
"albumentations<=1.4.10"
]

[project.optional-dependencies]
Expand Down
32 changes: 19 additions & 13 deletions terratorch/datamodules/generic_pixel_wise_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
"""
This module contains generic data modules for instantiation at runtime.
"""

import os
from collections.abc import Callable, Iterable
from pathlib import Path
from typing import Any

import numpy as np
import albumentations as A
import kornia.augmentation as K
import torch
Expand All @@ -17,7 +17,7 @@
from torchgeo.transforms import AugmentationSequential

from terratorch.datasets import GenericNonGeoPixelwiseRegressionDataset, GenericNonGeoSegmentationDataset, HLSBands

from terratorch.io.file import load_from_file_or_attribute

def wrap_in_compose_is_list(transform_list):
# set check shapes to false because of the multitemporal case
Expand Down Expand Up @@ -79,8 +79,8 @@ def __init__(
test_data_root: Path,
img_grep: str,
label_grep: str,
means: list[float],
stds: list[float],
means: list[float] | str,
stds: list[float] | str,
num_classes: int,
predict_data_root: Path | None = None,
train_label_data_root: Path | None = None,
Expand All @@ -91,9 +91,9 @@ def __init__(
test_split: Path | None = None,
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int] | None = None,
predict_dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down Expand Up @@ -198,6 +198,9 @@ def __init__(
# K.Normalize(means, stds),
# data_keys=["image"],
# )
means = load_from_file_or_attribute(means)
stds = load_from_file_or_attribute(stds)

self.aug = Normalize(means, stds)

# self.aug = Normalize(means, stds)
Expand Down Expand Up @@ -317,8 +320,8 @@ def __init__(
train_data_root: Path,
val_data_root: Path,
test_data_root: Path,
means: list[float],
stds: list[float],
means: list[float] | str,
stds: list[float] | str,
predict_data_root: Path | None = None,
img_grep: str | None = "*",
label_grep: str | None = "*",
Expand All @@ -330,9 +333,9 @@ def __init__(
test_split: Path | None = None,
ignore_split_file_extensions: bool = True,
allow_substring_split_file: bool = True,
dataset_bands: list[HLSBands | int] | None = None,
predict_dataset_bands: list[HLSBands | int] | None = None,
output_bands: list[HLSBands | int] | None = None,
dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
predict_dataset_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
output_bands: list[HLSBands | int | tuple[int, int] | str ] | None = None,
constant_scale: float = 1,
rgb_indices: list[int] | None = None,
train_transform: A.Compose | None | list[A.BasicTransform] = None,
Expand Down Expand Up @@ -430,6 +433,9 @@ def __init__(
# K.Normalize(means, stds),
# data_keys=["image"],
# )
means = load_from_file_or_attribute(means)
stds = load_from_file_or_attribute(stds)

self.aug = Normalize(means, stds)
self.no_data_replace = no_data_replace
self.no_label_replace = no_label_replace
Expand Down
10 changes: 7 additions & 3 deletions terratorch/datamodules/generic_scalar_label_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
HLSBands,
)

from terratorch.io.file import load_from_file_or_attribute

def wrap_in_compose_is_list(transform_list):
# set check shapes to false because of the multitemporal case
return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list


class Normalize(Callable):
def __init__(self, means, stds):
super().__init__()
Expand Down Expand Up @@ -68,8 +68,8 @@ def __init__(
train_data_root: Path,
val_data_root: Path,
test_data_root: Path,
means: list[float],
stds: list[float],
means: list[float] | str,
stds: list[float] | str,
num_classes: int,
predict_data_root: Path | None = None,
train_split: Path | None = None,
Expand Down Expand Up @@ -166,6 +166,10 @@ def __init__(
# K.Normalize(means, stds),
# data_keys=["image"],
# )

means = load_from_file_or_attribute(means)
stds = load_from_file_or_attribute(stds)

self.aug = Normalize(means, stds)

# self.aug = Normalize(means, stds)
Expand Down
Loading
Loading