Skip to content

Commit 2da01aa

Browse files
committed
TensorFlow for training with multiple api added allowing early/late fusion
1 parent 2bac683 commit 2da01aa

2 files changed

Lines changed: 166 additions & 35 deletions

File tree

ML/python_training.py

Lines changed: 149 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,13 @@
2525
from tensorflow import keras
2626
from tensorflow.keras import layers
2727
from tensorflow.keras.layers.experimental import preprocessing
28-
from tensorflow.keras.models import Sequential
29-
from tensorflow.keras.layers import Dense
28+
from tensorflow.keras.layers import Input, Dense
29+
from tensorflow.keras.models import Model
30+
from tensorflow.keras.utils import plot_model
31+
from tensorflow.keras.layers import concatenate
32+
from tensorflow.keras.layers import BatchNormalization
33+
from tensorflow.keras.layers import Dropout
34+
from tensorflow.keras.layers import LayerNormalization
3035

3136
import keras2onnx
3237
import tf2onnx
@@ -81,6 +86,7 @@ def get_external_params():
8186
global SETUP_PARAMS_KEY
8287
global OUTPUT_NAME_KEY
8388
global DEBUG_MODE_KEY
89+
global STRUCTURE_KEY
8490

8591
parser = argparse.ArgumentParser()
8692
parser.add_argument("config")
@@ -92,6 +98,7 @@ def get_external_params():
9298
MODEL_KEY = obj['MODEL']
9399
INPUT_NAME_KEY = obj['INPUT_NAME']
94100
LOAD_PARAMS_KEY = obj['LOAD_PARAMS']
101+
STRUCTURE_KEY = obj['STRUCTURE']
95102
CALC_PARAMS_KEY = obj['CALC_PARAMS']
96103
OBJECT_KEY = obj['OBJECT']
97104
HYPER_PARAM_KEY = obj['HYPER_PARAM']
@@ -180,32 +187,146 @@ def create_model_TensorFlowNN():
180187
import pickle
181188
global model
182189

