Skip to content

Commit 46ae7ad

Browse files
committed
migrate to TF-2.0.0-alpha; now keras model works with BS=1024 + TP=1200; use keras CB to perform logging and evaluate during training; move arg parsing to config; fix AttentionLayer so mask will be input; use logger in classes (instead of prints)
1 parent 365bc36 commit 46ae7ad

11 files changed

+459
-321
lines changed

code2vec.py

+11-50
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,20 @@
11
from vocabularies import VocabType
22
from config import Config
3-
from argparse import ArgumentParser
43
from interactive_predict import InteractivePredictor
54
from model_base import Code2VecModelBase
6-
import sys
75

86

97
def load_model_dynamically(config: Config) -> Code2VecModelBase:
8+
assert config.DL_FRAMEWORK in {'tensorflow', 'keras'}
109
if config.DL_FRAMEWORK == 'tensorflow':
1110
from tensorflow_model import Code2VecModel
1211
elif config.DL_FRAMEWORK == 'keras':
1312
from keras_model import Code2VecModel
14-
else:
15-
raise ValueError("config.DL_FRAMEWORK must be in {'tensorflow', 'keras'}.")
1613
return Code2VecModel(config)
1714

1815

1916
if __name__ == '__main__':
20-
# TODO: move args parsing to config!
21-
parser = ArgumentParser()
22-
parser.add_argument("-d", "--data", dest="data_path",
23-
help="path to preprocessed dataset", required=False)
24-
parser.add_argument("-te", "--test", dest="test_path",
25-
help="path to test file", metavar="FILE", required=False)
26-
27-
is_training = '--train' in sys.argv or '-tr' in sys.argv
28-
parser.add_argument("-s", "--save", dest="save_path",
29-
help="path to save the model file", metavar="FILE", required=False)
30-
parser.add_argument("-w2v", "--save_word2v", dest="save_w2v",
31-
help="path to save the tokens embeddings file", metavar="FILE", required=False)
32-
parser.add_argument("-t2v", "--save_target2v", dest="save_t2v",
33-
help="path to save the targets embeddings file", metavar="FILE", required=False)
34-
parser.add_argument("-l", "--load", dest="load_path",
35-
help="path to load the model from", metavar="FILE", required=False)
36-
parser.add_argument('--save_w2v', dest='save_w2v', required=False,
37-
help="save word (token) vectors in word2vec format")
38-
parser.add_argument('--save_t2v', dest='save_t2v', required=False,
39-
help="save target vectors in word2vec format")
40-
parser.add_argument('--export_code_vectors', action='store_true', required=False,
41-
help="export code vectors for the given examples")
42-
parser.add_argument('--release', action='store_true',
43-
help='if specified and loading a trained model, release the loaded model for a lower model '
44-
'size.')
45-
parser.add_argument('--predict', action='store_true',
46-
help='execute the interactive prediction shell')
47-
parser.add_argument("-fw", "--framework", dest="dl_framework", choices=['keras', 'tensorflow'],
48-
default='tensorflow', help="deep learning framework to use.")
49-
parser.add_argument("-v", "--verbose", dest="verbose_mode", type=int, required=False, default=1,
50-
help="verbose mode (should be in {0,1,2}).")
51-
parser.add_argument("-lp", "--logs-path", dest="logs_path", metavar="FILE", required=False,
52-
help="path to store logs into. if not given logs are not saved to file.")
53-
parser.add_argument('-tb', '--tensorboard', dest='use_tensorboard', action='store_true',
54-
help='use tensorboard during training')
55-
args = parser.parse_args()
56-
57-
config = Config.get_default_config(args)
17+
config = Config(set_defaults=True, load_from_args=True, verify=True)
5818

5919
model = load_model_dynamically(config)
6020
print('Created model')
@@ -65,17 +25,18 @@ def load_model_dynamically(config: Config) -> Code2VecModelBase:
6525

