-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathmain.py
212 lines (193 loc) · 13 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
import os
import pickle
import argparse
import shutil
import math
import sys
import model
import ggtnn_train
import ggtnn_graph_parse
from ggtnn_graph_parse import MetadataList, PreppedStory
from util import *
def helper_trim(bucketed, desired_total):
"""Trim bucketed fairly so that it has desired_total things total"""
cur_total = sum(len(b) for b in bucketed)
keep_frac = desired_total/cur_total
if keep_frac > 1.0:
print("WARNING: Asked to trim to {} items, but was already only {} items. Keeping original length.".format(desired_total, cur_total))
return bucketed
keep_amts = [math.floor(len(b) * keep_frac) for b in bucketed]
tmp_total = sum(keep_amts)
addtl_to_add = desired_total - tmp_total
assert addtl_to_add >= 0
keep_amts = [x + (1 if i < addtl_to_add else 0) for i,x in enumerate(keep_amts)]
assert sum(keep_amts) == desired_total
trimmed_bucketed = [b[:amt] for b,amt in zip(bucketed, keep_amts)]
return trimmed_bucketed
def main(task_dir, output_format_str, state_width, process_repr_size, dynamic_nodes, mutable_nodes, wipe_node_state, direct_reference, propagate_intermediate, sequence_aggregate_repr, old_aggregate, train_with_graph, train_with_query, outputdir, num_updates, batch_size, learning_rate, dropout_keep, resume, resume_auto, visualize, visualize_snap, visualization_test, validation, validation_interval, evaluate_accuracy, check_mode, stop_at_accuracy, stop_at_loss, stop_at_overfitting, restrict_dataset, train_save_params, batch_adjust, set_exit_status, just_compile, autopickle, pickle_model, unpickle_model, interrupt_file):
output_format = model.ModelOutputFormat[output_format_str]
with open(os.path.join(task_dir,'metadata.p'),'rb') as f:
metadata = pickle.load(f)
with open(os.path.join(task_dir,'file_list.p'),'rb') as f:
bucketed = pickle.load(f)
bucketed = [[os.path.join(task_dir,x) for x in b] for b in bucketed]
if restrict_dataset is not None:
bucketed = helper_trim(bucketed, restrict_dataset)
sentence_length, new_nodes_per_iter, bucket_sizes, wordlist, anslist, graph_node_list, graph_edge_list = metadata
eff_anslist = ggtnn_train.get_effective_answer_words(anslist, output_format)
if validation is None:
validation_buckets = None
validation_bucket_sizes = None
else:
with open(os.path.join(validation,'metadata.p'),'rb') as f:
validation_metadata = pickle.load(f)
with open(os.path.join(validation,'file_list.p'),'rb') as f:
validation_buckets = pickle.load(f)
validation_buckets = [[os.path.join(validation,x) for x in b] for b in validation_buckets]
validation_bucket_sizes = validation_metadata[2]
if direct_reference:
word_node_mapping = {wi:ni for wi,word in enumerate(wordlist)
for ni,node in enumerate(graph_node_list)
if word == node}
else:
word_node_mapping = {}
model_kwargs = dict(num_input_words=len(wordlist),
num_output_words=len(eff_anslist),
num_node_ids=len(graph_node_list),
node_state_size=state_width,
num_edge_types=len(graph_edge_list),
input_repr_size=100,
output_repr_size=100,
propose_repr_size=process_repr_size,
propagate_repr_size=process_repr_size,
new_nodes_per_iter=new_nodes_per_iter,
output_format=output_format,
final_propagate=5,
word_node_mapping=word_node_mapping,
dynamic_nodes=dynamic_nodes,
nodes_mutable=mutable_nodes,
wipe_node_state=wipe_node_state,
intermediate_propagate=(5 if propagate_intermediate else 0),
sequence_representation=sequence_aggregate_repr,
dropout_keep=dropout_keep,
use_old_aggregate=old_aggregate,
best_node_match_only=True,
train_with_graph=train_with_graph,
train_with_query=train_with_query,
setup=True,
check_mode=check_mode)
model_kwargs = get_compatible_kwargs(model.Model, model_kwargs)
if autopickle is not None:
if not os.path.exists(autopickle):
os.makedirs(autopickle)
model_hash = object_hash(model_kwargs)
model_filename = os.path.join(autopickle, "model_{}.p".format(model_hash))
print("Looking for cached model at {}".format(model_filename))
if os.path.isfile(model_filename):
print("Loading model from cache")
m, stored_kwargs = pickle.load(open(model_filename, 'rb'))
assert model_kwargs == stored_kwargs, "Hash collision between models!\nCurrent: {}\nStored: {}".format(model_kwargs,stored_kwargs)
else:
print("Building model from scratch")
m = model.Model(**model_kwargs)
print("Saving model to cache")
sys.setrecursionlimit(100000)
pickle.dump((m,model_kwargs), open(model_filename,'wb'), protocol=pickle.HIGHEST_PROTOCOL)
elif unpickle_model is not None:
print("Unpickling model...")
m = pickle.load(open(unpickle_model, 'rb'))
else:
m = model.Model(**model_kwargs)
if pickle_model is not None:
sys.setrecursionlimit(100000)
print("Pickling model...")
pickle.dump(m, open(pickle_model,'wb'), protocol=pickle.HIGHEST_PROTOCOL)
if just_compile:
return
if learning_rate is not None:
m.set_learning_rate(learning_rate)
if not os.path.exists(outputdir):
os.makedirs(outputdir)
if resume_auto:
result = find_recent_params(outputdir)
if result is not None:
start_idx, paramfile = result
print("Automatically resuming from {} after iteration {}.".format(paramfile, start_idx))
resume = result
else:
print("Didn't find anything to resume. Starting from the beginning...")
if resume is not None:
start_idx, paramfile = resume
start_idx = int(start_idx)
load_params(m.params, open(paramfile, "rb") )
else:
start_idx = 0
if visualize is not False:
if visualize is True:
source = bucketed
else:
bucket, story = visualize
source = [[bucketed[bucket][story]]]
print("Starting to visualize...")
ggtnn_train.visualize(m, source, wordlist, eff_anslist, output_format, outputdir, snap=visualize_snap)
print("Wrote visualization files to {}.".format(outputdir))
elif evaluate_accuracy:
print("Evaluating accuracy...")
acc = ggtnn_train.test_accuracy(m, bucketed, bucket_sizes, len(eff_anslist), output_format, batch_size, batch_adjust, (not train_with_query))
print("Obtained accuracy of {}".format(acc))
elif visualization_test:
print("Starting visualization test...")
ggtnn_train.visualize(m, bucketed, wordlist, eff_anslist, output_format, outputdir, debugmode=True)
print("Wrote visualization files to {}.".format(outputdir))
else:
print("Starting to train...")
status = ggtnn_train.train(m, bucketed, bucket_sizes, len(eff_anslist), output_format, num_updates, outputdir, start_idx, batch_size, validation_buckets, validation_bucket_sizes, stop_at_accuracy, stop_at_loss, stop_at_overfitting, train_save_params, validation_interval, batch_adjust, interrupt_file)
if set_exit_status:
sys.exit(status.value)
parser = argparse.ArgumentParser(description='Train a graph memory network model.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('task_dir', help="Parsed directory for the task to load")
parser.add_argument('output_format_str', choices=[x.name for x in model.ModelOutputFormat], help="Output format for the task")
parser.add_argument('state_width', type=int, help="Width of node state")
parser.add_argument('--process-repr-size', type=int, default=50, help="Width of intermediate representations")
parser.add_argument('--mutable-nodes', action="store_true", help="Make nodes mutable")
parser.add_argument('--wipe-node-state', action="store_true", help="Wipe node state before the query")
parser.add_argument('--direct-reference', action="store_true", help="Use direct reference for input, based on node names")
parser.add_argument('--dynamic-nodes', action="store_true", help="Create nodes after each sentence. (Otherwise, create unique nodes at the beginning)")
parser.add_argument('--propagate-intermediate', action="store_true", help="Run a propagation step after each sentence")
parser.add_argument('--sequence-aggregate-repr', action="store_true", help="Compute the query aggregate representation from the sequence of graphs instead of just the last one")
parser.add_argument('--old-aggregate', action="store_true", help="Use the old, incorrect aggregate function")
parser.add_argument('--no-graph', dest='train_with_graph', action="store_false", help="Don't train using graph supervision")
parser.add_argument('--no-query', dest='train_with_query', action="store_false", help="Don't train using query supervision")
parser.add_argument('--outputdir', default="output", help="Directory to save output in")
parser.add_argument('--num-updates', default="10000", type=int, help="How many iterations to train")
parser.add_argument('--batch-size', default="10", type=int, help="Batch size to use")
parser.add_argument('--learning-rate', type=float, default=None, help="Use this learning rate")
parser.add_argument('--dropout-keep', default=1, type=float, help="Use dropout, with this keep chance")
parser.add_argument('--restrict-dataset', metavar="NUM_STORIES", type=int, default=None, help="Restrict size of dataset to this")
parser.add_argument('--save-params-interval', type=int, default=1000, dest="train_save_params", help="Save parameters after this many iterations")
parser.add_argument('--final-params-only', action="store_const", const=None, dest="train_save_params", help="Don't save parameters while training, only at the end.")
parser.add_argument('--validation', metavar="VALIDATION_DIR", default=None, help="Parsed directory of validation tasks")
parser.add_argument('--validation-interval', type=int, default=1000, help="Check validation after this many iterations")
parser.add_argument('--check-nan', dest="check_mode", action="store_const", const="nan", help="Check for NaN. Slows execution")
parser.add_argument('--check-debug', dest="check_mode", action="store_const", const="debug", help="Debug mode. Slows execution")
parser.add_argument('--visualize', nargs="?", const=True, default=False, metavar="BUCKET,STORY", type=lambda s:[int(x) for x in s.split(',')], help="Visualise current state instead of training. Optional parameter selects a particular story to visualize, and should be of the form bucketnum,index")
parser.add_argument('--visualize-snap', action="store_true", help="In visualization mode, snap to best option at each timestep")
parser.add_argument('--visualization-test', action="store_true", help="Like visualize, but use the correct graph instead of the model's graph")
parser.add_argument('--evaluate-accuracy', action="store_true", help="Evaluate accuracy of model")
parser.add_argument('--stop-at-accuracy', type=float, default=None, help="Stop training once it reaches this accuracy on validation set")
parser.add_argument('--stop-at-loss', type=float, default=None, help="Stop training once it reaches this loss on validation set")
parser.add_argument('--stop-at-overfitting', type=float, default=None, help="Stop training once validation loss is this many times higher than train loss")
parser.add_argument('--batch-adjust', type=int, default=None, help="If set, ensure that size of edge matrix does not exceed this")
parser.add_argument('--set-exit-status', action="store_true", help="Give info about training status in the exit status")
parser.add_argument('--just-compile', action="store_true", help="Don't run the model, just compile it")
parser.add_argument('--autopickle', metavar="PICKLEDIR", default=None, help="Automatically cache model in this directory")
parser.add_argument('--pickle-model', metavar="MODELFILE", default=None, help="Save the compiled model to a file")
parser.add_argument('--unpickle-model', metavar="MODELFILE", default=None, help="Load the model from a file instead of compiling it from scratch")
parser.add_argument('--interrupt-file', default=None, help="Interrupt training if this file appears")
resume_group = parser.add_mutually_exclusive_group()
resume_group.add_argument('--resume', nargs=2, metavar=('TIMESTEP', 'PARAMFILE'), default=None, help='Where to restore from: timestep, and file to load')
resume_group.add_argument('--resume-auto', action='store_true', help='Automatically restore from a previous run using output directory')
if __name__ == '__main__':
np.set_printoptions(linewidth=shutil.get_terminal_size((80, 20)).columns)
args = vars(parser.parse_args())
main(**args)