183-
normalize = preprocessing.Normalization()
184-
185-
model = tf.keras.Sequential([normalize,
186-
layers.Dense(x_train.shape[1], use_bias=False),
187-
layers.Dense(x_train.shape[1], activation='sigmoid',use_bias=False),
188-
layers.Dense(x_train.shape[1], activation='sigmoid',use_bias=False),
189-
layers.Dense(x_train.shape[1], activation='sigmoid',use_bias=False),
190-
layers.Dense(x_train.shape[1], activation='sigmoid',use_bias=False),
191-
layers.Dense(x_train.shape[1], activation='sigmoid',use_bias=False),
192-
layers.Dense(x_train.shape[1], activation='sigmoid',use_bias=False),
193-
layers.Dense(x_train.shape[1], activation='sigmoid',use_bias=False),
194-
layers.Dense(x_train.shape[1], activation='sigmoid',use_bias=False),
195-
layers.Dense(x_train.shape[1], activation='sigmoid',use_bias=False),
196-
layers.Dense(1, activation='sigmoid')])
197-
198-
model.compile(loss='binary_crossentropy', optimizer=tf.optimizers.Adam(learning_rate=0.005), metrics=['accuracy'])
199-
200-
training_history = model.fit(x_train, y_train,
201-
epochs=1500,
202-
batch_size=1024,
203-
verbose=1,
204-
validation_data=(x_valid, y_valid))
190+
METRICS = [
191+
keras.metrics.BinaryAccuracy(name='accuracy'),
192+
keras.metrics.Precision(name='precision'),
193+
keras.metrics.Recall(name='recall'),
194+
keras.metrics.AUC(name='auc'),
195+
keras.metrics.AUC(name='prc', curve='PR')]
196+
197+
modelA=STRUCTURE_KEY['modelA']
198+
modelB=STRUCTURE_KEY['modelB']
199+
modelM=STRUCTURE_KEY['modelM']
200+
201+
modelA=np.array(modelA)
202+
modelB=np.array(modelB)
203+
modelM=np.array(modelM)
204+
205+
if (modelM.shape[0] != 0):
206+
MultipleInput = 1
207+
else:
208+
MultipleInput = 0
209+
210+
if (MultipleInput == 0):
211+
modelA_inputShape = int(modelA[0]);
212+
modelA_Normalization = int(modelA[1]);
213+
A1 = Input(shape = (modelA_inputShape,))
214+
lst = np.size(modelA);
215+
it= lst/3;
216+
if (modelA_Normalization == 1):
217+
A = LayerNormalization(axis=-1)(A1)
218+
A = Dense(int(modelA[3]), use_bias=False)(A)
219+
else:
220+
A = Dense(int(modelA[3]), use_bias=False)(A1)
221+
modelA_Shape = modelA[3::3]
222+
modelA_Normalization= modelA[4::3]
223+
modelA_Dropout = modelA[5::3]
224+
for i in range(0,int(it)-1):
225+
if ((i != 0) or (int(modelA_Shape[i]) != 0)):
226+
A = Dense(int(modelA_Shape[i]), activation='sigmoid',use_bias=False)(A)
227+
if (int(modelA_Normalization[i]) == 1):
228+
A = LayerNormalization(axis=-1)(A)
229+
if (modelA_Dropout[i] != 0.0):
230+
A = Dropout(modelA_Dropout[i])(A)
231+
A = Model(inputs=A1, outputs=A)
232+
233+
#print(A.summary())
234+
#plot_model(A,to_file="ModelStructA.png",show_shapes=True)
235+
236+
A.compile(loss='binary_crossentropy', optimizer=tf.optimizers.Adam(learning_rate=HYPER_PARAM_KEY['LEARN_RATE']), metrics=METRICS)
237+
training_history = A.fit(x_train, y_train,
238+
epochs=HYPER_PARAM_KEY['EPOCHS'],
239+
batch_size=HYPER_PARAM_KEY['BATCH_SIZE'],
240+
verbose=HYPER_PARAM_KEY['VERBOSE'],
241+
validation_split=HYPER_PARAM_KEY['VALIDATION_SPLIT'])
242+
243+
with open(OUTPUT_NAME_KEY['ORIGIN'],"wb") as f:
244+
A.save('tf_model')
245+
246+
else:
247+
modelA_inputShape = int(modelA[0]);
248+
modelA_Normalization = int(modelA[1]);
249+
A1 = Input(shape = (modelA_inputShape,))
250+
lst = np.size(modelA);
251+
it= lst/3;
252+
if (modelA_Normalization == 1):
253+
A = LayerNormalization(axis=-1)(A1)
254+
A = Dense(int(modelA[3]), use_bias=False)(A)
255+
else:
256+
A = Dense(int(modelA[3]), use_bias=False)(A1)
257+
modelA_Shape = modelA[3::3]
258+
modelA_Normalization= modelA[4::3]
259+
modelA_Dropout = modelA[5::3]
260+
261+
for i in range(0,int(it)-1):
262+
if ((i != 0) or (int(modelA_Shape[i]) != 0)):
263+
A = Dense(int(modelA_Shape[i]), activation='sigmoid',use_bias=False)(A)
264+
if (int(modelA_Normalization[i]) == 1):
265+
A = LayerNormalization(axis=-1)(A)
266+
if (modelA_Dropout[i] != 0.0):
267+
A = Dropout(modelA_Dropout[i])(A)
268+
A = Model(inputs=A1, outputs=A)
269+
270+
modelB_inputShape = int(modelB[0]);
271+
modelB_Normalization = int(modelB[1]);
272+
B1 = Input(shape = (modelB_inputShape,))
273+
lst = np.size(modelB);
274+
it= lst/3;
275+
if (modelB_Normalization == 1):
276+
B = LayerNormalization(axis=-1)(B1)
277+
B = Dense(int(modelB[3]), use_bias=False)(B)
278+
else:
279+
B = Dense(int(modelB[3]), use_bias=False)(B1)
280+
281+
modelB_Shape = modelB[3::3]
282+
modelB_Normalization= modelB[4::3]
283+
modelB_Dropout = modelB[5::3]
284+
for i in range(0,int(it)-1):
285+
if ((i != 0) or (int(modelB_Shape[i]) != 0)):
286+
B = Dense(int(modelB_Shape[i]), activation='sigmoid',use_bias=False)(B)
287+
if (int(modelB_Normalization[i]) == 1):
288+
B = LayerNormalization(axis=-1)(B)
289+
if (modelB_Dropout[i] != 0.0):
290+
B = Dropout(modelB_Dropout[i])(B)
291+
B = Model(inputs=B1, outputs=B)
292+
293+
M1 = concatenate([A.output,B.output])
294+
modelM_Normalization = int(modelM[1]);
295+
lst = np.size(modelM);
296+
it= lst/3;
297+
if (modelM_Normalization == 1):
298+
M = LayerNormalization(axis=-1)(M1)
299+
M = Dense(int(modelM[3]), use_bias=False)(M)
300+
else:
301+
M = Dense(int(modelM[3]), use_bias=False)(M1)
302+
modelM_Shape = modelM[3::3]
303+
modelM_Normalization= modelM[4::3]
304+
modelM_Dropout = modelM[5::3]
305+
for i in range(0,int(it)-1):
306+
if ((i != 0) or (int(modelM_Shape[i]) != 0)):
307+
M = Dense(int(modelM_Shape[i]), activation='sigmoid',use_bias=False)(M)
308+
if (int(modelM_Normalization[i]) == 1):
309+
M = LayerNormalization(axis=-1)(M)
310+
if (modelM_Dropout[i] != 0.0):
311+
M = Dropout(modelM_Dropout[i])(M)
312+
M = Dense(1, activation='sigmoid',use_bias=False)(M)
313+
full_mod = Model(inputs=[A.input, B.input], outputs=M)
314+
315+
#print(full_mod.summary())
316+
#plot_model(full_mod,to_file="ModelStruct.png",show_shapes=True)
317+
318+
x1_train, x2_train = np.hsplit(x_train, [int(modelA[0])]);
319+
320+
full_mod.compile(loss='binary_crossentropy', optimizer=tf.optimizers.Adam(learning_rate=HYPER_PARAM_KEY['LEARN_RATE']), metrics=METRICS)
321+
training_history = full_mod.fit(x=[x1_train,x2_train], y=y_train,
322+
epochs=HYPER_PARAM_KEY['EPOCHS'],
323+
batch_size=HYPER_PARAM_KEY['BATCH_SIZE'],
324+
verbose=HYPER_PARAM_KEY['VERBOSE'],
325+
validation_split=HYPER_PARAM_KEY['VALIDATION_SPLIT'])
326+
327+
with open(OUTPUT_NAME_KEY['ORIGIN'],"wb") as f:
328+
full_mod.save('tf_model')
205329

