|
4 | 4 | from __future__ import print_function
|
5 | 5 |
|
6 | 6 | from data_utils import load_task, vectorize_data
|
7 |
| -from sklearn import cross_validation, metrics |
| 7 | +from sklearn import metrics |
| 8 | +from sklearn.model_selection import train_test_split |
8 | 9 | from memn2n_kv import MemN2N_KV
|
9 | 10 | from itertools import chain
|
10 | 11 | from six.moves import range, reduce
|
|
36 | 37 | tf.flags.DEFINE_string("reader", "bow", "Reader for the model")
|
37 | 38 | FLAGS = tf.flags.FLAGS
|
38 | 39 |
|
39 |
| -FLAGS._parse_flags() |
40 | 40 | print("\nParameters:")
|
41 | 41 | with open(FLAGS.param_output_file, 'w') as f:
|
42 | 42 | for attr, value in sorted(FLAGS.__flags.items()):
|
|
80 | 80 | valA = []
|
81 | 81 | for task in train:
|
82 | 82 | S, Q, A = vectorize_data(task, word_idx, sentence_size, memory_size)
|
83 |
| - ts, vs, tq, vq, ta, va = cross_validation.train_test_split(S, Q, A, test_size=0.1, random_state=FLAGS.random_state) |
| 83 | + ts, vs, tq, vq, ta, va = train_test_split(S, Q, A, test_size=0.1, random_state=FLAGS.random_state) |
84 | 84 | trainS.append(ts)
|
85 | 85 | trainQ.append(tq)
|
86 | 86 | trainA.append(ta)
|
|
130 | 130 |
|
131 | 131 | model = MemN2N_KV(batch_size=batch_size, vocab_size=vocab_size,
|
132 | 132 | query_size=sentence_size, story_size=sentence_size, memory_key_size=memory_size,
|
133 |
| - feature_size=FLAGS.feature_size, memory_value_size=memory_size, embedding_size=FLAGS.embedding_size, hops=FLAGS.hops, reader=FLAGS.reader, l2_lambda=FLAGS.l2_lambda) |
| 133 | + feature_size=FLAGS.feature_size, memory_value_size=memory_size, |
| 134 | + embedding_size=FLAGS.embedding_size, hops=FLAGS.hops, reader=FLAGS.reader, |
| 135 | + l2_lambda=FLAGS.l2_lambda) |
134 | 136 | grads_and_vars = optimizer.compute_gradients(model.loss_op)
|
135 | 137 |
|
136 | 138 | grads_and_vars = [(tf.clip_by_norm(g, FLAGS.max_grad_norm), v)
|
|
144 | 146 | nil_grads_and_vars.append((g, v))
|
145 | 147 |
|
146 | 148 | train_op = optimizer.apply_gradients(nil_grads_and_vars, name="train_op", global_step=global_step)
|
147 |
| - sess.run(tf.initialize_all_variables()) |
| 149 | + sess.run(tf.global_variables_initializer()) |
148 | 150 |
|
149 | 151 | def train_step(s, q, a):
|
150 | 152 | feed_dict = {
|
|
0 commit comments