Skip to content

Commit cbc98f1

Browse files
Merge pull request #67 from IBM/add/unet
Add/unet
2 parents a09f8e9 + 9f7eca0 commit cbc98f1

16 files changed

+451
-1
lines changed

.github/workflows/test.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ jobs:
2828
run: |
2929
python -m pip install --upgrade pip
3030
pip install -r requirements/required.txt -r requirements/test.txt
31+
mim install mmsegmentation
3132
- name: List pip dependencies
3233
run: pip list
3334
- name: Test with pytest

pyproject.toml

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
requires = [ "setuptools" ]
33
build-backend = 'setuptools.build_meta'
44

5+
# It allows installation via `pip install -e`
6+
[tool.setuptools]
7+
py-modules = []
8+
59
[project]
610
name = "terratorch"
711
version = "0.99.1"

requirements/required.txt

+7-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,10 @@ lightly==1.4.25
1010
h5py==3.10.0
1111
geobench==1.0.0
1212
mlflow==2.14.3
13-
lightning==2.2.5
13+
lightning==2.2.5
14+
mmcv==2.1.0
15+
# Extra dependencies required by mmseg
16+
ftfy
17+
regex
18+
openmim
19+
#mim mmsegmentation

terratorch/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from terratorch.models.scalemae_model_factory import ScaleMAEModelFactory
66
from terratorch.models.smp_model_factory import SMPModelFactory
77
from terratorch.models.timm_model_factory import TimmModelFactory
8+
from terratorch.models.generic_unet_model_factory import GenericUnetModelFactory
89

