Skip to content

Commit 4429600

Browse files
committed
two new networks and some refactoring
1 parent b7323bb commit 4429600

File tree

7 files changed

+451
-99
lines changed

7 files changed

+451
-99
lines changed

.gitignore

+1-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ old*
55
states/*
66
logs/*
77
nesterov*
8-
change_19_task.py
98
*debug*
109
*.state
1110
*.sh
12-
for5*
11+
*deploy*

dmn_basic.py

+35-23
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ class DMN_basic:
2020

2121
def __init__(self, babi_train_raw, babi_test_raw, word2vec, word_vector_size,
2222
dim, mode, answer_module, input_mask_mode, memory_hops, l2,
23-
normalize_attention):
24-
23+
normalize_attention, **kwargs):
24+
25+
print "==> not used params in DMN class:", kwargs.keys()
2526
self.vocab = {}
2627
self.ivocab = {}
2728

@@ -85,11 +86,11 @@ def __init__(self, babi_train_raw, babi_test_raw, word2vec, word_vector_size,
8586
self.b_mem_hid = nn_utils.constant_param(value=0.0, shape=(self.dim,))
8687

8788
self.W_b = nn_utils.normal_param(std=0.1, shape=(self.dim, self.dim))
88-
self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 0))
89+
self.W_1 = nn_utils.normal_param(std=0.1, shape=(self.dim, 7 * self.dim + 2))
8990
self.W_2 = nn_utils.normal_param(std=0.1, shape=(1, self.dim))
9091
self.b_1 = nn_utils.constant_param(value=0.0, shape=(self.dim,))
9192
self.b_2 = nn_utils.constant_param(value=0.0, shape=(1,))
92-
93+
9394

9495
print "==> building episodic memory module (fixed number of steps: %d)" % self.memory_hops
9596
memory = [self.q_q.copy()]
@@ -147,8 +148,8 @@ def answer_step(prev_a, prev_y):
147148
self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid,
148149
self.W_mem_res_in, self.W_mem_res_hid, self.b_mem_res,
149150
self.W_mem_upd_in, self.W_mem_upd_hid, self.b_mem_upd,
150-
self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid, #self.W_b
151-
self.W_1, self.W_2, self.b_1, self.b_2, self.W_a]
151+
self.W_mem_hid_in, self.W_mem_hid_hid, self.b_mem_hid,
152+
self.W_b, self.W_1, self.W_2, self.b_1, self.b_2, self.W_a]
152153

153154
if self.answer_module == 'recurrent':
154155
self.params = self.params + [self.W_ans_res_in, self.W_ans_res_hid, self.b_ans_res,
@@ -157,9 +158,7 @@ def answer_step(prev_a, prev_y):
157158

158159

159160
print "==> building loss layer and computing updates"
160-
self.loss_ce = T.nnet.categorical_crossentropy(self.prediction.dimshuffle('x', 0),
161-
T.stack([self.answer_var]))[0]
162-
161+
self.loss_ce = T.nnet.categorical_crossentropy(self.prediction.dimshuffle('x', 0), T.stack([self.answer_var]))[0]
163162
if self.l2 > 0:
164163
self.loss_l2 = self.l2 * nn_utils.l2_reg(self.params)
165164
else:
@@ -172,14 +171,19 @@ def answer_step(prev_a, prev_y):
172171
if self.mode == 'train':
173172
print "==> compiling train_fn"
174173
self.train_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var],
175-
outputs=[self.prediction, self.loss],
176-
updates=updates)
174+
outputs=[self.prediction, self.loss],
175+
updates=updates)
177176

178177
print "==> compiling test_fn"
179178
self.test_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var],
180-
outputs=[self.prediction, self.loss, self.inp_c,
181-
self.q_q, last_mem])
179+
outputs=[self.prediction, self.loss, self.inp_c, self.q_q, last_mem])
180+
182181

182+
if self.mode == 'train':
183+
print "==> computing gradients (for debugging)"
184+
gradient = T.grad(self.loss, self.params)
185+
self.get_gradient_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.input_mask_var], outputs=gradient)
186+
183187

184188
def GRU_update(self, h, x, W_res_in, W_res_hid, b_res,
185189
W_upd_in, W_upd_hid, b_upd,
@@ -205,12 +209,12 @@ def input_gru_step(self, x, prev_h):
205209
return self.GRU_update(prev_h, x, self.W_inp_res_in, self.W_inp_res_hid, self.b_inp_res,
206210
self.W_inp_upd_in, self.W_inp_upd_hid, self.b_inp_upd,
207211
self.W_inp_hid_in, self.W_inp_hid_hid, self.b_inp_hid)
208-
212+
209213

210214
def new_attention_step(self, ct, prev_g, mem, q_q):
211-
#cWq = T.stack([T.dot(T.dot(ct, self.W_b), q_q)])
212-
#cWm = T.stack([T.dot(T.dot(ct, self.W_b), mem)])
213-
z = T.concatenate([ct, mem, q_q, ct * q_q, ct * mem, (ct - q_q) ** 2, (ct - mem) ** 2])#, cWq, cWm])
215+
cWq = T.stack([T.dot(T.dot(ct, self.W_b), q_q)])
216+
cWm = T.stack([T.dot(T.dot(ct, self.W_b), mem)])
217+
z = T.concatenate([ct, mem, q_q, ct * q_q, ct * mem, T.abs_(ct - q_q), T.abs_(ct - mem), cWq, cWm])
214218

215219
l_1 = T.dot(self.W_1, z) + self.b_1
216220
l_1 = T.tanh(l_1)
@@ -318,14 +322,14 @@ def get_batches_per_epoch(self, mode):
318322
return len(self.test_input)
319323
else:
320324
raise Exception("unknown mode")
321-
325+
322326

323327
def shuffle_train_set(self):
324328
print "==> Shuffling the train set"
325329
combined = zip(self.train_input, self.train_q, self.train_answer, self.train_input_mask)
326330
random.shuffle(combined)
327331
self.train_input, self.train_q, self.train_answer, self.train_input_mask = zip(*combined)
328-
332+
329333

330334
def step(self, batch_index, mode):
331335
if mode == "train" and self.mode == "test":
@@ -354,18 +358,26 @@ def step(self, batch_index, mode):
354358
skipped = 0
355359
grad_norm = float('NaN')
356360

361+
if mode == 'train':
362+
gradient_value = self.get_gradient_fn(inp, q, ans, input_mask)
363+
grad_norm = np.max([utils.get_norm(x) for x in gradient_value])
364+
365+
if (np.isnan(grad_norm)):
366+
print "==> gradient is nan at index %d." % batch_index
367+
print "==> skipping"
368+
skipped = 1
369+
357370
if skipped == 0:
358371
ret = theano_fn(inp, q, ans, input_mask)
359372
else:
360-
ret = [float('NaN'), float('NaN')]
373+
ret = [-1, -1]
361374

362375
param_norm = np.max([utils.get_norm(x.get_value()) for x in self.params])
363376

364377
return {"prediction": np.array([ret[0]]),
365378
"answers": np.array([ans]),
366379
"current_loss": ret[1],
367380
"skipped": skipped,
368-
"grad_norm": grad_norm,
369-
"param_norm": param_norm,
370-
"log": "",
381+
"log": "pn: %.3f \t gn: %.3f" % (param_norm, grad_norm)
371382
}
383+

dmn_batch.py

+13-18
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ class DMN_batch:
1919

2020
def __init__(self, babi_train_raw, babi_test_raw, word2vec, word_vector_size, dim,
2121
mode, answer_module, input_mask_mode, memory_hops, batch_size, l2,
22-
normalize_attention):
22+
normalize_attention, **kwargs):
23+
24+
print "==> not used params in DMN class:", kwargs.keys()
2325

2426
self.vocab = {}
2527
self.ivocab = {}
@@ -186,13 +188,15 @@ def answer_step(prev_a, prev_y):
186188

187189
if self.mode == 'train':
188190
print "==> compiling train_fn"
189-
self.train_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.fact_count_var, self.input_mask_var],
190-
outputs=[self.prediction, self.loss],
191-
updates=updates)
191+
self.train_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var,
192+
self.fact_count_var, self.input_mask_var],
193+
outputs=[self.prediction, self.loss],
194+
updates=updates)
192195

193196
print "==> compiling test_fn"
194-
self.test_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var, self.fact_count_var, self.input_mask_var],
195-
outputs=[self.prediction, self.loss, self.inp_c, self.q_q, last_mem])
197+
self.test_fn = theano.function(inputs=[self.input_var, self.q_var, self.answer_var,
198+
self.fact_count_var, self.input_mask_var],
199+
outputs=[self.prediction, self.loss])
196200

197201

198202

@@ -421,22 +425,13 @@ def step(self, batch_index, mode):
421425
fact_count = fact_counts[start_index:start_index+self.batch_size]
422426
input_mask = input_masks[start_index:start_index+self.batch_size]
423427

424-
skipped = 0
425-
grad_norm = float('NaN')
426-
427-
if skipped == 0:
428-
ret = theano_fn(inp, q, ans, fact_count, input_mask)
429-
else:
430-
ret = [float('NaN'), float('NaN')]
431-
428+
ret = theano_fn(inp, q, ans, fact_count, input_mask)
432429
param_norm = np.max([utils.get_norm(x.get_value()) for x in self.params])
433430

434431
return {"prediction": ret[0],
435432
"answers": ans,
436433
"current_loss": ret[1],
437-
"skipped": skipped,
438-
"grad_norm": grad_norm,
439-
"param_norm": param_norm,
440-
"log": "",
434+
"skipped": 0,
435+
"log": "pn: %.3f" % param_norm,
441436
}
442437

dmn_qa.py renamed to dmn_qa_draft.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717
class DMN_qa:
1818

1919
def __init__(self, babi_train_raw, babi_test_raw, word2vec, word_vector_size,
20-
dim, mode, input_mask_mode, memory_hops, l2, normalize_attention):
20+
dim, mode, input_mask_mode, memory_hops, l2, normalize_attention, **kwargs):
2121

22+
print "==> not used params in DMN class:", kwargs.keys()
2223
self.vocab = {}
2324
self.ivocab = {}
2425

@@ -423,9 +424,7 @@ def step(self, batch_index, mode):
423424
"answers": np.array([ans]),
424425
"current_loss": ret[1],
425426
"skipped": skipped,
426-
"grad_norm": grad_norm,
427-
"param_norm": param_norm,
428-
"log": "",
427+
"log": "pn: %.3f \t gn: %.3f" % (param_norm, grad_norm)
429428
}
430429

431430

0 commit comments

Comments
 (0)