1
+ import numpy as np
2
+ from keras .layers import Dense , Flatten , Dropout , GlobalAveragePooling2D , Input , Conv2D
3
+ from keras .layers import Activation , Multiply , Lambda , AveragePooling2D , MaxPooling2D , BatchNormalization
4
+ from keras .models import Model
5
+ from keras .utils import plot_model
6
+ from keras import backend as K
7
+
8
+ class AgenderSSRNet (Model ):
9
+ def __init__ (self , image_size ,stage_num ,lambda_local ,lambda_d ):
10
+ self .input_size = image_size
11
+ if K .image_dim_ordering () == "th" :
12
+ self .__channel_axis = 1
13
+ self .__input_shape = (3 , image_size , image_size )
14
+ else :
15
+ self .__channel_axis = - 1
16
+ self .__input_shape = (image_size , image_size , 3 )
17
+
18
+ self .__stage_num = stage_num
19
+ self .__lambda_local = lambda_local
20
+ self .__lambda_d = lambda_d
21
+
22
+ self .__x_layer1 = None
23
+ self .__x_layer2 = None
24
+ self .__x_layer3 = None
25
+ self .__x = None
26
+
27
+ self .__s_layer1 = None
28
+ self .__s_layer2 = None
29
+ self .__s_layer3 = None
30
+ self .__s = None
31
+
32
+ inputs = Input (shape = self .__input_shape )
33
+ self .__extraction_block (inputs )
34
+
35
+ pred_gender = self .__classifier_block (1 , 'gender' )
36
+ pred_age = self .__classifier_block (101 , 'age' )
37
+
38
+ super ().__init__ (inputs = inputs , outputs = [pred_gender , pred_age ], name = 'SSR_Net' )
39
+
40
+ def __extraction_block (self , inputs ):
41
+ x = Conv2D (32 ,(3 ,3 ))(inputs )
42
+ x = BatchNormalization (axis = self .__channel_axis )(x )
43
+ x = Activation ('relu' )(x )
44
+ self .__x_layer1 = AveragePooling2D (2 ,2 )(x )
45
+ x = Conv2D (32 ,(3 ,3 ))(self .__x_layer1 )
46
+ x = BatchNormalization (axis = self .__channel_axis )(x )
47
+ x = Activation ('relu' )(x )
48
+ self .__x_layer2 = AveragePooling2D (2 ,2 )(x )
49
+ x = Conv2D (32 ,(3 ,3 ))(self .__x_layer2 )
50
+ x = BatchNormalization (axis = self .__channel_axis )(x )
51
+ x = Activation ('relu' )(x )
52
+ self .__x_layer3 = AveragePooling2D (2 ,2 )(x )
53
+ x = Conv2D (32 ,(3 ,3 ))(self .__x_layer3 )
54
+ x = BatchNormalization (axis = self .__channel_axis )(x )
55
+ self .__x = Activation ('relu' )(x )
56
+ #-------------------------------------------------------------------------------------------------------------------------
57
+ s = Conv2D (16 ,(3 ,3 ))(inputs )
58
+ s = BatchNormalization (axis = self .__channel_axis )(s )
59
+ s = Activation ('tanh' )(s )
60
+ self .__s_layer1 = MaxPooling2D (2 ,2 )(s )
61
+ s = Conv2D (16 ,(3 ,3 ))(self .__s_layer1 )
62
+ s = BatchNormalization (axis = self .__channel_axis )(s )
63
+ s = Activation ('tanh' )(s )
64
+ self .__s_layer2 = MaxPooling2D (2 ,2 )(s )
65
+ s = Conv2D (16 ,(3 ,3 ))(self .__s_layer2 )
66
+ s = BatchNormalization (axis = self .__channel_axis )(s )
67
+ s = Activation ('tanh' )(s )
68
+ self .__s_layer3 = MaxPooling2D (2 ,2 )(s )
69
+ s = Conv2D (16 ,(3 ,3 ))(self .__s_layer3 )
70
+ s = BatchNormalization (axis = self .__channel_axis )(s )
71
+ self .__s = Activation ('tanh' )(s )
72
+
73
+ def __classifier_block (self , V , name ):
74
+ s_layer4 = Conv2D (10 ,(1 ,1 ),activation = 'relu' )(self .__s )
75
+ s_layer4 = Flatten ()(s_layer4 )
76
+ s_layer4_mix = Dropout (0.2 )(s_layer4 )
77
+ s_layer4_mix = Dense (units = self .__stage_num [0 ], activation = "relu" )(s_layer4_mix )
78
+
79
+ x_layer4 = Conv2D (10 ,(1 ,1 ),activation = 'relu' )(self .__x )
80
+ x_layer4 = Flatten ()(x_layer4 )
81
+ x_layer4_mix = Dropout (0.2 )(x_layer4 )
82
+ x_layer4_mix = Dense (units = self .__stage_num [0 ], activation = "relu" )(x_layer4_mix )
83
+
84
+ feat_s1_pre = Multiply ()([s_layer4 ,x_layer4 ])
85
+ delta_s1 = Dense (1 ,activation = 'tanh' ,name = name + '_delta_s1' )(feat_s1_pre )
86
+
87
+ feat_s1 = Multiply ()([s_layer4_mix ,x_layer4_mix ])
88
+ feat_s1 = Dense (2 * self .__stage_num [0 ],activation = 'relu' )(feat_s1 )
89
+ pred_s1 = Dense (units = self .__stage_num [0 ], activation = "relu" ,name = name + '_pred_stage1' )(feat_s1 )
90
+ local_s1 = Dense (units = self .__stage_num [0 ], activation = 'tanh' , name = name + '_local_delta_stage1' )(feat_s1 )
91
+ #-------------------------------------------------------------------------------------------------------------------------
92
+ s_layer2 = Conv2D (10 ,(1 ,1 ),activation = 'relu' )(self .__s_layer2 )
93
+ s_layer2 = MaxPooling2D (4 ,4 )(s_layer2 )
94
+ s_layer2 = Flatten ()(s_layer2 )
95
+ s_layer2_mix = Dropout (0.2 )(s_layer2 )
96
+ s_layer2_mix = Dense (self .__stage_num [1 ],activation = 'relu' )(s_layer2_mix )
97
+
98
+ x_layer2 = Conv2D (10 ,(1 ,1 ),activation = 'relu' )(self .__x_layer2 )
99
+ x_layer2 = AveragePooling2D (4 ,4 )(x_layer2 )
100
+ x_layer2 = Flatten ()(x_layer2 )
101
+ x_layer2_mix = Dropout (0.2 )(x_layer2 )
102
+ x_layer2_mix = Dense (self .__stage_num [1 ],activation = 'relu' )(x_layer2_mix )
103
+
104
+ feat_s2_pre = Multiply ()([s_layer2 ,x_layer2 ])
105
+ delta_s2 = Dense (1 ,activation = 'tanh' ,name = name + '_delta_s2' )(feat_s2_pre )
106
+
107
+ feat_s2 = Multiply ()([s_layer2_mix ,x_layer2_mix ])
108
+ feat_s2 = Dense (2 * self .__stage_num [1 ],activation = 'relu' )(feat_s2 )
109
+ pred_s2 = Dense (units = self .__stage_num [1 ], activation = "relu" ,name = name + '_pred_stage2' )(feat_s2 )
110
+ local_s2 = Dense (units = self .__stage_num [1 ], activation = 'tanh' , name = name + '_local_delta_stage2' )(feat_s2 )
111
+ #-------------------------------------------------------------------------------------------------------------------------
112
+ s_layer1 = Conv2D (10 ,(1 ,1 ),activation = 'relu' )(self .__s_layer1 )
113
+ s_layer1 = MaxPooling2D (8 ,8 )(s_layer1 )
114
+ s_layer1 = Flatten ()(s_layer1 )
115
+ s_layer1_mix = Dropout (0.2 )(s_layer1 )
116
+ s_layer1_mix = Dense (self .__stage_num [2 ],activation = 'relu' )(s_layer1_mix )
117
+
118
+ x_layer1 = Conv2D (10 ,(1 ,1 ),activation = 'relu' )(self .__x_layer1 )
119
+ x_layer1 = AveragePooling2D (8 ,8 )(x_layer1 )
120
+ x_layer1 = Flatten ()(x_layer1 )
121
+ x_layer1_mix = Dropout (0.2 )(x_layer1 )
122
+ x_layer1_mix = Dense (self .__stage_num [2 ],activation = 'relu' )(x_layer1_mix )
123
+
124
+ feat_s3_pre = Multiply ()([s_layer1 ,x_layer1 ])
125
+ delta_s3 = Dense (1 ,activation = 'tanh' ,name = name + '_delta_s3' )(feat_s3_pre )
126
+
127
+ feat_s3 = Multiply ()([s_layer1_mix ,x_layer1_mix ])
128
+ feat_s3 = Dense (2 * self .__stage_num [2 ],activation = 'relu' )(feat_s3 )
129
+ pred_s3 = Dense (units = self .__stage_num [2 ], activation = "relu" ,name = name + '_pred_stage3' )(feat_s3 )
130
+ local_s3 = Dense (units = self .__stage_num [2 ], activation = 'tanh' , name = name + '_local_delta_stage3' )(feat_s3 )
131
+ #-------------------------------------------------------------------------------------------------------------------------
132
+
133
+ def SSR_module (x ,s1 ,s2 ,s3 ,lambda_local ,lambda_d , V ):
134
+ a = x [0 ][:,0 ]* 0
135
+ b = x [0 ][:,0 ]* 0
136
+ c = x [0 ][:,0 ]* 0
137
+
138
+ for i in range (0 ,s1 ):
139
+ a = a + (i + lambda_local * x [6 ][:,i ])* x [0 ][:,i ]
140
+ a = K .expand_dims (a ,- 1 )
141
+ a = a / (s1 * (1 + lambda_d * x [3 ]))
142
+
143
+ for j in range (0 ,s2 ):
144
+ b = b + (j + lambda_local * x [7 ][:,j ])* x [1 ][:,j ]
145
+ b = K .expand_dims (b ,- 1 )
146
+ b = b / (s1 * (1 + lambda_d * x [3 ]))/ (s2 * (1 + lambda_d * x [4 ]))
147
+
148
+ for k in range (0 ,s3 ):
149
+ c = c + (k + lambda_local * x [8 ][:,k ])* x [2 ][:,k ]
150
+ c = K .expand_dims (c ,- 1 )
151
+ c = c / (s1 * (1 + lambda_d * x [3 ]))/ (s2 * (1 + lambda_d * x [4 ]))/ (s3 * (1 + lambda_d * x [5 ]))
152
+
153
+
154
+ out = (a + b + c )* V
155
+ return out
156
+
157
+ pred = Lambda (SSR_module ,
158
+ arguments = {'s1' :self .__stage_num [0 ],
159
+ 's2' :self .__stage_num [1 ],
160
+ 's3' :self .__stage_num [2 ],
161
+ 'lambda_local' :self .__lambda_local ,
162
+ 'lambda_d' :self .__lambda_d ,
163
+ 'V' :V },
164
+ name = name + '_prediction' )([pred_s1 ,pred_s2 ,pred_s3 ,delta_s1 ,delta_s2 ,delta_s3 , local_s1 , local_s2 , local_s3 ])
165
+ return pred
166
+
167
+ def prep_phase1 (self ):
168
+ pass
169
+
170
+ def prep_phase2 (self ):
171
+ pass
172
+
173
+ @staticmethod
174
+ def decode_prediction (prediction ):
175
+ gender_predicted = np .around (prediction [0 ]).astype ('int' ).squeeze ()
176
+ age_predicted = prediction [1 ].squeeze ()
177
+ return gender_predicted , age_predicted
178
+
179
+ @staticmethod
180
+ def prep_image (data ):
181
+ data = data .astype ('float16' )
182
+ return data
183
+
184
+ if __name__ == '__main__' :
185
+ model = AgenderSSRNet (64 , [3 ,3 ,3 ], 1.0 , 1.0 )
186
+ print (model .summary ())
0 commit comments