6626
if config.is_training:
6727
model.train()
68-
if args.save_w2v is not None:
69-
model.save_word2vec_format(args.save_w2v, VocabType.Token)
70-
print('Origin word vectors saved in word2vec text format in: %s' % args.save_w2v)
71-
if args.save_t2v is not None:
72-
model.save_word2vec_format(args.save_t2v, VocabType.Target)
73-
print('Target word vectors saved in word2vec text format in: %s' % args.save_t2v)
28+
if config.SAVE_W2V is not None:
29+
model.save_word2vec_format(config.SAVE_W2V, VocabType.Token)
30+
config.log('Origin word vectors saved in word2vec text format in: %s' % config.SAVE_W2V)
31+
if config.SAVE_T2V is not None:
32+
model.save_word2vec_format(config.SAVE_T2V, VocabType.Target)
33+
config.log('Target word vectors saved in word2vec text format in: %s' % config.SAVE_T2V)
7434
if config.is_testing and not config.is_training:
7535
eval_results = model.evaluate()
7636
if eval_results is not None:
77-
print(str(eval_results).replace('topk', 'top{}'.format(config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION)))
78-
if args.predict:
37+
config.log(
38+
str(eval_results).replace('topk', 'top{}'.format(config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION)))
39+
if config.PREDICT:
7940
predictor = InteractivePredictor(config, model)
8041
predictor.predict()
8142
model.close_session()

config.py

+127-43
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,100 @@
11
from math import ceil
2+
from typing import Optional
3+
import logging
4+
from argparse import ArgumentParser
25

36

47
class Config:
5-
@staticmethod
6-
def get_default_config(args):
7-
config = Config()
8-
9-
config.NUM_TRAIN_EPOCHS = 20
10-
config.SAVE_EVERY_EPOCHS = 1
11-
config.TRAIN_BATCH_SIZE = 1024
12-
config.TEST_BATCH_SIZE = config.TRAIN_BATCH_SIZE
13-
config.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION = 10
14-
config.NUM_BATCHES_TO_LOG = 100
15-
config.READER_NUM_PARALLEL_BATCHES = 6 # cpu cores [for tf.contrib.data.map_and_batch() in the reader]
16-
config.SHUFFLE_BUFFER_SIZE = 10000
17-
config.CSV_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB
8+
@classmethod
9+
def arguments_parser(cls) -> ArgumentParser:
10+
parser = ArgumentParser()
11+
parser.add_argument("-d", "--data", dest="data_path",
12+
help="path to preprocessed dataset", required=False)
13+
parser.add_argument("-te", "--test", dest="test_path",
14+
help="path to test file", metavar="FILE", required=False)
15+
parser.add_argument("-s", "--save", dest="save_path",
16+
help="path to save the model file", metavar="FILE", required=False)
17+
parser.add_argument("-w2v", "--save_word2v", dest="save_w2v",
18+
help="path to save the tokens embeddings file", metavar="FILE", required=False)
19+
parser.add_argument("-t2v", "--save_target2v", dest="save_t2v",
20+
help="path to save the targets embeddings file", metavar="FILE", required=False)
21+
parser.add_argument("-l", "--load", dest="load_path",
22+
help="path to load the model from", metavar="FILE", required=False)
23+
parser.add_argument('--save_w2v', dest='save_w2v', required=False,
24+
help="save word (token) vectors in word2vec format")
25+
parser.add_argument('--save_t2v', dest='save_t2v', required=False,
26+
help="save target vectors in word2vec format")
27+
parser.add_argument('--export_code_vectors', action='store_true', required=False,
28+
help="export code vectors for the given examples")
29+
parser.add_argument('--release', action='store_true',
30+
help='if specified and loading a trained model, release the loaded model for a lower model '
31+
'size.')
32+
parser.add_argument('--predict', action='store_true',
33+
help='execute the interactive prediction shell')
34+
parser.add_argument("-fw", "--framework", dest="dl_framework", choices=['keras', 'tensorflow'],
35+
default='tensorflow', help="deep learning framework to use.")
36+
parser.add_argument("-v", "--verbose", dest="verbose_mode", type=int, required=False, default=1,
37+
help="verbose mode (should be in {0,1,2}).")
38+
parser.add_argument("-lp", "--logs-path", dest="logs_path", metavar="FILE", required=False,
39+
help="path to store logs into. if not given logs are not saved to file.")
40+
parser.add_argument('-tb', '--tensorboard', dest='use_tensorboard', action='store_true',
41+
help='use tensorboard during training')
42+
return parser
43+
44+
def set_defaults(self):
45+
self.NUM_TRAIN_EPOCHS = 20
46+
self.SAVE_EVERY_EPOCHS = 1
47+
self.TRAIN_BATCH_SIZE = 1024
48+
self.TEST_BATCH_SIZE = self.TRAIN_BATCH_SIZE
49+
self.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION = 10
50+
self.NUM_TRAIN_BATCHES_TO_LOG_PROGRESS = 100
51+
self.NUM_TRAIN_BATCHES_TO_EVALUATE = 1000
52+
self.READER_NUM_PARALLEL_BATCHES = 6 # cpu cores [for tf.contrib.data.map_and_batch() in the reader]
53+
self.SHUFFLE_BUFFER_SIZE = 10000
54+
self.CSV_BUFFER_SIZE = 100 * 1024 * 1024 # 100 MB
55+
self.MAX_TO_KEEP = 10
1856

