Skip to content

Commit c8ac1e4

Browse files
authored
Merge pull request #51 from zhengjxu/master
the dropout probability should be different between train and inference
2 parents 909a8b7 + 43d67e9 commit c8ac1e4

File tree

1 file changed

+10
-3
lines changed

1 file changed

+10
-3
lines changed

10 - RNN/02 - Autocomplete.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -70,13 +70,16 @@ def make_batch(seq_data):
7070
# 기존처럼 one-hot 인코딩을 사용한다면 입력값의 형태는 [None, n_class] 여야합니다.
7171
Y = tf.placeholder(tf.int32, [None])
7272

73+
# dropout prob for RNN
74+
keep_prob = tf.placeholder(tf.float32, [])
75+
7376
W = tf.Variable(tf.random_normal([n_hidden, n_class]))
7477
b = tf.Variable(tf.random_normal([n_class]))
7578

7679
# RNN 셀을 생성합니다.
7780
cell1 = tf.nn.rnn_cell.BasicLSTMCell(n_hidden)
7881
# 과적합 방지를 위한 Dropout 기법을 사용합니다.
79-
cell1 = tf.nn.rnn_cell.DropoutWrapper(cell1, output_keep_prob=0.5)
82+
cell1 = tf.nn.rnn_cell.DropoutWrapper(cell1, output_keep_prob=keep_prob)
8083
# 여러개의 셀을 조합해서 사용하기 위해 셀을 추가로 생성합니다.
8184
cell2 = tf.nn.rnn_cell.BasicLSTMCell(n_hidden)
8285

@@ -108,7 +111,9 @@ def make_batch(seq_data):
108111

109112
for epoch in range(total_epoch):
110113
_, loss = sess.run([optimizer, cost],
111-
feed_dict={X: input_batch, Y: target_batch})
114+
feed_dict={X: input_batch,
115+
Y: target_batch,
116+
keep_prob: 0.5})
112117

113118
print('Epoch:', '%04d' % (epoch + 1),
114119
'cost =', '{:.6f}'.format(loss))
@@ -127,7 +132,9 @@ def make_batch(seq_data):
127132
input_batch, target_batch = make_batch(seq_data)
128133

129134
predict, accuracy_val = sess.run([prediction, accuracy],
130-
feed_dict={X: input_batch, Y: target_batch})
135+
feed_dict={X: input_batch,
136+
Y: target_batch,
137+
keep_prob:1})
131138

132139
predict_words = []
133140
for idx, val in enumerate(seq_data):

0 commit comments

Comments
 (0)