@@ -70,13 +70,16 @@ def make_batch(seq_data):
70
70
# 기존처럼 one-hot 인코딩을 사용한다면 입력값의 형태는 [None, n_class] 여야합니다.
71
71
Y = tf .placeholder (tf .int32 , [None ])
72
72
73
+ # dropout prob for RNN
74
+ keep_prob = tf .placeholder (tf .float32 , [])
75
+
73
76
W = tf .Variable (tf .random_normal ([n_hidden , n_class ]))
74
77
b = tf .Variable (tf .random_normal ([n_class ]))
75
78
76
79
# RNN 셀을 생성합니다.
77
80
cell1 = tf .nn .rnn_cell .BasicLSTMCell (n_hidden )
78
81
# 과적합 방지를 위한 Dropout 기법을 사용합니다.
79
- cell1 = tf .nn .rnn_cell .DropoutWrapper (cell1 , output_keep_prob = 0.5 )
82
+ cell1 = tf .nn .rnn_cell .DropoutWrapper (cell1 , output_keep_prob = keep_prob )
80
83
# 여러개의 셀을 조합해서 사용하기 위해 셀을 추가로 생성합니다.
81
84
cell2 = tf .nn .rnn_cell .BasicLSTMCell (n_hidden )
82
85
@@ -108,7 +111,9 @@ def make_batch(seq_data):
108
111
109
112
for epoch in range (total_epoch ):
110
113
_ , loss = sess .run ([optimizer , cost ],
111
- feed_dict = {X : input_batch , Y : target_batch })
114
+ feed_dict = {X : input_batch ,
115
+ Y : target_batch ,
116
+ keep_prob : 0.5 })
112
117
113
118
print ('Epoch:' , '%04d' % (epoch + 1 ),
114
119
'cost =' , '{:.6f}' .format (loss ))
@@ -127,7 +132,9 @@ def make_batch(seq_data):
127
132
input_batch , target_batch = make_batch (seq_data )
128
133
129
134
predict , accuracy_val = sess .run ([prediction , accuracy ],
130
- feed_dict = {X : input_batch , Y : target_batch })
135
+ feed_dict = {X : input_batch ,
136
+ Y : target_batch ,
137
+ keep_prob :1 })
131
138
132
139
predict_words = []
133
140
for idx , val in enumerate (seq_data ):
0 commit comments