-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathMnasnetEager.py
194 lines (139 loc) · 6.69 KB
/
MnasnetEager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, regularizers, activations
class Mnasnet(tf.keras.Model):
def __init__(self, num_classes, alpha=1, **kwargs):
super(Mnasnet, self).__init__(**kwargs)
self.blocks = []
self.conv_bn_initial = Conv_BN(filters=32*alpha, kernel_size=3, strides=2)
# Frist block (non-identity) Conv+ DepthwiseConv
self.conv1_block1 = depthwiseConv(depth_multiplier=1, kernel_size=3, strides=1)
self.bn1_block1 = layers.BatchNormalization(epsilon=1e-3, momentum=0.999)
self.relu1_block1 = layers.ReLU(max_value=6)
self.conv_bn_block_1 = Conv_BN(filters=16*alpha, kernel_size=1, strides=1)
# MBConv3 3x3
self.blocks.append(MBConv_idskip(input_filters=16*alpha, filters=24, kernel_size=3, strides=2,
filters_multiplier=3, alpha=alpha))
self.blocks.append(MBConv_idskip(input_filters=24*alpha, filters=24, kernel_size=3, strides=1,
filters_multiplier=3, alpha=alpha))
self.blocks.append(MBConv_idskip(input_filters=24*alpha, filters=24, kernel_size=3, strides=1,
filters_multiplier=3, alpha=alpha))
# MBConv3 5x5
self.blocks.append(MBConv_idskip(input_filters=24*alpha, filters=40, kernel_size=5, strides=2,
filters_multiplier=3, alpha=alpha))
self.blocks.append(MBConv_idskip(input_filters=40*alpha, filters=40, kernel_size=5, strides=1,
filters_multiplier=3, alpha=alpha))
self.blocks.append(MBConv_idskip(input_filters=40*alpha, filters=40, kernel_size=5, strides=1,
filters_multiplier=3, alpha=alpha))
# MBConv6 5x5
self.blocks.append(MBConv_idskip(input_filters=40*alpha, filters=80, kernel_size=5, strides=2,
filters_multiplier=6, alpha=alpha))
self.blocks.append(MBConv_idskip(input_filters=80*alpha, filters=80, kernel_size=5, strides=1,
filters_multiplier=6, alpha=alpha))
self.blocks.append(MBConv_idskip(input_filters=80*alpha, filters=80, kernel_size=5, strides=1,
filters_multiplier=6, alpha=alpha))
# MBConv6 3x3
self.blocks.append(MBConv_idskip(input_filters=80*alpha, filters=96, kernel_size=3, strides=1,
filters_multiplier=6, alpha=alpha))
self.blocks.append(MBConv_idskip(input_filters=96*alpha, filters=96, kernel_size=3, strides=1,
filters_multiplier=6, alpha=alpha))
# MBConv6 5x5
self.blocks.append(MBConv_idskip(input_filters=96*alpha, filters=192, kernel_size=5, strides=2,
filters_multiplier=6, alpha=alpha))
self.blocks.append(MBConv_idskip(input_filters=192*alpha, filters=192, kernel_size=5, strides=1,
filters_multiplier=6, alpha=alpha))
self.blocks.append(MBConv_idskip(input_filters=192*alpha, filters=192, kernel_size=5, strides=1,
filters_multiplier=6, alpha=alpha))
self.blocks.append(MBConv_idskip(input_filters=192*alpha, filters=192, kernel_size=5, strides=1,
filters_multiplier=6, alpha=alpha))
# MBConv6 3x3
self.blocks.append(MBConv_idskip(input_filters=192*alpha, filters=320, kernel_size=3, strides=1,
filters_multiplier=6, alpha=alpha))
# Last convolution
self.conv_bn_last = Conv_BN(filters=1152*alpha, kernel_size=1, strides=1)
# Pool + FC
self.avg_pool = layers.GlobalAveragePooling2D()
self.fc = layers.Dense(num_classes)
def call(self, inputs, training=None, mask=None):
out = self.conv_bn_initial(inputs, training=training)
out = self.conv1_block1(out)
out = self.bn1_block1(out, training=training)
out = self.relu1_block1(out)
out = self.conv_bn_block_1(out, training=training)
# forward pass through all the blocks
for block in self.blocks:
out = block(out, training=training)
out = self.conv_bn_last(out, training=training)
out = self.avg_pool(out)
out = self.fc(out)
'''
You could return several outputs, even intermediate outputs
'''
return out
class MBConv_idskip(tf.keras.Model):
def __init__(self, input_filters, filters, kernel_size, strides=1, filters_multiplier=1, alpha=1):
super(MBConv_idskip, self).__init__()
self.filters = filters
self.kernel_size = kernel_size
self.strides = strides
self.filters_multiplier = filters_multiplier
self.alpha = alpha
self.depthwise_conv_filters = _make_divisible(input_filters)
self.pointwise_conv_filters = _make_divisible(filters * alpha)
#conv1
self.conv_bn1 = Conv_BN(filters=self.depthwise_conv_filters*filters_multiplier, kernel_size=1, strides=1)
#depthwiseconv2
self.depthwise_conv = depthwiseConv(depth_multiplier=1, kernel_size=kernel_size, strides=strides)
self.bn = layers.BatchNormalization(epsilon=1e-3, momentum=0.999)
self.relu = layers.ReLU(max_value=6)
#conv3
self.conv_bn2 = Conv_BN(filters=self.pointwise_conv_filters, kernel_size=1, strides=1)
def call(self, inputs, training=None):
x = self.conv_bn1(inputs, training=training)
x = self.depthwise_conv(x)
x = self.bn(x, training=training)
x = self.relu(x)
x = self.conv_bn2(x, training=training, activation=False)
# Residual/Identity connection if possible
if self.strides==1 and x.shape[3] == inputs.shape[3]:
return layers.add([inputs, x])
else:
return x
class Conv_BN(tf.keras.Model):
def __init__(self, filters, kernel_size, strides=1):
super(Conv_BN, self).__init__()
self.filters = filters
self.kernel_size = kernel_size
self.strides = strides
self.conv = conv(filters=filters, kernel_size=kernel_size, strides=strides)
self.bn = layers.BatchNormalization(epsilon=1e-3, momentum=0.999)
self.relu = layers.ReLU(max_value=6)
def call(self, inputs, training=None, activation=True):
x = self.conv(inputs)
x = self.bn(x, training=training)
if activation:
x = self.relu(x)
return x
# convolution
def conv(filters, kernel_size, strides=1, dilation_rate=1, use_bias=False):
return layers.Conv2D(filters, kernel_size, strides=strides, padding='same', use_bias=use_bias,
kernel_regularizer=regularizers.l2(l=0.0003), dilation_rate=dilation_rate)
# Depthwise convolution
def depthwiseConv(kernel_size, strides=1, depth_multiplier=1, dilation_rate=1, use_bias=False):
return layers.DepthwiseConv2D(kernel_size, strides=strides, depth_multiplier=depth_multiplier,
padding='same', use_bias=use_bias, kernel_regularizer=regularizers.l2(l=0.0003),
dilation_rate=dilation_rate)
# dilation_rate
# This function is taken from the original tf repo.
# It ensures that all layers have a channel number that is divisible by 8
# It can be seen here:
# https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
def _make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v