910
__all__ = (
1011
"PrithviModelFactory",
1112
"ClayModelFactory",
1213
"SatMAEModelFactory",
1314
"ScaleMAEModelFactory",
1415
"SMPModelFactory",
16+
"GenericUnetModelFactory",
1517
"TimmModelFactory",
1618
"AuxiliaryHead",
1719
"AuxiliaryHeadWithDecoderWithoutInstantiatedHead",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright contributors to the Terratorch project
2+
3+
"""
4+
This is just an example of a possible structure to include SMP models
5+
Right now it always returns a UNET, but could easily be extended to many of the models provided by SMP.
6+
"""
7+
8+
from torch import nn
9+
import torch
10+
from terratorch.models.model import Model, ModelFactory, ModelOutput, register_factory
11+
from terratorch.tasks.segmentation_tasks import to_segmentation_prediction
12+
13+
import importlib
14+
15+
def freeze_module(module: nn.Module):
16+
for param in module.parameters():
17+
param.requires_grad_(False)
18+
19+
@register_factory
20+
class GenericUnetModelFactory(ModelFactory):
21+
def build_model(
22+
self,
23+
task: str = "segmentation",
24+
backbone: str = None,
25+
decoder: str = None,
26+
dilations: tuple[int] = (1, 6, 12, 18),
27+
in_channels: int = 6,
28+
pretrained: str | bool | None = True,
29+
num_classes: int = 1,
30+
regression_relu: bool = False,
31+
**kwargs,
32+
) -> Model:
33+
"""Factory to create model based on SMP.
34+
35+
Args:
36+
task (str): Must be "segmentation".
37+
model (str): Decoder architecture. Currently only supports "unet".
38+
in_channels (int): Number of input channels.
39+
pretrained(str | bool): Which weights to use for the backbone. If true, will use "imagenet". If false or None, random weights. Defaults to True.
40+
num_classes (int): Number of classes.
41+
regression_relu (bool). Whether to apply a ReLU if task is regression. Defaults to False.
42+
43+
Returns:
44+
Model: SMP model wrapped in SMPModelWrapper.
45+
"""
46+
if task not in ["segmentation", "regression"]:
47+
msg = f"SMP models can only perform pixel wise tasks, but got task {task}"
48+
raise Exception(msg)
49+
50+
mmseg_decoders = importlib.import_module("mmseg.models.decode_heads")
51+
mmseg_encoders = importlib.import_module("mmseg.models.backbones")
52+
53+
if backbone:
54+
backbone_kwargs = _extract_prefix_keys(kwargs, "backbone_")
55+
model = backbone
56+
model_kwargs = backbone_kwargs
57+
mmseg = mmseg_encoders
58+
elif decoder:
59+
decoder_kwargs = _extract_prefix_keys(kwargs, "decoder_")
60+
model = decoder
61+
model_kwargs = decoder_kwargs
62+
mmseg = mmseg_decoders
63+
else:
64+
print("It is necessary to define a backbone and/or a decoder.")
65+
66+
model_class = getattr(mmseg, model)
67+
68+
model = model_class(
69+
**model_kwargs,
70+
)
71+
72+
return GenericUnetModelWrapper(
73+
model, relu=task == "regression" and regression_relu, squeeze_single_class=task == "regression"
74+
)
75+
76+
class GenericUnetModelWrapper(Model, nn.Module):
77+
def __init__(self, unet_model, relu=False, squeeze_single_class=False) -> None:
78+
super().__init__()
79+
self.unet_model = unet_model
80+
self.final_act = nn.ReLU() if relu else nn.Identity()
81+
self.squeeze_single_class = squeeze_single_class
82+
83+
def forward(self, *args, **kwargs):
84+
85+
# It supposes the input has dimension (B, C, H, W)
86+
input_data = [args[0]] # It adapts the input to became a list of time 'snapshots'
87+
args = (input_data,)
88+
89+
unet_output = self.unet_model(*args, **kwargs)
90+
unet_output = self.final_act(unet_output)
91+
92+
if unet_output.shape[1] == 1 and self.squeeze_single_class:
93+
unet_output = unet_output.squeeze(1)
94+
95+
model_output = ModelOutput(unet_output)
96+
97+
return model_output
98+
99+
def freeze_encoder(self):
100+
raise NotImplementedError()
101+
102+
def freeze_decoder(self):
103+
raise freeze_module(self.unet_model)
104+
105+
106+
def _extract_prefix_keys(d: dict, prefix: str) -> dict:
107+
extracted_dict = {}
108+
keys_to_del = []
109+
for k, v in d.items():
110+
if k.startswith(prefix):
111+
extracted_dict[k.split(prefix)[1]] = v
112+
keys_to_del.append(k)
113+
114+
for k in keys_to_del:
115+
del d[k]
116+
117+
return extracted_dict
+148
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# lightning.pytorch==2.1.1
2+
seed_everything: 42
3+
trainer:
4+
accelerator: auto
5+
strategy: auto
6+
devices: auto
7+
num_nodes: 1
8+
#precision: 16-mixed
9+
# precision: 16-mixed
10+
logger:
11+
class_path: TensorBoardLogger
12+
init_args:
13+
save_dir: tests/
14+
name: all_ecos_random
15+
callbacks:
16+
- class_path: RichProgressBar
17+
- class_path: LearningRateMonitor
18+
init_args:
19+
logging_interval: epoch
20+
- class_path: EarlyStopping
21+
init_args:
22+
monitor: val/loss
23+
patience: 100
24+
max_epochs: 5
25+
check_val_every_n_epoch: 1
26+
log_every_n_steps: 20
27+
enable_checkpointing: true
28+
default_root_dir: tests/
29+
data:
30+
class_path: GenericNonGeoSegmentationDataModule
31+
init_args:
32+
batch_size: 2
33+
num_workers: 4
34+
train_transform:
35+
- class_path: albumentations.HorizontalFlip
36+
init_args:
37+
p: 0.5
38+
- class_path: albumentations.Rotate
39+
init_args:
40+
limit: 30
41+
border_mode: 0 # cv2.BORDER_CONSTANT
42+
value: 0
43+
# mask_value: 1
44+
p: 0.5
45+
- class_path: ToTensorV2
46+
dataset_bands:
47+
- COASTAL_AEROSOL
48+
- BLUE
49+
- GREEN
50+
- RED
51+
- NIR_NARROW
52+
- SWIR_1
53+
- SWIR_2
54+
- CIRRUS
55+
- THEMRAL_INFRARED_1
56+
- THEMRAL_INFRARED_2
57+
output_bands:
58+
- BLUE
59+
- GREEN
60+
- RED
61+
- NIR_NARROW
62+
- SWIR_1
63+
- SWIR_2
64+
rgb_indices:
65+
- 2
66+
- 1
67+
- 0
68+
train_data_root: tests/
69+
train_label_data_root: tests/
70+
val_data_root: tests/
71+
val_label_data_root: tests/
72+
test_data_root: tests/
73+
test_label_data_root: tests/
74+
img_grep: "segmentation*input*.tif"
75+
label_grep: "segmentation*label*.tif"
76+
means:
77+
- 547.36707
78+
- 898.5121
79+
- 1020.9082
80+
- 2665.5352
81+
- 2340.584
82+
- 1610.1407
83+
stds:
84+
- 411.4701
85+
- 558.54065
86+
- 815.94025
87+
- 812.4403
88+
- 1113.7145
89+
- 1067.641
90+
no_label_replace: -1
91+
no_data_replace: 0
92+
num_classes: 2
93+
model:
94+
class_path: terratorch.tasks.SemanticSegmentationTask
95+
init_args:
96+
model_args:
97+
decoder: "ASPPHead"
98+
decoder_dilations: [1, 6, 12, 18]
99+
decoder_channels: 256
100+
decoder_in_channels: 6
101+
decoder_num_classes: 2
102+
in_channels: 6
103+
bands:
104+
- BLUE
105+
- GREEN
106+
- RED
107+
- NIR_NARROW
108+
- SWIR_1
109+
- SWIR_2
110+
#num_frames: 1
111+
head_dropout: 0.5708022831486758
112+
head_final_act: torch.nn.ReLU
113+
head_learned_upscale_layers: 2
114+
num_classes: 2
115+
loss: ce
116+
#aux_heads:
117+
# - name: aux_head
118+
# decoder: IdentityDecoder
119+
# decoder_args:
120+
# decoder_out_index: 2
121+
# head_dropout: 0,5
122+
# head_channel_list:
123+
# - 64
124+
# head_final_act: torch.nn.ReLU
125+
#aux_loss:
126+
# aux_head: 0.4
127+
ignore_index: -1
128+
#freeze_encoder: false #true
129+
#freeze_decoder: false
130+
model_factory: GenericUnetModelFactory
131+
132+
# uncomment this block for tiled inference
133+
# tiled_inference_parameters:
134+
# h_crop: 224
135+
# h_stride: 192
136+
# w_crop: 224
137+
# w_stride: 192
138+
# average_patches: true
139+
optimizer:
140+
class_path: torch.optim.AdamW
141+
init_args:
142+
lr: 0.00013524680528283027
143+
weight_decay: 0.047782217873995426
144+
lr_scheduler:
145+
class_path: ReduceLROnPlateau
146+
init_args:
147+
monitor: val/loss
148+

0 commit comments

Comments
 (0)