1957
# model hyper-params
20-
config.MAX_CONTEXTS = 200
21-
config.MAX_TOKEN_VOCAB_SIZE = 1301136
22-
config.MAX_TARGET_VOCAB_SIZE = 261245
23-
config.MAX_PATH_VOCAB_SIZE = 911417
24-
config.DEFAULT_EMBEDDINGS_SIZE = 128
25-
config.TOKEN_EMBEDDINGS_SIZE = config.DEFAULT_EMBEDDINGS_SIZE
26-
config.PATH_EMBEDDINGS_SIZE = config.DEFAULT_EMBEDDINGS_SIZE
27-
config.CODE_VECTOR_SIZE = config.context_vector_size
28-
config.TARGET_EMBEDDINGS_SIZE = config.CODE_VECTOR_SIZE
29-
config.MAX_TO_KEEP = 10
30-
config.DROPOUT_KEEP_RATE = 0.75
31-
58+
self.MAX_CONTEXTS = 200
59+
self.MAX_TOKEN_VOCAB_SIZE = 1301136
60+
self.MAX_TARGET_VOCAB_SIZE = 261245
61+
self.MAX_PATH_VOCAB_SIZE = 911417
62+
self.DEFAULT_EMBEDDINGS_SIZE = 128
63+
self.TOKEN_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE
64+
self.PATH_EMBEDDINGS_SIZE = self.DEFAULT_EMBEDDINGS_SIZE
65+
self.CODE_VECTOR_SIZE = self.context_vector_size
66+
self.TARGET_EMBEDDINGS_SIZE = self.CODE_VECTOR_SIZE
67+
self.DROPOUT_KEEP_RATE = 0.75
68+
69+
def load_from_args(self):
70+
args = self.arguments_parser().parse_args()
3271
# Automatically filled, do not edit:
33-
config.TRAIN_DATA_PATH_PREFIX = args.data_path
34-
config.TEST_DATA_PATH = args.test_path
35-
config.MODEL_SAVE_PATH = args.save_path
36-
config.MODEL_LOAD_PATH = args.load_path
37-
config.RELEASE = args.release
38-
config.EXPORT_CODE_VECTORS = args.export_code_vectors
39-
config.VERBOSE_MODE = args.verbose_mode
40-
config.LOGS_PATH = args.logs_path
41-
config.DL_FRAMEWORK = 'tensorflow' if not args.dl_framework else args.dl_framework
42-
config.USE_TENSORBOARD = args.use_tensorboard
43-
44-
return config
45-
46-
def __init__(self):
72+
self.PREDICT = args.predict
73+
self.MODEL_SAVE_PATH = args.save_path
74+
self.MODEL_LOAD_PATH = args.load_path
75+
self.TRAIN_DATA_PATH_PREFIX = args.data_path
76+
self.TEST_DATA_PATH = args.test_path
77+
self.RELEASE = args.release
78+
self.EXPORT_CODE_VECTORS = args.export_code_vectors
79+
self.SAVE_W2V = args.save_w2v
80+
self.SAVE_T2V = args.save_t2v
81+
self.VERBOSE_MODE = args.verbose_mode
82+
self.LOGS_PATH = args.logs_path
83+
self.DL_FRAMEWORK = 'tensorflow' if not args.dl_framework else args.dl_framework
84+
self.USE_TENSORBOARD = args.use_tensorboard
85+
86+
def __init__(self, set_defaults: bool = False, load_from_args: bool = False, verify: bool = False):
4787
self.NUM_TRAIN_EPOCHS: int = 0
4888
self.SAVE_EVERY_EPOCHS: int = 0
4989
self.TRAIN_BATCH_SIZE: int = 0
5090
self.TEST_BATCH_SIZE: int = 0
5191
self.TOP_K_WORDS_CONSIDERED_DURING_PREDICTION: int = 0
52-
self.NUM_BATCHES_TO_LOG: int = 0
92+
self.NUM_TRAIN_BATCHES_TO_LOG_PROGRESS: int = 0 # TODO: update README;
93+
self.NUM_TRAIN_BATCHES_TO_EVALUATE: int = 0 # TODO: update README; update tensorflow_model to use it
5394
self.READER_NUM_PARALLEL_BATCHES: int = 0
5495
self.SHUFFLE_BUFFER_SIZE: int = 0
5596
self.CSV_BUFFER_SIZE: int = 0
97+
self.MAX_TO_KEEP: int = 0
5698

