1
- # Copyright contributors to the Terratorch project
2
-
3
1
import torch
4
2
import torch .nn .functional as F # noqa: N812
5
3
from torch import Tensor , nn
6
4
7
5
"""
8
6
Adapted from https://github.com/yassouali/pytorch-segmentation/blob/master/models/upernet.py
9
7
"""
10
-
11
-
12
8
class ConvModule (nn .Module ):
13
9
def __init__ (self , in_channels , out_channels , kernel_size , padding = 0 , inplace = False ) -> None : # noqa: FBT002
14
10
super ().__init__ ()
@@ -19,103 +15,6 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0, inplace=Fa
19
15
def forward (self , x ):
20
16
return self .act (self .norm (self .conv (x )))
21
17
22
-
23
- # class PSPModule(nn.Module):
24
- # # In the original inmplementation they use precise RoI pooling
25
- # # Instead of using adaptative average pooling
26
- # def __init__(self, in_channels: int, bin_sizes: list[int] | None = None):
27
- # super().__init__()
28
- # if bin_sizes is None:
29
- # bin_sizes = [1, 2, 3, 6]
30
- # out_channels = in_channels // len(bin_sizes)
31
- # self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s) for b_s in bin_sizes])
32
- # self.bottleneck = nn.Sequential(
33
- # nn.Conv2d(
34
- # in_channels + (out_channels * len(bin_sizes)),
35
- # in_channels,
36
- # kernel_size=3,
37
- # padding=1,
38
- # bias=False,
39
- # ),
40
- # nn.BatchNorm2d(in_channels),
41
- # nn.ReLU(inplace=True),
42
- # nn.Dropout2d(0.1),
43
- # )
44
-
45
- # def _make_stages(self, in_channels, out_channels, bin_sz):
46
- # prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)
47
- # conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
48
- # bn = nn.BatchNorm2d(out_channels)
49
- # relu = nn.ReLU(inplace=True)
50
- # return nn.Sequential(prior, conv, bn, relu)
51
-
52
- # def forward(self, features):
53
- # h, w = features.size()[2], features.size()[3]
54
- # pyramids = [features]
55
- # pyramids.extend(
56
- # [F.interpolate(stage(features), size=(h, w), mode="bilinear", align_corners=True) for stage in self.stages]
57
- # )
58
- # output = self.bottleneck(torch.cat(pyramids, dim=1))
59
- # return output
60
-
61
-
62
- # def up_and_add(x, y):
63
- # return F.interpolate(x, size=(y.size(2), y.size(3)), mode="bilinear", align_corners=True) + y
64
-
65
-
66
- # class FPNFuse(nn.Module):
67
- # def __init__(self, feature_channels=None, fpn_out=256):
68
- # super().__init__()
69
- # if feature_channels is None:
70
- # feature_channels = [256, 512, 1024, 2048]
71
- # if not feature_channels[0] == fpn_out:
72
- # msg = f"First index of feature channel ({feature_channels[0]}) did not match fpn_out ({fpn_out})"
73
- # raise Exception(msg)
74
- # self.conv1x1 = nn.ModuleList([nn.Conv2d(ft_size, fpn_out, kernel_size=1) for ft_size in feature_channels[1:]])
75
- # self.smooth_conv = nn.ModuleList(
76
- # [nn.Conv2d(fpn_out, fpn_out, kernel_size=3, padding=1)] * (len(feature_channels) - 1)
77
- # )
78
- # self.conv_fusion = nn.Sequential(
79
- # nn.Conv2d(
80
- # len(feature_channels) * fpn_out,
81
- # fpn_out,
82
- # kernel_size=3,
83
- # padding=1,
84
- # bias=False,
85
- # ),
86
- # nn.BatchNorm2d(fpn_out),
87
- # nn.ReLU(inplace=True),
88
- # )
89
-
90
- # def forward(self, features):
91
- # features[1:] = [conv1x1(feature) for feature, conv1x1 in zip(features[1:], self.conv1x1, strict=False)]
92
- # p = [up_and_add(features[i], features[i - 1]) for i in reversed(range(1, len(features)))]
93
- # p = [smooth_conv(x) for smooth_conv, x in zip(self.smooth_conv, p, strict=False)]
94
- # p = list(reversed(p))
95
- # p.append(features[-1]) # P = [P1, P2, P3, P4]
96
- # h, w = p[0].size(2), p[0].size(3)
97
- # p[1:] = [F.interpolate(feature, size=(h, w), mode="bilinear", align_corners=True) for feature in p[1:]]
98
-
99
- # x = self.conv_fusion(torch.cat(p, dim=1))
100
- # return x
101
-
102
-
103
- # class UperNetDecoder(nn.Module):
104
- # def __init__(self, embed_dim: list[int]) -> None:
105
- # super().__init__()
106
- # self.embed_dim = embed_dim
107
- # self.output_embed_dim = embed_dim[0]
108
- # self.PPN = PSPModule(embed_dim[-1])
109
- # self.FPN = FPNFuse(embed_dim, fpn_out=self.output_embed_dim)
110
-
111
- # def forward(self, x: Tensor):
112
- # x = [f.clone() for f in x]
113
- # x[-1] = self.PPN(x[-1])
114
- # x = self.FPN(x)
115
-
116
- # return x
117
-
118
-
119
18
# Adapted from MMSegmentation
120
19
class UperNetDecoder (nn .Module ):
121
20
"""UperNetDecoder. Adapted from MMSegmentation."""
@@ -126,6 +25,7 @@ def __init__(
126
25
pool_scales : tuple [int ] = (1 , 2 , 3 , 6 ),
127
26
channels : int = 256 ,
128
27
align_corners : bool = True , # noqa: FBT001, FBT002
28
+ scale_modules : bool = False
129
29
):
130
30
"""Constructor
131
31
@@ -134,10 +34,29 @@ def __init__(
134
34
pool_scales (tuple[int], optional): Pooling scales used in Pooling Pyramid
135
35
Module applied on the last feature. Default: (1, 2, 3, 6).
136
36
channels (int, optional): Channels used in the decoder. Defaults to 256.
137
- align_corners (bool, optional): Whter to align corners in rescaling. Defaults to True.
37
+ align_corners (bool, optional): Wheter to align corners in rescaling. Defaults to True.
38
+ scale_modules (bool, optional): Whether to apply scale modules to the inputs. Needed for plain ViT.
39
+ Defaults to False.
138
40
"""
139
41
super ().__init__ ()
140
- self .embed_dim = embed_dim
42
+ self .scale_modules = scale_modules
43
+ if scale_modules :
44
+ self .fpn1 = nn .Sequential (
45
+ nn .ConvTranspose2d (embed_dim [0 ],
46
+ embed_dim [0 ] // 2 , 2 , 2 ),
47
+ nn .BatchNorm2d (embed_dim [0 ] // 2 ),
48
+ nn .GELU (),
49
+ nn .ConvTranspose2d (embed_dim [0 ] // 2 ,
50
+ embed_dim [0 ] // 4 , 2 , 2 ))
51
+ self .fpn2 = nn .Sequential (
52
+ nn .ConvTranspose2d (embed_dim [1 ],
53
+ embed_dim [1 ] // 2 , 2 , 2 ))
54
+ self .fpn3 = nn .Sequential (nn .Identity ())
55
+ self .fpn4 = nn .Sequential (nn .MaxPool2d (kernel_size = 2 , stride = 2 ))
56
+ self .embed_dim = [embed_dim [0 ] // 4 , embed_dim [1 ] // 2 , embed_dim [2 ], embed_dim [3 ]]
57
+ else :
58
+ self .embed_dim = embed_dim
59
+
141
60
self .output_embed_dim = channels
142
61
self .channels = channels
143
62
self .align_corners = align_corners
@@ -192,6 +111,14 @@ def forward(self, inputs):
192
111
feats (Tensor): A tensor of shape (batch_size, self.channels,
193
112
H, W) which is feature map for last layer of decoder head.
194
113
"""
114
+
115
+ if self .scale_modules :
116
+ scaled_inputs = []
117
+ scaled_inputs .append (self .fpn1 (inputs [0 ]))
118
+ scaled_inputs .append (self .fpn2 (inputs [1 ]))
119
+ scaled_inputs .append (self .fpn3 (inputs [2 ]))
120
+ scaled_inputs .append (self .fpn4 (inputs [3 ]))
121
+ inputs = scaled_inputs
195
122
# build laterals
196
123
laterals = [lateral_conv (inputs [i ]) for i , lateral_conv in enumerate (self .lateral_convs )]
197
124
laterals .append (self .psp_forward (inputs ))
0 commit comments