Skip to content

Commit ab87623

Browse files
authoredOct 29, 2021
[Feature] Add segmenter (PaddlePaddle#1469)
1 parent 02688dd commit ab87623

10 files changed

+462
-40
lines changed
 

‎configs/segmenter/README.md

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Segmenter: Transformer for Semantic Segmentation
2+
3+
## Reference
4+
5+
> Strudel, Robin, Ricardo Garcia, Ivan Laptev, and Cordelia Schmid. "Segmenter: Transformer for Semantic Segmentation." In Proceedings of the IEEE International Conference on Computer Vision, pp. 7262-7272. 2021.
6+
7+
## Performance
8+
9+
### ADE20k
10+
11+
| Model | Backbone | Head | Patch Size | Resolution | Training Iters | mIoU (slice) | mIoU (flip) | Links |
12+
| :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: | :-: |
13+
| Segmenter | ViT small | Linear | 16 | 512*512 | 160000 | 45.48 | 45.69 | [model](https://paddleseg.bj.bcebos.com/dygraph/ade20k/segmenter_vit_small_linear_ade20k_512x512_160k/model.pdparams) \| [log](https://paddleseg.bj.bcebos.com/dygraph/ade20k/segmenter_vit_small_linear_ade20k_512x512_160k/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=4dc954a9b774e4807c07c511c04ce0f6) |
14+
| Segmenter | ViT small | Mask | 16 | 512*512 | 160000 | 45.15 | 45.41 | [model](https://paddleseg.bj.bcebos.com/dygraph/ade20k/segmenter_vit_small_mask_ade20k_512x512_160k/model.pdparams) \| [log](https://paddleseg.bj.bcebos.com/dygraph/ade20k/segmenter_vit_small_mask_ade20k_512x512_160k/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=0fdd5191ecec56bbdf08259cc6c32a21) |
15+
| Segmenter | ViT base | Linear | 16 | 512*512 | 160000 | 48.13 | 48.31 | [model](https://paddleseg.bj.bcebos.com/dygraph/ade20k/segmenter_vit_base_linear_ade20k_512x512_160k/model.pdparams) \| [log](https://paddleseg.bj.bcebos.com/dygraph/ade20k/segmenter_vit_base_linear_ade20k_512x512_160k/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/index?id=992f38b3f937de87dc74a888d217f53e) |
16+
| Segmenter | ViT base | Mask | 16 | 512*512 | 160000 | 48.49 | 48.61 | [model](https://paddleseg.bj.bcebos.com/dygraph/ade20k/segmenter_vit_base_mask_ade20k_512x512_160k/model.pdparams) \| [log](https://paddleseg.bj.bcebos.com/dygraph/ade20k/segmenter_vit_base_mask_ade20k_512x512_160k/train.log) \| [vdl](https://www.paddlepaddle.org.cn/paddle/visualdl/service/app/scalar?id=16a7380069b6435bdf6e566dcc7f4a6b) |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
_base_: '../_base_/ade20k.yml'
2+
3+
batch_size: 2
4+
iters: 160000
5+
6+
model:
7+
type: LinearSegmenter
8+
backbone:
9+
type: VisionTransformer
10+
img_size: 512
11+
patch_size: 16
12+
embed_dim: 768
13+
depth: 12
14+
num_heads: 12
15+
mlp_ratio: 4
16+
qkv_bias: True
17+
drop_rate: 0.0
18+
drop_path_rate: 0.1
19+
final_norm: True
20+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/pretrained_models/vit_base_patch16_384_augreg.tar.gz
21+
22+
val_dataset:
23+
transforms:
24+
- type: ResizeByShort
25+
short_size: 512
26+
- type: Normalize
27+
28+
optimizer:
29+
weight_decay: 0.0
30+
31+
lr_scheduler:
32+
learning_rate: 0.001
33+
end_lr: 1.0e-05
34+
35+
test_config:
36+
is_slide: True
37+
crop_size: [512, 512]
38+
stride: [512, 512]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
_base_: './segmenter_vit_base_linear_ade20k_512x512_160k.yml'
2+
3+
model:
4+
type: MaskSegmenter
5+
h_embed_dim: 768
6+
h_depth: 2
7+
h_num_heads: 12
8+
h_mlp_ratio: 4
9+
h_drop_rate: 0.0
10+
h_drop_path_rate: 0.1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
_base_: './segmenter_vit_base_linear_ade20k_512x512_160k.yml'
2+
3+
model:
4+
type: LinearSegmenter
5+
backbone:
6+
type: VisionTransformer
7+
img_size: 512
8+
patch_size: 16
9+
embed_dim: 384
10+
depth: 12
11+
num_heads: 6
12+
mlp_ratio: 4
13+
qkv_bias: True
14+
drop_rate: 0.0
15+
drop_path_rate: 0.1
16+
final_norm: True
17+
pretrained: https://bj.bcebos.com/paddleseg/dygraph/pretrained_models/vit_small_patch16_384_augreg.tar.gz
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
_base_: './segmenter_vit_small_linear_ade20k_512x512_160k.yml'
2+
3+
model:
4+
type: MaskSegmenter
5+
h_embed_dim: 384
6+
h_depth: 2
7+
h_num_heads: 6
8+
h_mlp_ratio: 4
9+
h_drop_rate: 0.0
10+
h_drop_path_rate: 0.1

‎paddleseg/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,5 @@
4444
from .segformer import SegFormer
4545
from .pointrend import PointRend
4646
from .ginet import GINet
47+
from .segmenter import *
4748
from .segnet import SegNet

‎paddleseg/models/backbones/transformer_utils.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
import paddle.nn.initializer as paddle_init
1818

1919
__all__ = [
20-
'to_2tuple', 'DropPath', 'Identity', 'trunc_normal_', 'zeros_', 'ones_'
20+
'to_2tuple', 'DropPath', 'Identity', 'trunc_normal_', 'zeros_', 'ones_',
21+
'init_weights'
2122
]
2223

2324

@@ -63,3 +64,20 @@ def forward(self, input):
6364
trunc_normal_ = paddle_init.TruncatedNormal(std=.02)
6465
zeros_ = paddle_init.Constant(value=0.)
6566
ones_ = paddle_init.Constant(value=1.)
67+
68+
69+
def init_weights(layer):
70+
"""
71+
Init the weights of transformer.
72+
Args:
73+
layer(nn.Layer): The layer to init weights.
74+
Returns:
75+
None
76+
"""
77+
if isinstance(layer, nn.Linear):
78+
trunc_normal_(layer.weight)
79+
if layer.bias is not None:
80+
zeros_(layer.bias)
81+
elif isinstance(layer, nn.LayerNorm):
82+
zeros_(layer.bias)
83+
ones_(layer.weight)

‎paddleseg/models/backbones/vision_transformer.py

+65-21
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
16+
import math
17+
1518
import paddle
1619
import paddle.nn as nn
1720
import paddle.nn.functional as F
1821
import numpy as np
1922

2023
from paddleseg.cvlibs import manager
21-
from paddleseg.utils import utils
22-
from paddleseg.models.backbones.transformer_utils import *
24+
from paddleseg.utils import utils, logger
25+
from paddleseg.models.backbones.transformer_utils import to_2tuple, DropPath, Identity
2326

2427

2528
class Mlp(nn.Layer):
@@ -145,6 +148,7 @@ def forward(self, x):
145148
return x
146149

147150

151+
@manager.BACKBONES.add_component
148152
class VisionTransformer(nn.Layer):
149153
""" Vision Transformer with support for patch input
150154
"""
@@ -164,10 +168,11 @@ def __init__(self,
164168
drop_path_rate=0.,
165169
norm_layer='nn.LayerNorm',
166170
epsilon=1e-5,
171+
final_norm=False,
167172
pretrained=None,
168173
**args):
169174
super().__init__()
170-
self.depth = depth
175+
self.img_size = img_size
171176
self.embed_dim = embed_dim
172177

173178
self.patch_embed = PatchEmbed(
@@ -180,12 +185,10 @@ def __init__(self,
180185

181186
self.pos_embed = self.create_parameter(
182187
shape=(1, self.pos_w * self.pos_h + 1, embed_dim),
183-
default_initializer=paddle.nn.initializer.Constant(value=0.))
184-
self.add_parameter("pos_embed", self.pos_embed)
188+
default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
185189
self.cls_token = self.create_parameter(
186190
shape=(1, 1, embed_dim),
187191
default_initializer=paddle.nn.initializer.Constant(value=0.))
188-
self.add_parameter("cls_token", self.cls_token)
189192
self.pos_drop = nn.Dropout(p=drop_rate)
190193

191194
dpr = np.linspace(0, drop_path_rate, depth)
@@ -204,40 +207,81 @@ def __init__(self,
204207
epsilon=epsilon) for i in range(depth)
205208
])
206209

207-
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
210+
self.final_norm = final_norm
211+
if self.final_norm:
212+
self.norm = eval(norm_layer)(embed_dim, epsilon=epsilon)
208213
self.pretrained = pretrained
209214
self.init_weight()
210215

211216
def init_weight(self):
212217
utils.load_pretrained_model(self, self.pretrained)
213218

214-
def forward_features(self, x):
215-
x = self.patch_embed(x)
216-
x_shape = paddle.shape(x)
217-
pos_embed = self.pos_embed[:, 1:, :]
218-
cls_pos_embed = self.pos_embed[:, :1, :]
219-
cls_tokens = self.cls_token.expand((x_shape[0], -1, -1))
219+
# load and resize pos_embed
220+
model_path = self.pretrained
221+
if not os.path.exists(model_path):
222+
model_path = utils.download_pretrained_model(model_path)
223+
224+
load_state_dict = paddle.load(model_path)
225+
model_state_dict = self.state_dict()
226+
pos_embed_name = "pos_embed"
227+
if pos_embed_name in load_state_dict.keys():
228+
load_pos_embed = paddle.to_tensor(
229+
load_state_dict[pos_embed_name], dtype="float32")
230+
if self.pos_embed.shape != load_pos_embed.shape:
231+
pos_size = int(math.sqrt(load_pos_embed.shape[1] - 1))
232+
model_state_dict[pos_embed_name] = self.resize_pos_embed(
233+
load_pos_embed, (pos_size, pos_size),
234+
(self.pos_h, self.pos_w))
235+
self.set_dict(model_state_dict)
236+
logger.info(
237+
"Load pos_embed and resize it from {} to {} .".format(
238+
load_pos_embed.shape, self.pos_embed.shape))
239+
240+
def resize_pos_embed(self, pos_embed, old_hw, new_hw):
241+
"""
242+
Resize pos_embed weight.
243+
Args:
244+
pos_embed (Tensor): the pos_embed weight
245+
old_hw (list[int]): the height and width of old pos_embed
246+
new_hw (list[int]): the height and width of new pos_embed
247+
Returns:
248+
Tensor: the resized pos_embed weight
249+
"""
250+
cls_pos_embed = pos_embed[:, :1, :]
251+
pos_embed = pos_embed[:, 1:, :]
220252

221253
pos_embed = pos_embed.transpose([0, 2, 1])
222-
pos_embed = pos_embed.reshape([1, -1, self.pos_h, self.pos_w])
254+
pos_embed = pos_embed.reshape([1, -1, old_hw[0], old_hw[1]])
223255
pos_embed = F.interpolate(
224-
pos_embed, x_shape[2:], mode='bilinear', align_corners=False)
225-
256+
pos_embed, new_hw, mode='bicubic', align_corners=False)
226257
pos_embed = pos_embed.flatten(2).transpose([0, 2, 1])
227258
pos_embed = paddle.concat([cls_pos_embed, pos_embed], axis=1)
228-
x = x.flatten(2).transpose([0, 2, 1])
259+
260+
return pos_embed
261+
262+
def forward(self, x):
263+
x = self.patch_embed(x)
264+
x_shape = paddle.shape(x) # b * c * h * w
265+
266+
cls_tokens = self.cls_token.expand((x_shape[0], -1, -1))
267+
x = x.flatten(2).transpose([0, 2, 1]) # b * hw * c
229268
x = paddle.concat([cls_tokens, x], axis=1)
230-
x = x + pos_embed
231269

270+
if paddle.shape(x)[1] == self.pos_embed.shape[1]:
271+
x = x + self.pos_embed
272+
else:
273+
x = x + self.resize_pos_embed(self.pos_embed,
274+
(self.pos_h, self.pos_w), x_shape[2:])
232275
x = self.pos_drop(x)
276+
233277
res = []
234278
for idx, blk in enumerate(self.blocks):
235279
x = blk(x)
280+
if self.final_norm and idx == len(self.blocks) - 1:
281+
x = self.norm(x)
236282
res.append(x[:, 1:, :])
237-
return res, x_shape
238283

239-
def forward(self, x):
240-
return self.forward_features(x)
284+
return res, x_shape
241285

242286

243287
@manager.BACKBONES.add_component

‎paddleseg/models/segmenter.py

+256
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import paddle
16+
import paddle.nn as nn
17+
import paddle.nn.functional as F
18+
import numpy as np
19+
20+
from paddleseg.utils import utils
21+
from paddleseg.cvlibs import manager, param_init
22+
from paddleseg.models.backbones import vision_transformer, transformer_utils
23+
24+
__all__ = ['LinearSegmenter', 'MaskSegmenter']
25+
26+
27+
@manager.MODELS.add_component
28+
class LinearSegmenter(nn.Layer):
29+
'''
30+
The implementation of segmenter with linear head based on PaddlePaddle.
31+
32+
The original article refers to Strudel, Robin, et al. "Segmenter: Transformer
33+
for Semantic Segmentation." arXiv preprint arXiv:2105.05633 (2021).
34+
35+
Args:
36+
num_classes (int): The unique number of target classes.
37+
backbone (nn.Layer): The backbone transformer network.
38+
pretrained (str, optional): The path or url of pretrained model. Default: None.
39+
'''
40+
41+
def __init__(self, num_classes, backbone, pretrained=None):
42+
super().__init__()
43+
self.backbone = backbone
44+
self.head = SegmenterLinearHead(num_classes, backbone.embed_dim)
45+
self.pretrained = pretrained
46+
self.init_weight()
47+
48+
def init_weight(self):
49+
if self.pretrained is not None:
50+
utils.load_entire_model(self, self.pretrained)
51+
52+
def forward(self, x):
53+
x_shape = paddle.shape(x)
54+
55+
feats, shape = self.backbone(x)
56+
logits = self.head(feats[-1], shape[2:])
57+
58+
logit_list = [
59+
F.interpolate(logit, x_shape[2:], mode='bilinear')
60+
for logit in logits
61+
]
62+
63+
return logit_list
64+
65+
66+
@manager.MODELS.add_component
67+
class MaskSegmenter(nn.Layer):
68+
'''
69+
The implementation of segmenter with mask head based on PaddlePaddle.
70+
71+
The original article refers to Strudel, Robin, et al. "Segmenter: Transformer
72+
for Semantic Segmentation." arXiv preprint arXiv:2105.05633 (2021).
73+
74+
Args:
75+
num_classes (int): The unique number of target classes.
76+
backbone (nn.Layer): The backbone transformer network.
77+
h_embed_dim (int): The embedding dim in mask head.
78+
h_depth (int): The num of layers in mask head.
79+
h_num_heads (int): The num of heads of MSA in mask head.
80+
h_mlp_ratio (int, optional): Ratio of MLP dim in mask head. Default: 4.
81+
h_drop_rate (float, optional): Drop rate of MLP in mask head. Default: 0.0.
82+
h_drop_path_rate (float, optional): Drop path rate in mask head. Default: 0.0.
83+
h_attn_drop_rate (float, optional): Attenation drop rate in mask head. Default: 0.0.
84+
h_qkv_bias (bool, optional): Whether add bias in mask head. Default: False.
85+
pretrained (str, optional): The path or url of pretrained model. Default: None.
86+
'''
87+
88+
def __init__(self,
89+
num_classes,
90+
backbone,
91+
h_embed_dim,
92+
h_depth,
93+
h_num_heads,
94+
h_mlp_ratio=4,
95+
h_drop_rate=0.0,
96+
h_drop_path_rate=0.0,
97+
h_attn_drop_rate=0.0,
98+
h_qkv_bias=False,
99+
pretrained=None):
100+
super().__init__()
101+
self.backbone = backbone
102+
self.head = SegmenterMaskHead(
103+
num_classes, backbone.embed_dim, h_embed_dim, h_depth, h_num_heads,
104+
h_mlp_ratio, h_drop_rate, h_drop_path_rate, h_attn_drop_rate,
105+
h_qkv_bias)
106+
self.pretrained = pretrained
107+
self.init_weight()
108+
109+
def init_weight(self):
110+
if self.pretrained is not None:
111+
utils.load_entire_model(self, self.pretrained)
112+
113+
def forward(self, x):
114+
x_shape = paddle.shape(x)
115+
116+
feats, shape = self.backbone(x)
117+
logits = self.head(feats[-1], shape[2:])
118+
119+
logit_list = [
120+
F.interpolate(logit, x_shape[2:], mode='bilinear')
121+
for logit in logits
122+
]
123+
124+
return logit_list
125+
126+
127+
class SegmenterLinearHead(nn.Layer):
128+
'''
129+
The linear head of Segmenter.
130+
Args:
131+
num_classes (int): The unique number of target classes.
132+
in_dim (int): The embed dim of input.
133+
'''
134+
135+
def __init__(self, num_classes, in_dim):
136+
super().__init__()
137+
self.head = nn.Linear(in_dim, num_classes)
138+
self.apply(transformer_utils.init_weights)
139+
140+
def forward(self, x, patch_embed_size):
141+
""" Forward function.
142+
Args:
143+
x (Tensor): Input tensor of decoder.
144+
patch_embed_size (Tensor): The height and width of the patch embed tensor.
145+
Returns:
146+
list[Tensor]: Segmentation results.
147+
"""
148+
masks = self.head(x)
149+
150+
#[b, (h w), c] -> [b, c, h, w]
151+
h, w = patch_embed_size[0], patch_embed_size[1]
152+
masks = masks.reshape((0, h, w, paddle.shape(masks)[-1]))
153+
masks = masks.transpose((0, 3, 1, 2))
154+
155+
return [masks]
156+
157+
158+
class SegmenterMaskHead(nn.Layer):
159+
'''
160+
The mask head of segmenter.
161+
Args:
162+
num_classes (int): The unique number of target classes.
163+
in_dim (int): The embed dim of input.
164+
embed_dim (int): Embedding dim of mask transformer.
165+
depth (int): The num of layers in Transformer.
166+
num_heads (int): The num of heads in MSA.
167+
mlp_ratio (int, optional): Ratio of MLP dim. Default: 4.
168+
drop_rate (float, optional): Drop rate of MLP in MSA. Default: 0.0.
169+
drop_path_rate (float, optional): Drop path rate in MSA. Default: 0.0.
170+
attn_drop_rate (float, optional): Attenation drop rate in MSA. Default: 0.0.
171+
qkv_bias (bool, optional): Whether add bias in qkv linear. Default: False.
172+
'''
173+
174+
def __init__(self,
175+
num_classes,
176+
in_dim,
177+
embed_dim,
178+
depth,
179+
num_heads,
180+
mlp_ratio=4,
181+
drop_rate=0.0,
182+
drop_path_rate=0.0,
183+
attn_drop_rate=0.0,
184+
qkv_bias=False):
185+
super().__init__()
186+
self.num_classes = num_classes
187+
188+
self.proj_input = nn.Linear(in_dim, embed_dim)
189+
190+
self.cls_token = self.create_parameter(
191+
shape=(1, num_classes, embed_dim),
192+
default_initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
193+
194+
dpr = [x for x in np.linspace(0, drop_path_rate, depth)]
195+
self.blocks = nn.LayerList([
196+
vision_transformer.Block(
197+
dim=embed_dim,
198+
num_heads=num_heads,
199+
mlp_ratio=mlp_ratio,
200+
drop=drop_rate,
201+
drop_path=dpr[i],
202+
attn_drop=attn_drop_rate,
203+
qkv_bias=qkv_bias) for i in range(depth)
204+
])
205+
206+
initializer = paddle.nn.initializer.TruncatedNormal(std=0.02)
207+
self.proj_patch = nn.Linear(
208+
embed_dim,
209+
embed_dim,
210+
weight_attr=paddle.ParamAttr(initializer=initializer),
211+
bias_attr=False)
212+
self.proj_class = nn.Linear(
213+
embed_dim,
214+
embed_dim,
215+
weight_attr=paddle.ParamAttr(initializer=initializer),
216+
bias_attr=False)
217+
218+
self.decoder_norm = nn.LayerNorm(embed_dim)
219+
self.mask_norm = nn.LayerNorm(num_classes)
220+
221+
self.apply(transformer_utils.init_weights)
222+
223+
def forward(self, x, patch_embed_size):
224+
""" Forward function.
225+
Args:
226+
x (Tensor): Input tensor of decoder.
227+
patch_embed_size (Tensor): The height and width of the patch embed tensor.
228+
Returns:
229+
list[Tensor]: Segmentation results.
230+
"""
231+
x = self.proj_input(x)
232+
233+
cls_token = self.cls_token.expand((paddle.shape(x)[0], -1, -1))
234+
x = paddle.concat([x, cls_token], axis=1)
235+
236+
for block in self.blocks:
237+
x = block(x)
238+
x = self.decoder_norm(x)
239+
240+
patches, masks = x[:, :-self.num_classes], x[:, -self.num_classes:]
241+
patches = self.proj_patch(patches)
242+
masks = self.proj_class(masks)
243+
patches = patches / paddle.norm(patches, axis=-1, keepdim=True)
244+
masks = masks / paddle.norm(masks, axis=-1, keepdim=True)
245+
246+
masks = patches @ masks.transpose((0, 2, 1))
247+
masks = masks.reshape((0, 0,
248+
self.num_classes)) # For export inference model
249+
masks = self.mask_norm(masks)
250+
251+
#[b, (h w), c] -> [b, c, h, w]
252+
h, w = patch_embed_size[0], patch_embed_size[1]
253+
masks = masks.reshape((0, h, w, paddle.shape(masks)[-1]))
254+
masks = masks.transpose((0, 3, 1, 2))
255+
256+
return [masks]

‎paddleseg/utils/utils.py

+30-18
Original file line numberDiff line numberDiff line change
@@ -41,29 +41,41 @@ def load_entire_model(model, pretrained):
4141
logger.warning('Not all pretrained params of {} are loaded, ' \
4242
'training from scratch or a pretrained backbone.'.format(model.__class__.__name__))
4343

44+
def download_pretrained_model(pretrained_model):
45+
"""
46+
Download pretrained model from url.
47+
Args:
48+
pretrained_model (str): the url of pretrained weight
49+
Returns:
50+
str: the path of pretrained weight
51+
"""
52+
assert urlparse(pretrained_model).netloc, "The url is not valid."
53+
54+
pretrained_model = unquote(pretrained_model)
55+
savename = pretrained_model.split('/')[-1]
56+
if not savename.endswith(('tgz', 'tar.gz', 'tar', 'zip')):
57+
savename = pretrained_model.split('/')[-2]
58+
else:
59+
savename = savename.split('.')[0]
60+
61+
with generate_tempdir() as _dir:
62+
with filelock.FileLock(
63+
os.path.join(seg_env.TMP_HOME, savename)):
64+
pretrained_model = download_file_and_uncompress(
65+
pretrained_model,
66+
savepath=_dir,
67+
extrapath=seg_env.PRETRAINED_MODEL_HOME,
68+
extraname=savename)
69+
pretrained_model = os.path.join(pretrained_model,
70+
'model.pdparams')
71+
return pretrained_model
4472

4573
def load_pretrained_model(model, pretrained_model):
4674
if pretrained_model is not None:
4775
logger.info('Loading pretrained model from {}'.format(pretrained_model))
48-
# download pretrained model from url
76+
4977
if urlparse(pretrained_model).netloc:
50-
pretrained_model = unquote(pretrained_model)
51-
savename = pretrained_model.split('/')[-1]
52-
if not savename.endswith(('tgz', 'tar.gz', 'tar', 'zip')):
53-
savename = pretrained_model.split('/')[-2]
54-
else:
55-
savename = savename.split('.')[0]
56-
with generate_tempdir() as _dir:
57-
with filelock.FileLock(
58-
os.path.join(seg_env.TMP_HOME, savename)):
59-
pretrained_model = download_file_and_uncompress(
60-
pretrained_model,
61-
savepath=_dir,
62-
extrapath=seg_env.PRETRAINED_MODEL_HOME,
63-
extraname=savename)
64-
65-
pretrained_model = os.path.join(pretrained_model,
66-
'model.pdparams')
78+
pretrained_model = download_pretrained_model(pretrained_model)
6779

6880
if os.path.exists(pretrained_model):
6981
para_state_dict = paddle.load(pretrained_model)

0 commit comments

Comments
 (0)
Please sign in to comment.