1
1
from math import ceil
2
+ from typing import Optional
3
+ import logging
4
+ from argparse import ArgumentParser
2
5
3
6
4
7
class Config :
5
- @staticmethod
6
- def get_default_config (args ):
7
- config = Config ()
8
-
9
- config .NUM_TRAIN_EPOCHS = 20
10
- config .SAVE_EVERY_EPOCHS = 1
11
- config .TRAIN_BATCH_SIZE = 1024
12
- config .TEST_BATCH_SIZE = config .TRAIN_BATCH_SIZE
13
- config .TOP_K_WORDS_CONSIDERED_DURING_PREDICTION = 10
14
- config .NUM_BATCHES_TO_LOG = 100
15
- config .READER_NUM_PARALLEL_BATCHES = 6 # cpu cores [for tf.contrib.data.map_and_batch() in the reader]
16
- config .SHUFFLE_BUFFER_SIZE = 10000
17
- config .CSV_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB
8
+ @classmethod
9
+ def arguments_parser (cls ) -> ArgumentParser :
10
+ parser = ArgumentParser ()
11
+ parser .add_argument ("-d" , "--data" , dest = "data_path" ,
12
+ help = "path to preprocessed dataset" , required = False )
13
+ parser .add_argument ("-te" , "--test" , dest = "test_path" ,
14
+ help = "path to test file" , metavar = "FILE" , required = False )
15
+ parser .add_argument ("-s" , "--save" , dest = "save_path" ,
16
+ help = "path to save the model file" , metavar = "FILE" , required = False )
17
+ parser .add_argument ("-w2v" , "--save_word2v" , dest = "save_w2v" ,
18
+ help = "path to save the tokens embeddings file" , metavar = "FILE" , required = False )
19
+ parser .add_argument ("-t2v" , "--save_target2v" , dest = "save_t2v" ,
20
+ help = "path to save the targets embeddings file" , metavar = "FILE" , required = False )
21
+ parser .add_argument ("-l" , "--load" , dest = "load_path" ,
22
+ help = "path to load the model from" , metavar = "FILE" , required = False )
23
+ parser .add_argument ('--save_w2v' , dest = 'save_w2v' , required = False ,
24
+ help = "save word (token) vectors in word2vec format" )
25
+ parser .add_argument ('--save_t2v' , dest = 'save_t2v' , required = False ,
26
+ help = "save target vectors in word2vec format" )
27
+ parser .add_argument ('--export_code_vectors' , action = 'store_true' , required = False ,
28
+ help = "export code vectors for the given examples" )
29
+ parser .add_argument ('--release' , action = 'store_true' ,
30
+ help = 'if specified and loading a trained model, release the loaded model for a lower model '
31
+ 'size.' )
32
+ parser .add_argument ('--predict' , action = 'store_true' ,
33
+ help = 'execute the interactive prediction shell' )
34
+ parser .add_argument ("-fw" , "--framework" , dest = "dl_framework" , choices = ['keras' , 'tensorflow' ],
35
+ default = 'tensorflow' , help = "deep learning framework to use." )
36
+ parser .add_argument ("-v" , "--verbose" , dest = "verbose_mode" , type = int , required = False , default = 1 ,
37
+ help = "verbose mode (should be in {0,1,2})." )
38
+ parser .add_argument ("-lp" , "--logs-path" , dest = "logs_path" , metavar = "FILE" , required = False ,
39
+ help = "path to store logs into. if not given logs are not saved to file." )
40
+ parser .add_argument ('-tb' , '--tensorboard' , dest = 'use_tensorboard' , action = 'store_true' ,
41
+ help = 'use tensorboard during training' )
42
+ return parser
43
+
44
+ def set_defaults (self ):
45
+ self .NUM_TRAIN_EPOCHS = 20
46
+ self .SAVE_EVERY_EPOCHS = 1
47
+ self .TRAIN_BATCH_SIZE = 1024
48
+ self .TEST_BATCH_SIZE = self .TRAIN_BATCH_SIZE
49
+ self .TOP_K_WORDS_CONSIDERED_DURING_PREDICTION = 10
50
+ self .NUM_TRAIN_BATCHES_TO_LOG_PROGRESS = 100
51
+ self .NUM_TRAIN_BATCHES_TO_EVALUATE = 1000
52
+ self .READER_NUM_PARALLEL_BATCHES = 6 # cpu cores [for tf.contrib.data.map_and_batch() in the reader]
53
+ self .SHUFFLE_BUFFER_SIZE = 10000
54
+ self .CSV_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB
55
+ self .MAX_TO_KEEP = 10
18
56
19
57
# model hyper-params
20
- config .MAX_CONTEXTS = 200
21
- config .MAX_TOKEN_VOCAB_SIZE = 1301136
22
- config .MAX_TARGET_VOCAB_SIZE = 261245
23
- config .MAX_PATH_VOCAB_SIZE = 911417
24
- config .DEFAULT_EMBEDDINGS_SIZE = 128
25
- config .TOKEN_EMBEDDINGS_SIZE = config .DEFAULT_EMBEDDINGS_SIZE
26
- config .PATH_EMBEDDINGS_SIZE = config .DEFAULT_EMBEDDINGS_SIZE
27
- config .CODE_VECTOR_SIZE = config .context_vector_size
28
- config .TARGET_EMBEDDINGS_SIZE = config .CODE_VECTOR_SIZE
29
- config .MAX_TO_KEEP = 10
30
- config .DROPOUT_KEEP_RATE = 0.75
31
-
58
+ self .MAX_CONTEXTS = 200
59
+ self .MAX_TOKEN_VOCAB_SIZE = 1301136
60
+ self .MAX_TARGET_VOCAB_SIZE = 261245
61
+ self .MAX_PATH_VOCAB_SIZE = 911417
62
+ self .DEFAULT_EMBEDDINGS_SIZE = 128
63
+ self .TOKEN_EMBEDDINGS_SIZE = self .DEFAULT_EMBEDDINGS_SIZE
64
+ self .PATH_EMBEDDINGS_SIZE = self .DEFAULT_EMBEDDINGS_SIZE
65
+ self .CODE_VECTOR_SIZE = self .context_vector_size
66
+ self .TARGET_EMBEDDINGS_SIZE = self .CODE_VECTOR_SIZE
67
+ self .DROPOUT_KEEP_RATE = 0.75
68
+
69
+ def load_from_args (self ):
70
+ args = self .arguments_parser ().parse_args ()
32
71
# Automatically filled, do not edit:
33
- config .TRAIN_DATA_PATH_PREFIX = args .data_path
34
- config .TEST_DATA_PATH = args .test_path
35
- config .MODEL_SAVE_PATH = args .save_path
36
- config .MODEL_LOAD_PATH = args .load_path
37
- config .RELEASE = args .release
38
- config .EXPORT_CODE_VECTORS = args .export_code_vectors
39
- config .VERBOSE_MODE = args .verbose_mode
40
- config .LOGS_PATH = args .logs_path
41
- config .DL_FRAMEWORK = 'tensorflow' if not args .dl_framework else args .dl_framework
42
- config .USE_TENSORBOARD = args .use_tensorboard
43
-
44
- return config
45
-
46
- def __init__ (self ):
72
+ self .PREDICT = args .predict
73
+ self .MODEL_SAVE_PATH = args .save_path
74
+ self .MODEL_LOAD_PATH = args .load_path
75
+ self .TRAIN_DATA_PATH_PREFIX = args .data_path
76
+ self .TEST_DATA_PATH = args .test_path
77
+ self .RELEASE = args .release
78
+ self .EXPORT_CODE_VECTORS = args .export_code_vectors
79
+ self .SAVE_W2V = args .save_w2v
80
+ self .SAVE_T2V = args .save_t2v
81
+ self .VERBOSE_MODE = args .verbose_mode
82
+ self .LOGS_PATH = args .logs_path
83
+ self .DL_FRAMEWORK = 'tensorflow' if not args .dl_framework else args .dl_framework
84
+ self .USE_TENSORBOARD = args .use_tensorboard
85
+
86
+ def __init__ (self , set_defaults : bool = False , load_from_args : bool = False , verify : bool = False ):
47
87
self .NUM_TRAIN_EPOCHS : int = 0
48
88
self .SAVE_EVERY_EPOCHS : int = 0
49
89
self .TRAIN_BATCH_SIZE : int = 0
50
90
self .TEST_BATCH_SIZE : int = 0
51
91
self .TOP_K_WORDS_CONSIDERED_DURING_PREDICTION : int = 0
52
- self .NUM_BATCHES_TO_LOG : int = 0
92
+ self .NUM_TRAIN_BATCHES_TO_LOG_PROGRESS : int = 0 # TODO: update README;
93
+ self .NUM_TRAIN_BATCHES_TO_EVALUATE : int = 0 # TODO: update README; update tensorflow_model to use it
53
94
self .READER_NUM_PARALLEL_BATCHES : int = 0
54
95
self .SHUFFLE_BUFFER_SIZE : int = 0
55
96
self .CSV_BUFFER_SIZE : int = 0
97
+ self .MAX_TO_KEEP : int = 0
56
98
57
99
# model hyper-params
58
100
self .MAX_CONTEXTS : int = 0
@@ -64,16 +106,18 @@ def __init__(self):
64
106
self .PATH_EMBEDDINGS_SIZE : int = 0
65
107
self .CODE_VECTOR_SIZE : int = 0
66
108
self .TARGET_EMBEDDINGS_SIZE : int = 0
67
- self .MAX_TO_KEEP : int = 0
68
109
self .DROPOUT_KEEP_RATE : float = 0
69
110
70
111
# Automatically filled by `args`.
112
+ self .PREDICT : bool = False # TODO: update README;
71
113
self .MODEL_SAVE_PATH : str = ''
72
114
self .MODEL_LOAD_PATH : str = ''
73
115
self .TRAIN_DATA_PATH_PREFIX : str = ''
74
116
self .TEST_DATA_PATH : str = ''
75
117
self .RELEASE : bool = False
76
118
self .EXPORT_CODE_VECTORS : bool = False
119
+ self .SAVE_W2V : Optional [str ] = None # TODO: update README;
120
+ self .SAVE_T2V : Optional [str ] = None # TODO: update README;
77
121
self .VERBOSE_MODE : int = 0
78
122
self .LOGS_PATH : str = ''
79
123
self .DL_FRAMEWORK : str = '' # in {'keras', 'tensorflow'}
@@ -83,6 +127,15 @@ def __init__(self):
83
127
self .NUM_TRAIN_EXAMPLES : int = 0
84
128
self .NUM_TEST_EXAMPLES : int = 0
85
129
130
+ self .__logger : Optional [logging .Logger ] = None
131
+
132
+ if set_defaults :
133
+ self .set_defaults ()
134
+ if load_from_args :
135
+ self .load_from_args ()
136
+ if verify :
137
+ self .verify ()
138
+
86
139
@property
87
140
def context_vector_size (self ) -> int :
88
141
# The context vector is actually a concatenation of the embedded
@@ -106,7 +159,7 @@ def train_steps_per_epoch(self) -> int:
106
159
return ceil (self .NUM_TRAIN_EXAMPLES / self .TRAIN_BATCH_SIZE ) if self .TRAIN_BATCH_SIZE else 0
107
160
108
161
@property
109
- def test_steps_per_epoch (self ) -> int :
162
+ def test_steps (self ) -> int :
110
163
return ceil (self .NUM_TEST_EXAMPLES / self .TEST_BATCH_SIZE ) if self .TEST_BATCH_SIZE else 0
111
164
112
165
def data_path (self , is_evaluating : bool = False ):
@@ -155,15 +208,46 @@ def model_weights_save_path(self):
155
208
def verify (self ):
156
209
if not self .is_training and not self .is_loading :
157
210
raise ValueError ("Must train or load a model." )
211
+ if self .DL_FRAMEWORK not in {'tensorflow' , 'keras' }:
212
+ raise ValueError ("config.DL_FRAMEWORK must be in {'tensorflow', 'keras'}." )
158
213
159
214
def __iter__ (self ):
160
215
for attr_name in dir (self ):
161
216
if attr_name .startswith ("__" ):
162
217
continue
163
218
try :
164
- attr_value = getattr (self , attr_name )
219
+ attr_value = getattr (self , attr_name , None )
165
220
except :
166
221
attr_value = None
167
222
if callable (attr_value ):
168
223
continue
169
224
yield attr_name , attr_value
225
+
226
+ def get_logger (self ) -> logging .Logger :
227
+ if self .__logger is None :
228
+ self .__logger = logging .getLogger ('code2vec' )
229
+ self .__logger .setLevel (logging .INFO )
230
+ old_handlers = list (self .__logger .handlers )
231
+ for handler in old_handlers :
232
+ self .__logger .removeHandler (handler )
233
+ ch = logging .StreamHandler ()
234
+ ch .setLevel (logging .INFO )
235
+ formatter = logging .Formatter ('%(asctime)s %(levelname)-8s %(message)s' )
236
+ ch .setFormatter (formatter )
237
+ self .__logger .addHandler (ch )
238
+ if self .LOGS_PATH :
239
+ # logging.basicConfig(
240
+ # filename=self.LOGS_PATH,
241
+ # level=logging.INFO,
242
+ # format='%(asctime)s %(levelname)-8s %(message)s',
243
+ # datefmt='%Y-%m-%d %H:%M:%S'
244
+ # )
245
+ fh = logging .FileHandler (self .LOGS_PATH )
246
+ fh .setLevel (logging .INFO )
247
+ fh .setFormatter (formatter )
248
+ self .__logger .addHandler (fh )
249
+
250
+ return self .__logger
251
+
252
+ def log (self , msg ):
253
+ self .get_logger ().info (msg )
0 commit comments