-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathmodel.py
95 lines (74 loc) · 3.1 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.layers import AveragePooling2D, GlobalAveragePooling2D, UpSampling2D, Reshape, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.applications import ResNet50
import tensorflow as tf
def SqueezeAndExcite(inputs, ratio=8):
init = inputs
filters = init.shape[-1]
se_shape = (1, 1, filters)
se = GlobalAveragePooling2D()(init)
se = Reshape(se_shape)(se)
se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
x = init * se
return x
def ASPP(inputs):
""" Image Pooling """
shape = inputs.shape
y1 = AveragePooling2D(pool_size=(shape[1], shape[2]))(inputs)
y1 = Conv2D(256, 1, padding="same", use_bias=False)(y1)
y1 = BatchNormalization()(y1)
y1 = Activation("relu")(y1)
y1 = UpSampling2D((shape[1], shape[2]), interpolation="bilinear")(y1)
""" 1x1 conv """
y2 = Conv2D(256, 1, padding="same", use_bias=False)(inputs)
y2 = BatchNormalization()(y2)
y2 = Activation("relu")(y2)
""" 3x3 conv rate=6 """
y3 = Conv2D(256, 3, padding="same", use_bias=False, dilation_rate=6)(inputs)
y3 = BatchNormalization()(y3)
y3 = Activation("relu")(y3)
""" 3x3 conv rate=12 """
y4 = Conv2D(256, 3, padding="same", use_bias=False, dilation_rate=12)(inputs)
y4 = BatchNormalization()(y4)
y4 = Activation("relu")(y4)
""" 3x3 conv rate=18 """
y5 = Conv2D(256, 3, padding="same", use_bias=False, dilation_rate=18)(inputs)
y5 = BatchNormalization()(y5)
y5 = Activation("relu")(y5)
y = Concatenate()([y1, y2, y3, y4, y5])
y = Conv2D(256, 1, padding="same", use_bias=False)(y)
y = BatchNormalization()(y)
y = Activation("relu")(y)
return y
def deeplabv3_plus(shape):
""" Input """
inputs = Input(shape)
""" Encoder """
encoder = ResNet50(weights="imagenet", include_top=False, input_tensor=inputs)
image_features = encoder.get_layer("conv4_block6_out").output
x_a = ASPP(image_features)
x_a = UpSampling2D((4, 4), interpolation="bilinear")(x_a)
x_b = encoder.get_layer("conv2_block2_out").output
x_b = Conv2D(filters=48, kernel_size=1, padding='same', use_bias=False)(x_b)
x_b = BatchNormalization()(x_b)
x_b = Activation('relu')(x_b)
x = Concatenate()([x_a, x_b])
x = SqueezeAndExcite(x)
x = Conv2D(filters=256, kernel_size=3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(filters=256, kernel_size=3, padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SqueezeAndExcite(x)
x = UpSampling2D((4, 4), interpolation="bilinear")(x)
x = Conv2D(1, 1)(x)
x = Activation("sigmoid")(x)
model = Model(inputs, x)
return model
if __name__ == "__main__":
model = deeplabv3_plus((512, 512, 3))