|
| 1 | +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import paddle |
| 16 | +import paddle.nn as nn |
| 17 | +import paddle.nn.functional as F |
| 18 | + |
| 19 | +from paddleseg.cvlibs import manager |
| 20 | +from paddleseg.models import layers |
| 21 | +from paddleseg.utils import utils |
| 22 | + |
| 23 | + |
| 24 | +@manager.MODELS.add_component |
| 25 | +class FastFCN(nn.Layer): |
| 26 | + """ |
| 27 | + The FastFCN implementation based on PaddlePaddle. |
| 28 | +
|
| 29 | + The original article refers to |
| 30 | + Huikai Wu, Junge Zhang, Kaiqi Huang. "FastFCN: Rethinking Dilated Convolution in the Backbone for Semantic Segmentation". |
| 31 | +
|
| 32 | + Args: |
| 33 | + num_classes (int): The unique number of target classes. |
| 34 | + backbone (Paddle.nn.Layer): A backbone network. |
| 35 | + backbone_indices (tuple): The values in the tuple indicate the indices of |
| 36 | + output of backbone. |
| 37 | + num_codes (int): The number of encoded words. Default: 32. |
| 38 | + mid_channels (int): The channels of middle layers. Default: 512. |
| 39 | + use_jpu (bool): Whether use jpu module. Default: True. |
| 40 | + aux_loss (bool): Whether use auxiliary head loss. Default: True. |
| 41 | + use_se_loss (int): Whether use semantic encoding loss. Default: True. |
| 42 | + add_lateral (int): Whether use lateral convolution layers. Default: False. |
| 43 | + pretrained (str, optional): The path or url of pretrained model. Default: None. |
| 44 | + """ |
| 45 | + def __init__(self, |
| 46 | + num_classes, |
| 47 | + backbone, |
| 48 | + num_codes=32, |
| 49 | + mid_channels=512, |
| 50 | + use_jpu=True, |
| 51 | + aux_loss=True, |
| 52 | + use_se_loss=True, |
| 53 | + add_lateral=False, |
| 54 | + pretrained=None): |
| 55 | + super().__init__() |
| 56 | + self.add_lateral = add_lateral |
| 57 | + self.num_codes = num_codes |
| 58 | + self.backbone = backbone |
| 59 | + self.use_jpu = use_jpu |
| 60 | + in_channels = self.backbone.feat_channels |
| 61 | + |
| 62 | + if use_jpu: |
| 63 | + self.jpu_layer = layers.JPU(in_channels, mid_channels) |
| 64 | + in_channels[-1] = mid_channels * 4 |
| 65 | + self.bottleneck = layers.ConvBNReLU( |
| 66 | + in_channels[-1], |
| 67 | + mid_channels, |
| 68 | + 1, |
| 69 | + padding=0, |
| 70 | + bias_attr=False, |
| 71 | + ) |
| 72 | + else: |
| 73 | + self.bottleneck = layers.ConvBNReLU( |
| 74 | + in_channels[-1], |
| 75 | + mid_channels, |
| 76 | + 3, |
| 77 | + padding=1, |
| 78 | + bias_attr=False, |
| 79 | + ) |
| 80 | + if self.add_lateral: |
| 81 | + self.lateral_convs = nn.LayerList([ |
| 82 | + layers.ConvBNReLU(in_channels[0], |
| 83 | + mid_channels, |
| 84 | + 1, |
| 85 | + bias_attr=False), |
| 86 | + layers.ConvBNReLU(in_channels[1], |
| 87 | + mid_channels, |
| 88 | + 1, |
| 89 | + bias_attr=False), |
| 90 | + ]) |
| 91 | + |
| 92 | + self.fusion = layers.ConvBNReLU( |
| 93 | + 3 * mid_channels, |
| 94 | + mid_channels, |
| 95 | + 3, |
| 96 | + padding=1, |
| 97 | + bias_attr=False, |
| 98 | + ) |
| 99 | + |
| 100 | + self.enc_module = EncModule(mid_channels, num_codes) |
| 101 | + self.cls_seg = nn.Conv2D(mid_channels, num_classes, 1) |
| 102 | + |
| 103 | + self.aux_loss = aux_loss |
| 104 | + if self.aux_loss: |
| 105 | + self.fcn_head = layers.AuxLayer(in_channels[-2], mid_channels, |
| 106 | + num_classes) |
| 107 | + |
| 108 | + self.use_se_loss = use_se_loss |
| 109 | + if use_se_loss: |
| 110 | + self.se_layer = nn.Linear(mid_channels, num_classes) |
| 111 | + |
| 112 | + self.pretrained = pretrained |
| 113 | + self.init_weight() |
| 114 | + |
| 115 | + def init_weight(self): |
| 116 | + if self.pretrained is not None: |
| 117 | + utils.load_entire_model(self, self.pretrained) |
| 118 | + |
| 119 | + def forward(self, inputs): |
| 120 | + imsize = paddle.shape(inputs)[2:] |
| 121 | + feats = self.backbone(inputs) |
| 122 | + if self.use_jpu: |
| 123 | + feats = self.jpu_layer(*feats) |
| 124 | + |
| 125 | + fcn_feat = feats[2] |
| 126 | + |
| 127 | + feat = self.bottleneck(feats[-1]) |
| 128 | + if self.add_lateral: |
| 129 | + laterals = [] |
| 130 | + for i, lateral_conv in enumerate(self.lateral_convs): |
| 131 | + laterals.append( |
| 132 | + F.interpolate(lateral_conv(feats[i]), |
| 133 | + size=paddle.shape(feat)[2:], |
| 134 | + mode='bilinear', |
| 135 | + align_corners=False)) |
| 136 | + feat = self.fusion(paddle.concat([feat, *laterals], 1)) |
| 137 | + encode_feat, feat = self.enc_module(feat) |
| 138 | + out = self.cls_seg(feat) |
| 139 | + out = F.interpolate(out, |
| 140 | + size=imsize, |
| 141 | + mode='bilinear', |
| 142 | + align_corners=False) |
| 143 | + output = [out] |
| 144 | + |
| 145 | + if self.training: |
| 146 | + fcn_out = self.fcn_head(fcn_feat) |
| 147 | + fcn_out = F.interpolate(fcn_out, |
| 148 | + size=imsize, |
| 149 | + mode='bilinear', |
| 150 | + align_corners=False) |
| 151 | + output.append(fcn_out) |
| 152 | + if self.use_se_loss: |
| 153 | + se_out = self.se_layer(encode_feat) |
| 154 | + output.append(se_out) |
| 155 | + return output |
| 156 | + return output |
| 157 | + |
| 158 | + |
| 159 | +class Encoding(nn.Layer): |
| 160 | + def __init__(self, channels, num_codes): |
| 161 | + super().__init__() |
| 162 | + self.channels, self.num_codes = channels, num_codes |
| 163 | + |
| 164 | + std = 1 / ((channels * num_codes)**0.5) |
| 165 | + self.codewords = self.create_parameter( |
| 166 | + shape=(num_codes, channels), |
| 167 | + default_initializer=nn.initializer.Uniform(-std, std), |
| 168 | + ) |
| 169 | + self.scale = self.create_parameter( |
| 170 | + shape=(num_codes, ), |
| 171 | + default_initializer=nn.initializer.Uniform(-1, 0), |
| 172 | + ) |
| 173 | + |
| 174 | + def scaled_l2(self, x, codewords, scale): |
| 175 | + num_codes, channels = paddle.shape(codewords) |
| 176 | + reshaped_scale = scale.reshape([1, 1, num_codes]) |
| 177 | + expanded_x = paddle.tile(x.unsqueeze(2), [1, 1, num_codes, 1]) |
| 178 | + reshaped_codewords = codewords.reshape([1, 1, num_codes, channels]) |
| 179 | + |
| 180 | + scaled_l2_norm = reshaped_scale * ( |
| 181 | + expanded_x - reshaped_codewords).pow(2).sum(axis=3) |
| 182 | + return scaled_l2_norm |
| 183 | + |
| 184 | + def aggregate(self, assignment_weights, x, codewords): |
| 185 | + num_codes, channels = paddle.shape(codewords) |
| 186 | + reshaped_codewords = codewords.reshape([1, 1, num_codes, channels]) |
| 187 | + expanded_x = paddle.tile( |
| 188 | + x.unsqueeze(2), |
| 189 | + [1, 1, num_codes, 1], |
| 190 | + ) |
| 191 | + encoded_feat = (assignment_weights.unsqueeze(3) * |
| 192 | + (expanded_x - reshaped_codewords)).sum(axis=1) |
| 193 | + return encoded_feat |
| 194 | + |
| 195 | + def forward(self, x): |
| 196 | + x_dims = x.ndim |
| 197 | + assert x_dims == 4, "The dimension of input tensor must equal 4, but got {}.".format( |
| 198 | + x_dims) |
| 199 | + assert paddle.shape( |
| 200 | + x |
| 201 | + )[1] == self.channels, "Encoding channels error, excepted {} but got {}.".format( |
| 202 | + self.channels, |
| 203 | + paddle.shape(x)[1]) |
| 204 | + batch_size = paddle.shape(x)[0] |
| 205 | + x = x.reshape([batch_size, self.channels, -1]).transpose([0, 2, 1]) |
| 206 | + assignment_weights = F.softmax(self.scaled_l2(x, self.codewords, |
| 207 | + self.scale), |
| 208 | + axis=2) |
| 209 | + |
| 210 | + encoded_feat = self.aggregate(assignment_weights, x, self.codewords) |
| 211 | + encoded_feat = encoded_feat.reshape([batch_size, self.num_codes, -1]) |
| 212 | + return encoded_feat |
| 213 | + |
| 214 | + |
| 215 | +class EncModule(nn.Layer): |
| 216 | + def __init__(self, in_channels, num_codes): |
| 217 | + super().__init__() |
| 218 | + self.encoding_project = layers.ConvBNReLU( |
| 219 | + in_channels, |
| 220 | + in_channels, |
| 221 | + 1, |
| 222 | + ) |
| 223 | + self.encoding = nn.Sequential( |
| 224 | + Encoding(channels=in_channels, num_codes=num_codes), |
| 225 | + nn.BatchNorm1D(num_codes), |
| 226 | + nn.ReLU(), |
| 227 | + ) |
| 228 | + self.fc = nn.Sequential( |
| 229 | + nn.Linear(in_channels, in_channels), |
| 230 | + nn.Sigmoid(), |
| 231 | + ) |
| 232 | + |
| 233 | + def forward(self, x): |
| 234 | + encoding_projection = self.encoding_project(x) |
| 235 | + encoding_feat = self.encoding(encoding_projection).mean(axis=1) |
| 236 | + batch_size, channels, _, _ = paddle.shape(x) |
| 237 | + gamma = self.fc(encoding_feat) |
| 238 | + y = gamma.reshape([batch_size, channels, 1, 1]) |
| 239 | + output = F.relu(x + x * y) |
| 240 | + return encoding_feat, output |
0 commit comments