5799
# model hyper-params
58100
self.MAX_CONTEXTS: int = 0
@@ -64,16 +106,18 @@ def __init__(self):
64106
self.PATH_EMBEDDINGS_SIZE: int = 0
65107
self.CODE_VECTOR_SIZE: int = 0
66108
self.TARGET_EMBEDDINGS_SIZE: int = 0
67-
self.MAX_TO_KEEP: int = 0
68109
self.DROPOUT_KEEP_RATE: float = 0
69110

70111
# Automatically filled by `args`.
112+
self.PREDICT: bool = False # TODO: update README;
71113
self.MODEL_SAVE_PATH: str = ''
72114
self.MODEL_LOAD_PATH: str = ''
73115
self.TRAIN_DATA_PATH_PREFIX: str = ''
74116
self.TEST_DATA_PATH: str = ''
75117
self.RELEASE: bool = False
76118
self.EXPORT_CODE_VECTORS: bool = False
119+
self.SAVE_W2V: Optional[str] = None # TODO: update README;
120+
self.SAVE_T2V: Optional[str] = None # TODO: update README;
77121
self.VERBOSE_MODE: int = 0
78122
self.LOGS_PATH: str = ''
79123
self.DL_FRAMEWORK: str = '' # in {'keras', 'tensorflow'}
@@ -83,6 +127,15 @@ def __init__(self):
83127
self.NUM_TRAIN_EXAMPLES: int = 0
84128
self.NUM_TEST_EXAMPLES: int = 0
85129

130+
self.__logger: Optional[logging.Logger] = None
131+
132+
if set_defaults:
133+
self.set_defaults()
134+
if load_from_args:
135+
self.load_from_args()
136+
if verify:
137+
self.verify()
138+
86139
@property
87140
def context_vector_size(self) -> int:
88141
# The context vector is actually a concatenation of the embedded
@@ -106,7 +159,7 @@ def train_steps_per_epoch(self) -> int:
106159
return ceil(self.NUM_TRAIN_EXAMPLES / self.TRAIN_BATCH_SIZE) if self.TRAIN_BATCH_SIZE else 0
107160

108161
@property
109-
def test_steps_per_epoch(self) -> int:
162+
def test_steps(self) -> int:
110163
return ceil(self.NUM_TEST_EXAMPLES / self.TEST_BATCH_SIZE) if self.TEST_BATCH_SIZE else 0
111164

112165
def data_path(self, is_evaluating: bool = False):
@@ -155,15 +208,46 @@ def model_weights_save_path(self):
155208
def verify(self):
156209
if not self.is_training and not self.is_loading:
157210
raise ValueError("Must train or load a model.")
211+
if self.DL_FRAMEWORK not in {'tensorflow', 'keras'}:
212+
raise ValueError("config.DL_FRAMEWORK must be in {'tensorflow', 'keras'}.")
158213

