|
| 1 | +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import paddle |
| 16 | +import paddle.nn as nn |
| 17 | +import paddle.nn.functional as F |
| 18 | +import numpy as np |
| 19 | + |
| 20 | +from paddleseg.utils import utils |
| 21 | +from paddleseg.cvlibs import manager, param_init |
| 22 | +from paddleseg.models.backbones import vision_transformer, transformer_utils |
| 23 | + |
| 24 | +__all__ = ['LinearSegmenter', 'MaskSegmenter'] |
| 25 | + |
| 26 | + |
| 27 | +@manager.MODELS.add_component |
| 28 | +class LinearSegmenter(nn.Layer): |
| 29 | + ''' |
| 30 | + The implementation of segmenter with linear head based on PaddlePaddle. |
| 31 | +
|
| 32 | + The original article refers to Strudel, Robin, et al. "Segmenter: Transformer |
| 33 | + for Semantic Segmentation." arXiv preprint arXiv:2105.05633 (2021). |
| 34 | +
|
| 35 | + Args: |
| 36 | + num_classes (int): The unique number of target classes. |
| 37 | + backbone (nn.Layer): The backbone transformer network. |
| 38 | + pretrained (str, optional): The path or url of pretrained model. Default: None. |
| 39 | + ''' |
| 40 | + |
| 41 | + def __init__(self, num_classes, backbone, pretrained=None): |
| 42 | + super().__init__() |
| 43 | + self.backbone = backbone |
| 44 | + self.head = SegmenterLinearHead(num_classes, backbone.embed_dim) |
| 45 | + self.pretrained = pretrained |
| 46 | + self.init_weight() |
| 47 | + |
| 48 | + def init_weight(self): |
| 49 | + if self.pretrained is not None: |
| 50 | + utils.load_entire_model(self, self.pretrained) |
| 51 | + |
| 52 | + def forward(self, x): |
| 53 | + x_shape = paddle.shape(x) |
| 54 | + |
| 55 | + feats, shape = self.backbone(x) |
| 56 | + logits = self.head(feats[-1], shape[2:]) |
| 57 | + |
| 58 | + logit_list = [ |
| 59 | + F.interpolate(logit, x_shape[2:], mode='bilinear') |
| 60 | + for logit in logits |
| 61 | + ] |
| 62 | + |
| 63 | + return logit_list |
| 64 | + |
| 65 | + |
| 66 | +@manager.MODELS.add_component |
| 67 | +class MaskSegmenter(nn.Layer): |
| 68 | + ''' |
| 69 | + The implementation of segmenter with mask head based on PaddlePaddle. |
| 70 | +
|
| 71 | + The original article refers to Strudel, Robin, et al. "Segmenter: Transformer |
| 72 | + for Semantic Segmentation." arXiv preprint arXiv:2105.05633 (2021). |
| 73 | +
|
| 74 | + Args: |
| 75 | + num_classes (int): The unique number of target classes. |
| 76 | + backbone (nn.Layer): The backbone transformer network. |
| 77 | + h_embed_dim (int): The embedding dim in mask head. |
| 78 | + h_depth (int): The num of layers in mask head. |
| 79 | + h_num_heads (int): The num of heads of MSA in mask head. |
| 80 | + h_mlp_ratio (int, optional): Ratio of MLP dim in mask head. Default: 4. |
| 81 | + h_drop_rate (float, optional): Drop rate of MLP in mask head. Default: 0.0. |
| 82 | + h_drop_path_rate (float, optional): Drop path rate in mask head. Default: 0.0. |
| 83 | + h_attn_drop_rate (float, optional): Attenation drop rate in mask head. Default: 0.0. |
| 84 | + h_qkv_bias (bool, optional): Whether add bias in mask head. Default: False. |
| 85 | + pretrained (str, optional): The path or url of pretrained model. Default: None. |
| 86 | + ''' |
| 87 | + |
| 88 | + def __init__(self, |
| 89 | + num_classes, |
| 90 | + backbone, |
| 91 | + h_embed_dim, |
| 92 | + h_depth, |
| 93 | + h_num_heads, |
| 94 | + h_mlp_ratio=4, |
| 95 | + h_drop_rate=0.0, |
| 96 | + h_drop_path_rate=0.0, |
| 97 | + h_attn_drop_rate=0.0, |
| 98 | + h_qkv_bias=False, |
| 99 | + pretrained=None): |
| 100 | + super().__init__() |
| 101 | + self.backbone = backbone |
| 102 | + self.head = SegmenterMaskHead( |
| 103 | + num_classes, backbone.embed_dim, h_embed_dim, h_depth, h_num_heads, |
| 104 | + h_mlp_ratio, h_drop_rate, h_drop_path_rate, h_attn_drop_rate, |
| 105 | + h_qkv_bias) |
| 106 | + self.pretrained = pretrained |
| 107 | + self.init_weight() |
| 108 | + |
| 109 | + def init_weight(self): |
| 110 | + if self.pretrained is not None: |
| 111 | + utils.load_entire_model(self, self.pretrained) |
| 112 | + |
| 113 | + def forward(self, x): |
| 114 | + x_shape = paddle.shape(x) |
| 115 | + |
| 116 | + feats, shape = self.backbone(x) |
| 117 | + logits = self.head(feats[-1], shape[2:]) |
| 118 | + |
| 119 | + logit_list = [ |
| 120 | + F.interpolate(logit, x_shape[2:], mode='bilinear') |
| 121 | + for logit in logits |
| 122 | + ] |
| 123 | + |
| 124 | + return logit_list |
| 125 | + |
| 126 | + |
| 127 | +class SegmenterLinearHead(nn.Layer): |
| 128 | + ''' |
| 129 | + The linear head of Segmenter. |
| 130 | + Args: |
| 131 | + num_classes (int): The unique number of target classes. |
| 132 | + in_dim (int): The embed dim of input. |
| 133 | + ''' |
| 134 | + |
| 135 | + def __init__(self, num_classes, in_dim): |
| 136 | + super().__init__() |
| 137 | + self.head = nn.Linear(in_dim, num_classes) |
| 138 | + self.apply(transformer_utils.init_weights) |
| 139 | + |
| 140 | + def forward(self, x, patch_embed_size): |
| 141 | + """ Forward function. |
| 142 | + Args: |
| 143 | + x (Tensor): Input tensor of decoder. |
| 144 | + patch_embed_size (Tensor): The height and width of the patch embed tensor. |
| 145 | + Returns: |
| 146 | + list[Tensor]: Segmentation results. |
| 147 | + """ |
| 148 | + masks = self.head(x) |
| 149 | + |
| 150 | + #[b, (h w), c] -> [b, c, h, w] |
| 151 | + h, w = patch_embed_size[0], patch_embed_size[1] |
| 152 | + masks = masks.reshape((0, h, w, paddle.shape(masks)[-1])) |
| 153 | + masks = masks.transpose((0, 3, 1, 2)) |
| 154 | + |
| 155 | + return [masks] |
| 156 | + |
| 157 | + |
| 158 | +class SegmenterMaskHead(nn.Layer): |
| 159 | + ''' |
| 160 | + The mask head of segmenter. |
| 161 | + Args: |
| 162 | + num_classes (int): The unique number of target classes. |
| 163 | + in_dim (int): The embed dim of input. |
| 164 | + embed_dim (int): Embedding dim of mask transformer. |
| 165 | + depth (int): The num of layers in Transformer. |
| 166 | + num_heads (int): The num of heads in MSA. |
| 167 | + mlp_ratio (int, optional): Ratio of MLP dim. Default: 4. |
| 168 | + drop_rate (float, optional): Drop rate of MLP in MSA. Default: 0.0. |
| 169 | + drop_path_rate (float, optional): Drop path rate in MSA. Default: 0.0. |
| 170 | + attn_drop_rate (float, optional): Attenation drop rate in MSA. Default: 0.0. |
| 171 | + qkv_bias (bool, optional): Whether add bias in qkv linear. Default: False. |
| 172 | + ''' |
| 173 | + |
| 174 | + def __init__(self, |
| 175 | + num_classes, |
| 176 | + in_dim, |
| 177 | + embed_dim, |
| 178 | + depth, |
| 179 | + num_heads, |
| 180 | + mlp_ratio=4, |
| 181 | + drop_rate=0.0, |
| 182 | + drop_path_rate=0.0, |
| 183 | + attn_drop_rate=0.0, |
| 184 | + qkv_bias=False): |
| 185 | + super().__init__() |
| 186 | + self.num_classes = num_classes |
| 187 | + |
| 188 | + self.proj_input = nn.Linear(in_dim, embed_dim) |
| 189 | + |
| 190 | + self.cls_token = self.create_parameter( |
| 191 | + shape=(1, num_classes, embed_dim), |
| 192 | + default_initializer=paddle.nn.initializer.TruncatedNormal(std=0.02)) |
| 193 | + |
| 194 | + dpr = [x for x in np.linspace(0, drop_path_rate, depth)] |
| 195 | + self.blocks = nn.LayerList([ |
| 196 | + vision_transformer.Block( |
| 197 | + dim=embed_dim, |
| 198 | + num_heads=num_heads, |
| 199 | + mlp_ratio=mlp_ratio, |
| 200 | + drop=drop_rate, |
| 201 | + drop_path=dpr[i], |
| 202 | + attn_drop=attn_drop_rate, |
| 203 | + qkv_bias=qkv_bias) for i in range(depth) |
| 204 | + ]) |
| 205 | + |
| 206 | + initializer = paddle.nn.initializer.TruncatedNormal(std=0.02) |
| 207 | + self.proj_patch = nn.Linear( |
| 208 | + embed_dim, |
| 209 | + embed_dim, |
| 210 | + weight_attr=paddle.ParamAttr(initializer=initializer), |
| 211 | + bias_attr=False) |
| 212 | + self.proj_class = nn.Linear( |
| 213 | + embed_dim, |
| 214 | + embed_dim, |
| 215 | + weight_attr=paddle.ParamAttr(initializer=initializer), |
| 216 | + bias_attr=False) |
| 217 | + |
| 218 | + self.decoder_norm = nn.LayerNorm(embed_dim) |
| 219 | + self.mask_norm = nn.LayerNorm(num_classes) |
| 220 | + |
| 221 | + self.apply(transformer_utils.init_weights) |
| 222 | + |
| 223 | + def forward(self, x, patch_embed_size): |
| 224 | + """ Forward function. |
| 225 | + Args: |
| 226 | + x (Tensor): Input tensor of decoder. |
| 227 | + patch_embed_size (Tensor): The height and width of the patch embed tensor. |
| 228 | + Returns: |
| 229 | + list[Tensor]: Segmentation results. |
| 230 | + """ |
| 231 | + x = self.proj_input(x) |
| 232 | + |
| 233 | + cls_token = self.cls_token.expand((paddle.shape(x)[0], -1, -1)) |
| 234 | + x = paddle.concat([x, cls_token], axis=1) |
| 235 | + |
| 236 | + for block in self.blocks: |
| 237 | + x = block(x) |
| 238 | + x = self.decoder_norm(x) |
| 239 | + |
| 240 | + patches, masks = x[:, :-self.num_classes], x[:, -self.num_classes:] |
| 241 | + patches = self.proj_patch(patches) |
| 242 | + masks = self.proj_class(masks) |
| 243 | + patches = patches / paddle.norm(patches, axis=-1, keepdim=True) |
| 244 | + masks = masks / paddle.norm(masks, axis=-1, keepdim=True) |
| 245 | + |
| 246 | + masks = patches @ masks.transpose((0, 2, 1)) |
| 247 | + masks = masks.reshape((0, 0, |
| 248 | + self.num_classes)) # For export inference model |
| 249 | + masks = self.mask_norm(masks) |
| 250 | + |
| 251 | + #[b, (h w), c] -> [b, c, h, w] |
| 252 | + h, w = patch_embed_size[0], patch_embed_size[1] |
| 253 | + masks = masks.reshape((0, h, w, paddle.shape(masks)[-1])) |
| 254 | + masks = masks.transpose((0, 3, 1, 2)) |
| 255 | + |
| 256 | + return [masks] |
0 commit comments