Skip to content

Commit aefcf02

Browse files
Merge branch 'main' of github.com:IBM/terratorch into add/unet
2 parents 162da74 + 9df5f5f commit aefcf02

10 files changed

+352
-239
lines changed

terratorch/datamodules/generic_pixel_wise_data_module.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
"""
44
This module contains generic data modules for instantiation at runtime.
55
"""
6-
6+
import os
77
from collections.abc import Callable, Iterable
88
from pathlib import Path
99
from typing import Any
10-
10+
import numpy as np
1111
import albumentations as A
1212
import kornia.augmentation as K
1313
import torch
@@ -17,7 +17,7 @@
1717
from torchgeo.transforms import AugmentationSequential
1818

1919
from terratorch.datasets import GenericNonGeoPixelwiseRegressionDataset, GenericNonGeoSegmentationDataset, HLSBands
20-
20+
from terratorch.io.file import load_from_file_or_attribute
2121

2222
def wrap_in_compose_is_list(transform_list):
2323
# set check shapes to false because of the multitemporal case
@@ -79,8 +79,8 @@ def __init__(
7979
test_data_root: Path,
8080
img_grep: str,
8181
label_grep: str,
82-
means: list[float],
83-
stds: list[float],
82+
means: list[float] | str,
83+
stds: list[float] | str,
8484
num_classes: int,
8585
predict_data_root: Path | None = None,
8686
train_label_data_root: Path | None = None,
@@ -198,6 +198,9 @@ def __init__(
198198
# K.Normalize(means, stds),
199199
# data_keys=["image"],
200200
# )
201+
means = load_from_file_or_attribute(means)
202+
stds = load_from_file_or_attribute(stds)
203+
201204
self.aug = Normalize(means, stds)
202205

203206
# self.aug = Normalize(means, stds)
@@ -317,8 +320,8 @@ def __init__(
317320
train_data_root: Path,
318321
val_data_root: Path,
319322
test_data_root: Path,
320-
means: list[float],
321-
stds: list[float],
323+
means: list[float] | str,
324+
stds: list[float] | str,
322325
predict_data_root: Path | None = None,
323326
img_grep: str | None = "*",
324327
label_grep: str | None = "*",
@@ -430,6 +433,9 @@ def __init__(
430433
# K.Normalize(means, stds),
431434
# data_keys=["image"],
432435
# )
436+
means = load_from_file_or_attribute(means)
437+
stds = load_from_file_or_attribute(stds)
438+
433439
self.aug = Normalize(means, stds)
434440
self.no_data_replace = no_data_replace
435441
self.no_label_replace = no_label_replace

terratorch/datamodules/generic_scalar_label_data_module.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@
2222
HLSBands,
2323
)
2424

25+
from terratorch.io.file import load_from_file_or_attribute
2526

2627
def wrap_in_compose_is_list(transform_list):
2728
# set check shapes to false because of the multitemporal case
2829
return A.Compose(transform_list, is_check_shapes=False) if isinstance(transform_list, Iterable) else transform_list
2930

30-
3131
class Normalize(Callable):
3232
def __init__(self, means, stds):
3333
super().__init__()
@@ -68,8 +68,8 @@ def __init__(
6868
train_data_root: Path,
6969
val_data_root: Path,
7070
test_data_root: Path,
71-
means: list[float],
72-
stds: list[float],
71+
means: list[float] | str,
72+
stds: list[float] | str,
7373
num_classes: int,
7474
predict_data_root: Path | None = None,
7575
train_split: Path | None = None,
@@ -166,6 +166,10 @@ def __init__(
166166
# K.Normalize(means, stds),
167167
# data_keys=["image"],
168168
# )
169+
170+
means = load_from_file_or_attribute(means)
171+
stds = load_from_file_or_attribute(stds)
172+
169173
self.aug = Normalize(means, stds)
170174

171175
# self.aug = Normalize(means, stds)

terratorch/io/file.py

+19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import importlib
33
from torch import nn
4+
import numpy as np
45

56
def open_generic_torch_model(model: type | str = None,
67
model_kwargs: dict = None,
@@ -51,3 +52,21 @@ def load_torch_weights(model:nn.Module=None, save_dir: str = None, name: str = N
5152
)
5253

5354
return model
55+
56+
def load_from_file_or_attribute(value: list[float]|str):
57+
58+
if isinstance(value, list):
59+
return value
60+
elif isinstance(value, str): # It can be the path for a file
61+
if os.path.isfile(value):
62+
try:
63+
print(value)
64+
content = np.genfromtxt(value).tolist()
65+
except:
66+
raise Exception(f"File must be txt, but received {value}")
67+
else:
68+
raise Exception(f"The input {value} does not exist or is not a file.")
69+
70+
return content
71+
72+

terratorch/models/decoders/upernet_decoder.py

+30-103
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
1-
# Copyright contributors to the Terratorch project
2-
31
import torch
42
import torch.nn.functional as F # noqa: N812
53
from torch import Tensor, nn
64

75
"""
86
Adapted from https://github.com/yassouali/pytorch-segmentation/blob/master/models/upernet.py
97
"""
10-
11-
128
class ConvModule(nn.Module):
139
def __init__(self, in_channels, out_channels, kernel_size, padding=0, inplace=False) -> None: # noqa: FBT002
1410
super().__init__()
@@ -19,103 +15,6 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0, inplace=Fa
1915
def forward(self, x):
2016
return self.act(self.norm(self.conv(x)))
2117

22-
23-
# class PSPModule(nn.Module):
24-
# # In the original inmplementation they use precise RoI pooling
25-
# # Instead of using adaptative average pooling
26-
# def __init__(self, in_channels: int, bin_sizes: list[int] | None = None):
27-
# super().__init__()
28-
# if bin_sizes is None:
29-
# bin_sizes = [1, 2, 3, 6]
30-
# out_channels = in_channels // len(bin_sizes)
31-
# self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s) for b_s in bin_sizes])
32-
# self.bottleneck = nn.Sequential(
33-
# nn.Conv2d(
34-
# in_channels + (out_channels * len(bin_sizes)),
35-
# in_channels,
36-
# kernel_size=3,
37-
# padding=1,
38-
# bias=False,
39-
# ),
40-
# nn.BatchNorm2d(in_channels),
41-
# nn.ReLU(inplace=True),
42-
# nn.Dropout2d(0.1),
43-
# )
44-
45-
# def _make_stages(self, in_channels, out_channels, bin_sz):
46-
# prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
47-
# conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
48-
# bn = nn.BatchNorm2d(out_channels)
49-
# relu = nn.ReLU(inplace=True)
50-
# return nn.Sequential(prior, conv, bn, relu)
51-
52-
# def forward(self, features):
53-
# h, w = features.size()[2], features.size()[3]
54-
# pyramids = [features]
55-
# pyramids.extend(
56-
# [F.interpolate(stage(features), size=(h, w), mode="bilinear", align_corners=True) for stage in self.stages]
57-
# )
58-
# output = self.bottleneck(torch.cat(pyramids, dim=1))
59-
# return output
60-
61-
62-
# def up_and_add(x, y):
63-
# return F.interpolate(x, size=(y.size(2), y.size(3)), mode="bilinear", align_corners=True) + y
64-
65-
66-
# class FPNFuse(nn.Module):
67-
# def __init__(self, feature_channels=None, fpn_out=256):
68-
# super().__init__()
69-
# if feature_channels is None:
70-
# feature_channels = [256, 512, 1024, 2048]
71-
# if not feature_channels[0] == fpn_out:
72-
# msg = f"First index of feature channel ({feature_channels[0]}) did not match fpn_out ({fpn_out})"
73-
# raise Exception(msg)
74-
# self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1) for ft_size in feature_channels[1:]])
75-
# self.smooth_conv = nn.ModuleList(
76-
# [nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)] * (len(feature_channels) - 1)
77-
# )
78-
# self.conv_fusion = nn.Sequential(
79-
# nn.Conv2d(
80-
# len(feature_channels) * fpn_out,
81-
# fpn_out,
82-
# kernel_size=3,
83-
# padding=1,
84-
# bias=False,
85-
# ),
86-
# nn.BatchNorm2d(fpn_out),
87-
# nn.ReLU(inplace=True),
88-
# )
89-
90-
# def forward(self, features):
91-
# features[1:] = [conv1x1(feature) for feature, conv1x1 in zip(features[1:], self.conv1x1, strict=False)]
92-
# p = [up_and_add(features[i], features[i - 1]) for i in reversed(range(1, len(features)))]
93-
# p = [smooth_conv(x) for smooth_conv, x in zip(self.smooth_conv, p, strict=False)]
94-
# p = list(reversed(p))
95-
# p.append(features[-1]) # P = [P1, P2, P3, P4]
96-
# h, w = p[0].size(2), p[0].size(3)
97-
# p[1:] = [F.interpolate(feature, size=(h, w), mode="bilinear", align_corners=True) for feature in p[1:]]
98-
99-
# x = self.conv_fusion(torch.cat(p, dim=1))
100-
# return x
101-
102-
103-
# class UperNetDecoder(nn.Module):
104-
# def __init__(self, embed_dim: list[int]) -> None:
105-
# super().__init__()
106-
# self.embed_dim = embed_dim
107-
# self.output_embed_dim = embed_dim[0]
108-
# self.PPN = PSPModule(embed_dim[-1])
109-
# self.FPN = FPNFuse(embed_dim, fpn_out=self.output_embed_dim)
110-
111-
# def forward(self, x: Tensor):
112-
# x = [f.clone() for f in x]
113-
# x[-1] = self.PPN(x[-1])
114-
# x = self.FPN(x)
115-
116-
# return x
117-
118-
11918
# Adapted from MMSegmentation
12019
class UperNetDecoder(nn.Module):
12120
"""UperNetDecoder. Adapted from MMSegmentation."""
@@ -126,6 +25,7 @@ def __init__(
12625
pool_scales: tuple[int] = (1, 2, 3, 6),
12726
channels: int = 256,
12827
align_corners: bool = True, # noqa: FBT001, FBT002
28+
scale_modules: bool = False
12929
):
13030
"""Constructor
13131
@@ -134,10 +34,29 @@ def __init__(
13434
pool_scales (tuple[int], optional): Pooling scales used in Pooling Pyramid
13535
Module applied on the last feature. Default: (1, 2, 3, 6).
13636
channels (int, optional): Channels used in the decoder. Defaults to 256.
137-
align_corners (bool, optional): Whter to align corners in rescaling. Defaults to True.
37+
align_corners (bool, optional): Wheter to align corners in rescaling. Defaults to True.
38+
scale_modules (bool, optional): Whether to apply scale modules to the inputs. Needed for plain ViT.
39+
Defaults to False.
13840
"""
13941
super().__init__()
140-
self.embed_dim = embed_dim
42+
self.scale_modules = scale_modules
43+
if scale_modules:
44+
self.fpn1 = nn.Sequential(
45+
nn.ConvTranspose2d(embed_dim[0],
46+
embed_dim[0] // 2, 2, 2),
47+
nn.BatchNorm2d(embed_dim[0] // 2),
48+
nn.GELU(),
49+
nn.ConvTranspose2d(embed_dim[0] // 2,
50+
embed_dim[0] // 4, 2, 2))
51+
self.fpn2 = nn.Sequential(
52+
nn.ConvTranspose2d(embed_dim[1],
53+
embed_dim[1] // 2, 2, 2))
54+
self.fpn3 = nn.Sequential(nn.Identity())
55+
self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
56+
self.embed_dim = [embed_dim[0] // 4, embed_dim[1] // 2, embed_dim[2], embed_dim[3]]
57+
else:
58+
self.embed_dim = embed_dim
59+
14160
self.output_embed_dim = channels
14261
self.channels = channels
14362
self.align_corners = align_corners
@@ -192,6 +111,14 @@ def forward(self, inputs):
192111
feats (Tensor): A tensor of shape (batch_size, self.channels,
193112
H, W) which is feature map for last layer of decoder head.
194113
"""
114+
115+
if self.scale_modules:
116+
scaled_inputs = []
117+
scaled_inputs.append(self.fpn1(inputs[0]))
118+
scaled_inputs.append(self.fpn2(inputs[1]))
119+
scaled_inputs.append(self.fpn3(inputs[2]))
120+
scaled_inputs.append(self.fpn4(inputs[3]))
121+
inputs = scaled_inputs
195122
# build laterals
196123
laterals = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
197124
laterals.append(self.psp_forward(inputs))

0 commit comments

Comments
 (0)