159214
def __iter__(self):
160215
for attr_name in dir(self):
161216
if attr_name.startswith("__"):
162217
continue
163218
try:
164-
attr_value = getattr(self, attr_name)
219+
attr_value = getattr(self, attr_name, None)
165220
except:
166221
attr_value = None
167222
if callable(attr_value):
168223
continue
169224
yield attr_name, attr_value
225+
226+
def get_logger(self) -> logging.Logger:
227+
if self.__logger is None:
228+
self.__logger = logging.getLogger('code2vec')
229+
self.__logger.setLevel(logging.INFO)
230+
old_handlers = list(self.__logger.handlers)
231+
for handler in old_handlers:
232+
self.__logger.removeHandler(handler)
233+
ch = logging.StreamHandler()
234+
ch.setLevel(logging.INFO)
235+
formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
236+
ch.setFormatter(formatter)
237+
self.__logger.addHandler(ch)
238+
if self.LOGS_PATH:
239+
# logging.basicConfig(
240+
# filename=self.LOGS_PATH,
241+
# level=logging.INFO,
242+
# format='%(asctime)s %(levelname)-8s %(message)s',
243+
# datefmt='%Y-%m-%d %H:%M:%S'
244+
# )
245+
fh = logging.FileHandler(self.LOGS_PATH)
246+
fh.setLevel(logging.INFO)
247+
fh.setFormatter(formatter)
248+
self.__logger.addHandler(fh)
249+
250+
return self.__logger
251+
252+
def log(self, msg):
253+
self.get_logger().info(msg)

keras_attention_layer.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,15 @@ class AttentionLayer(Layer):
99
def __init__(self, **kwargs):
1010
super(AttentionLayer, self).__init__(**kwargs)
1111

12-
def build(self, input_shape):
12+
def build(self, inputs_shape):
13+
inputs_shape = inputs_shape if isinstance(inputs_shape, list) else [inputs_shape]
14+
15+
if len(inputs_shape) < 1 or len(inputs_shape) > 2:
16+
raise ValueError("AttentionLayer expect one or two inputs.")
17+
18+
# The first (and required) input is the actual input to the layer
19+
input_shape = inputs_shape[0]
20+
1321
# Expected input shape consists of a triplet: (batch, input_length, input_dim)
1422
if len(input_shape) != 3:
1523
raise ValueError("Input shape for AttentionLayer should be of 3 dimension.")
@@ -26,16 +34,23 @@ def build(self, input_shape):
2634
dtype=tf.float32)
2735
super(AttentionLayer, self).build(input_shape)
2836

29-
def call(self, inputs, mask: Optional[tf.Tensor] = None, **kwargs):
37+
def call(self, inputs, **kwargs):
38+
inputs = inputs if isinstance(inputs, list) else [inputs]
39+
40+
if len(inputs) < 1 or len(inputs) > 2:
41+
raise ValueError("AttentionLayer expect one or two inputs.")
42+
43+
actual_input = inputs[0]
44+
mask = inputs[1] if len(inputs) > 1 else None
3045
if mask is not None and not (((len(mask.shape) == 3 and mask.shape[2] == 1) or len(mask.shape) == 2)
3146
and mask.shape[1] == self.input_length):
3247
raise ValueError("`mask` should be of shape (batch, input_length) or (batch, input_length, 1) "
3348
"when calling an AttentionLayer.")
3449

35-
assert inputs.shape[-1] == self.attention_param.shape[0]
50+
assert actual_input.shape[-1] == self.attention_param.shape[0]
3651

3752
# (batch, input_length, input_dim) * (input_dim, 1) ==> (batch, input_length, 1)
38-
attention_weights = K.dot(inputs, self.attention_param)
53+
attention_weights = K.dot(actual_input, self.attention_param)
3954

4055
if mask is not None:
4156
if len(mask.shape) == 2:
@@ -44,7 +59,7 @@ def call(self, inputs, mask: Optional[tf.Tensor] = None, **kwargs):
4459
attention_weights += mask
4560

4661
attention_weights = K.softmax(attention_weights, axis=1) # (batch, input_length, 1)
47-
result = K.sum(inputs * attention_weights, axis=1) # (batch, input_length) [multiplication uses broadcast]
62+
result = K.sum(actual_input * attention_weights, axis=1) # (batch, input_length) [multiplication uses broadcast]
4863
return result, attention_weights
4964

5065
def compute_output_shape(self, input_shape):

0 commit comments

Comments
 (0)