Skip to content

Commit 3670c7b

Browse files
Merge pull request #62 from IBM/feature/upernet_scale_modules
add scale modules to upernet for vit backbone
2 parents da480f0 + 57b8851 commit 3670c7b

File tree

3 files changed

+162
-229
lines changed

3 files changed

+162
-229
lines changed

terratorch/models/decoders/upernet_decoder.py

+30-103
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
1-
# Copyright contributors to the Terratorch project
2-
31
import torch
42
import torch.nn.functional as F # noqa: N812
53
from torch import Tensor, nn
64

75
"""
86
Adapted from https://github.com/yassouali/pytorch-segmentation/blob/master/models/upernet.py
97
"""
10-
11-
128
class ConvModule(nn.Module):
139
def __init__(self, in_channels, out_channels, kernel_size, padding=0, inplace=False) -> None: # noqa: FBT002
1410
super().__init__()
@@ -19,103 +15,6 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0, inplace=Fa
1915
def forward(self, x):
2016
return self.act(self.norm(self.conv(x)))
2117

22-
23-
# class PSPModule(nn.Module):
24-
# # In the original inmplementation they use precise RoI pooling
25-
# # Instead of using adaptative average pooling
26-
# def __init__(self, in_channels: int, bin_sizes: list[int] | None = None):
27-
# super().__init__()
28-
# if bin_sizes is None:
29-
# bin_sizes = [1, 2, 3, 6]
30-
# out_channels = in_channels // len(bin_sizes)
31-
# self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s) for b_s in bin_sizes])
32-
# self.bottleneck = nn.Sequential(
33-
# nn.Conv2d(
34-
# in_channels + (out_channels * len(bin_sizes)),
35-
# in_channels,
36-
# kernel_size=3,
37-
# padding=1,
38-
# bias=False,
39-
# ),
40-
# nn.BatchNorm2d(in_channels),
41-
# nn.ReLU(inplace=True),
42-
# nn.Dropout2d(0.1),
43-
# )
44-
45-
# def _make_stages(self, in_channels, out_channels, bin_sz):
46-
# prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
47-
# conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
48-
# bn = nn.BatchNorm2d(out_channels)
49-
# relu = nn.ReLU(inplace=True)
50-
# return nn.Sequential(prior, conv, bn, relu)
51-
52-
# def forward(self, features):
53-
# h, w = features.size()[2], features.size()[3]
54-
# pyramids = [features]
55-
# pyramids.extend(
56-
# [F.interpolate(stage(features), size=(h, w), mode="bilinear", align_corners=True) for stage in self.stages]
57-
# )
58-
# output = self.bottleneck(torch.cat(pyramids, dim=1))
59-
# return output
60-
61-
62-
# def up_and_add(x, y):
63-
# return F.interpolate(x, size=(y.size(2), y.size(3)), mode="bilinear", align_corners=True) + y
64-
65-
66-
# class FPNFuse(nn.Module):
67-
# def __init__(self, feature_channels=None, fpn_out=256):
68-
# super().__init__()
69-
# if feature_channels is None:
70-
# feature_channels = [256, 512, 1024, 2048]
71-
# if not feature_channels[0] == fpn_out:
72-
# msg = f"First index of feature channel ({feature_channels[0]}) did not match fpn_out ({fpn_out})"
73-
# raise Exception(msg)
74-
# self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1) for ft_size in feature_channels[1:]])
75-
# self.smooth_conv = nn.ModuleList(
76-
# [nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)] * (len(feature_channels) - 1)
77-
# )
78-
# self.conv_fusion = nn.Sequential(
79-
# nn.Conv2d(
80-
# len(feature_channels) * fpn_out,
81-
# fpn_out,
82-
# kernel_size=3,
83-
# padding=1,
84-
# bias=False,
85-
# ),
86-
# nn.BatchNorm2d(fpn_out),
87-
# nn.ReLU(inplace=True),
88-
# )
89-
90-
# def forward(self, features):
91-
# features[1:] = [conv1x1(feature) for feature, conv1x1 in zip(features[1:], self.conv1x1, strict=False)]
92-
# p = [up_and_add(features[i], features[i - 1]) for i in reversed(range(1, len(features)))]
93-
# p = [smooth_conv(x) for smooth_conv, x in zip(self.smooth_conv, p, strict=False)]
94-
# p = list(reversed(p))
95-
# p.append(features[-1]) # P = [P1, P2, P3, P4]
96-
# h, w = p[0].size(2), p[0].size(3)
97-
# p[1:] = [F.interpolate(feature, size=(h, w), mode="bilinear", align_corners=True) for feature in p[1:]]
98-
99-
# x = self.conv_fusion(torch.cat(p, dim=1))
100-
# return x
101-
102-
103-
# class UperNetDecoder(nn.Module):
104-
# def __init__(self, embed_dim: list[int]) -> None:
105-
# super().__init__()
106-
# self.embed_dim = embed_dim
107-
# self.output_embed_dim = embed_dim[0]
108-
# self.PPN = PSPModule(embed_dim[-1])
109-
# self.FPN = FPNFuse(embed_dim, fpn_out=self.output_embed_dim)
110-
111-
# def forward(self, x: Tensor):
112-
# x = [f.clone() for f in x]
113-
# x[-1] = self.PPN(x[-1])
114-
# x = self.FPN(x)
115-
116-
# return x
117-
118-
11918
# Adapted from MMSegmentation
12019
class UperNetDecoder(nn.Module):
12120
"""UperNetDecoder. Adapted from MMSegmentation."""
@@ -126,6 +25,7 @@ def __init__(
12625
pool_scales: tuple[int] = (1, 2, 3, 6),
12726
channels: int = 256,
12827
align_corners: bool = True, # noqa: FBT001, FBT002
28+
scale_modules: bool = False
12929
):
13030
"""Constructor
13131
@@ -134,10 +34,29 @@ def __init__(
13434
pool_scales (tuple[int], optional): Pooling scales used in Pooling Pyramid
13535
Module applied on the last feature. Default: (1, 2, 3, 6).
13636
channels (int, optional): Channels used in the decoder. Defaults to 256.
137-
align_corners (bool, optional): Whter to align corners in rescaling. Defaults to True.
37+
align_corners (bool, optional): Wheter to align corners in rescaling. Defaults to True.
38+
scale_modules (bool, optional): Whether to apply scale modules to the inputs. Needed for plain ViT.
39+
Defaults to False.
13840
"""
13941
super().__init__()
140-
self.embed_dim = embed_dim
42+
self.scale_modules = scale_modules
43+
if scale_modules:
44+
self.fpn1 = nn.Sequential(
45+
nn.ConvTranspose2d(embed_dim[0],
46+
embed_dim[0] // 2, 2, 2),
47+
nn.BatchNorm2d(embed_dim[0] // 2),
48+
nn.GELU(),
49+
nn.ConvTranspose2d(embed_dim[0] // 2,
50+
embed_dim[0] // 4, 2, 2))
51+
self.fpn2 = nn.Sequential(
52+
nn.ConvTranspose2d(embed_dim[1],
53+
embed_dim[1] // 2, 2, 2))
54+
self.fpn3 = nn.Sequential(nn.Identity())
55+
self.fpn4 = nn.Sequential(nn.MaxPool2d(kernel_size=2, stride=2))
56+
self.embed_dim = [embed_dim[0] // 4, embed_dim[1] // 2, embed_dim[2], embed_dim[3]]
57+
else:
58+
self.embed_dim = embed_dim
59+
14160
self.output_embed_dim = channels
14261
self.channels = channels
14362
self.align_corners = align_corners
@@ -192,6 +111,14 @@ def forward(self, inputs):
192111
feats (Tensor): A tensor of shape (batch_size, self.channels,
193112
H, W) which is feature map for last layer of decoder head.
194113
"""
114+
115+
if self.scale_modules:
116+
scaled_inputs = []
117+
scaled_inputs.append(self.fpn1(inputs[0]))
118+
scaled_inputs.append(self.fpn2(inputs[1]))
119+
scaled_inputs.append(self.fpn3(inputs[2]))
120+
scaled_inputs.append(self.fpn4(inputs[3]))
121+
inputs = scaled_inputs
195122
# build laterals
196123
laterals = [lateral_conv(inputs[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
197124
laterals.append(self.psp_forward(inputs))

0 commit comments

Comments
 (0)