206-
with open(OUTPUT_NAME_KEY['ORIGIN'],"wb") as f:
207-
model.save('tf_model')
208-
#pickle.dump(model,f)
209330

210331
def convert_onnx_TensorFlowNN():
211332

@@ -461,6 +582,7 @@ def main():
461582
global SETUP_PARAMS_KEY
462583
global OUTPUT_NAME_KEY
463584
global DEBUG_MODE_KEY
585+
global STRUCTURE_KEY
464586

465587
global df
466588
global df_train

ML/training_configure_tfNN.yaml

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,20 @@ SETUP_PARAMS:
2020
HYPER_PARAM:
2121
TYPE: gbdt
2222
OBJECTIVE: binary
23-
LEARN_RATE: 0.01
24-
MAX_DEPTH: 5
25-
N_ESTIMATOR: 1000
26-
MAX_BIN: 20000
23+
LEARN_RATE: 0.0005
2724
METRIC: auc
25+
EPOCHS: 500
26+
BATCH_SIZE: 32
27+
VERBOSE: 1
28+
VALIDATION_SPLIT: 0.2
29+
30+
STRUCTURE:
31+
modelA:
32+
[]
33+
modelB:
34+
[]
35+
modelM:
36+
[]
2837

2938
LOAD_PARAMS:
3039
- MFT_X
@@ -61,14 +70,14 @@ CALC_PARAMS:
6170

6271
TRAIN_PARAMS:
6372
- MFT_X
64-
- MFT_Y
65-
- MFT_Phi
66-
- MFT_Tanl
67-
- MFT_InvQPt
6873
- MCH_X
74+
- MFT_Y
6975
- MCH_Y
76+
- MFT_Phi
7077
- MCH_Phi
78+
- MFT_Tanl
7179
- MCH_Tanl
80+
- MFT_InvQPt
7281
- MCH_InvQPt
7382
- MFT_TrackChi2
7483
- Delta_X

0 commit comments

Comments
 (0)