@@ -20,8 +20,9 @@ class DMN_basic:
20
20
21
21
def __init__ (self , babi_train_raw , babi_test_raw , word2vec , word_vector_size ,
22
22
dim , mode , answer_module , input_mask_mode , memory_hops , l2 ,
23
- normalize_attention ):
24
-
23
+ normalize_attention , ** kwargs ):
24
+
25
+ print "==> not used params in DMN class:" , kwargs .keys ()
25
26
self .vocab = {}
26
27
self .ivocab = {}
27
28
@@ -85,11 +86,11 @@ def __init__(self, babi_train_raw, babi_test_raw, word2vec, word_vector_size,
85
86
self .b_mem_hid = nn_utils .constant_param (value = 0.0 , shape = (self .dim ,))
86
87
87
88
self .W_b = nn_utils .normal_param (std = 0.1 , shape = (self .dim , self .dim ))
88
- self .W_1 = nn_utils .normal_param (std = 0.1 , shape = (self .dim , 7 * self .dim + 0 ))
89
+ self .W_1 = nn_utils .normal_param (std = 0.1 , shape = (self .dim , 7 * self .dim + 2 ))
89
90
self .W_2 = nn_utils .normal_param (std = 0.1 , shape = (1 , self .dim ))
90
91
self .b_1 = nn_utils .constant_param (value = 0.0 , shape = (self .dim ,))
91
92
self .b_2 = nn_utils .constant_param (value = 0.0 , shape = (1 ,))
92
-
93
+
93
94
94
95
print "==> building episodic memory module (fixed number of steps: %d)" % self .memory_hops
95
96
memory = [self .q_q .copy ()]
@@ -147,8 +148,8 @@ def answer_step(prev_a, prev_y):
147
148
self .W_inp_hid_in , self .W_inp_hid_hid , self .b_inp_hid ,
148
149
self .W_mem_res_in , self .W_mem_res_hid , self .b_mem_res ,
149
150
self .W_mem_upd_in , self .W_mem_upd_hid , self .b_mem_upd ,
150
- self .W_mem_hid_in , self .W_mem_hid_hid , self .b_mem_hid , #self.W_b
151
- self .W_1 , self .W_2 , self .b_1 , self .b_2 , self .W_a ]
151
+ self .W_mem_hid_in , self .W_mem_hid_hid , self .b_mem_hid ,
152
+ self .W_b , self . W_1 , self .W_2 , self .b_1 , self .b_2 , self .W_a ]
152
153
153
154
if self .answer_module == 'recurrent' :
154
155
self .params = self .params + [self .W_ans_res_in , self .W_ans_res_hid , self .b_ans_res ,
@@ -157,9 +158,7 @@ def answer_step(prev_a, prev_y):
157
158
158
159
159
160
print "==> building loss layer and computing updates"
160
- self .loss_ce = T .nnet .categorical_crossentropy (self .prediction .dimshuffle ('x' , 0 ),
161
- T .stack ([self .answer_var ]))[0 ]
162
-
161
+ self .loss_ce = T .nnet .categorical_crossentropy (self .prediction .dimshuffle ('x' , 0 ), T .stack ([self .answer_var ]))[0 ]
163
162
if self .l2 > 0 :
164
163
self .loss_l2 = self .l2 * nn_utils .l2_reg (self .params )
165
164
else :
@@ -172,14 +171,19 @@ def answer_step(prev_a, prev_y):
172
171
if self .mode == 'train' :
173
172
print "==> compiling train_fn"
174
173
self .train_fn = theano .function (inputs = [self .input_var , self .q_var , self .answer_var , self .input_mask_var ],
175
- outputs = [self .prediction , self .loss ],
176
- updates = updates )
174
+ outputs = [self .prediction , self .loss ],
175
+ updates = updates )
177
176
178
177
print "==> compiling test_fn"
179
178
self .test_fn = theano .function (inputs = [self .input_var , self .q_var , self .answer_var , self .input_mask_var ],
180
- outputs = [self .prediction , self .loss , self .inp_c ,
181
- self . q_q , last_mem ])
179
+ outputs = [self .prediction , self .loss , self .inp_c , self . q_q , last_mem ])
180
+
182
181
182
+ if self .mode == 'train' :
183
+ print "==> computing gradients (for debugging)"
184
+ gradient = T .grad (self .loss , self .params )
185
+ self .get_gradient_fn = theano .function (inputs = [self .input_var , self .q_var , self .answer_var , self .input_mask_var ], outputs = gradient )
186
+
183
187
184
188
def GRU_update (self , h , x , W_res_in , W_res_hid , b_res ,
185
189
W_upd_in , W_upd_hid , b_upd ,
@@ -205,12 +209,12 @@ def input_gru_step(self, x, prev_h):
205
209
return self .GRU_update (prev_h , x , self .W_inp_res_in , self .W_inp_res_hid , self .b_inp_res ,
206
210
self .W_inp_upd_in , self .W_inp_upd_hid , self .b_inp_upd ,
207
211
self .W_inp_hid_in , self .W_inp_hid_hid , self .b_inp_hid )
208
-
212
+
209
213
210
214
def new_attention_step (self , ct , prev_g , mem , q_q ):
211
- # cWq = T.stack([T.dot(T.dot(ct, self.W_b), q_q)])
212
- # cWm = T.stack([T.dot(T.dot(ct, self.W_b), mem)])
213
- z = T .concatenate ([ct , mem , q_q , ct * q_q , ct * mem , (ct - q_q ) ** 2 , (ct - mem ) ** 2 ]) # , cWq, cWm])
215
+ cWq = T .stack ([T .dot (T .dot (ct , self .W_b ), q_q )])
216
+ cWm = T .stack ([T .dot (T .dot (ct , self .W_b ), mem )])
217
+ z = T .concatenate ([ct , mem , q_q , ct * q_q , ct * mem , T . abs_ (ct - q_q ), T . abs_ (ct - mem ), cWq , cWm ])
214
218
215
219
l_1 = T .dot (self .W_1 , z ) + self .b_1
216
220
l_1 = T .tanh (l_1 )
@@ -318,14 +322,14 @@ def get_batches_per_epoch(self, mode):
318
322
return len (self .test_input )
319
323
else :
320
324
raise Exception ("unknown mode" )
321
-
325
+
322
326
323
327
def shuffle_train_set (self ):
324
328
print "==> Shuffling the train set"
325
329
combined = zip (self .train_input , self .train_q , self .train_answer , self .train_input_mask )
326
330
random .shuffle (combined )
327
331
self .train_input , self .train_q , self .train_answer , self .train_input_mask = zip (* combined )
328
-
332
+
329
333
330
334
def step (self , batch_index , mode ):
331
335
if mode == "train" and self .mode == "test" :
@@ -354,18 +358,26 @@ def step(self, batch_index, mode):
354
358
skipped = 0
355
359
grad_norm = float ('NaN' )
356
360
361
+ if mode == 'train' :
362
+ gradient_value = self .get_gradient_fn (inp , q , ans , input_mask )
363
+ grad_norm = np .max ([utils .get_norm (x ) for x in gradient_value ])
364
+
365
+ if (np .isnan (grad_norm )):
366
+ print "==> gradient is nan at index %d." % batch_index
367
+ print "==> skipping"
368
+ skipped = 1
369
+
357
370
if skipped == 0 :
358
371
ret = theano_fn (inp , q , ans , input_mask )
359
372
else :
360
- ret = [float ( 'NaN' ), float ( 'NaN' ) ]
373
+ ret = [- 1 , - 1 ]
361
374
362
375
param_norm = np .max ([utils .get_norm (x .get_value ()) for x in self .params ])
363
376
364
377
return {"prediction" : np .array ([ret [0 ]]),
365
378
"answers" : np .array ([ans ]),
366
379
"current_loss" : ret [1 ],
367
380
"skipped" : skipped ,
368
- "grad_norm" : grad_norm ,
369
- "param_norm" : param_norm ,
370
- "log" : "" ,
381
+ "log" : "pn: %.3f \t gn: %.3f" % (param_norm , grad_norm )
371
382
}
383
+
0 commit comments