Skip to content

Commit 1632da5

Browse files
committed
predict script for rnn 3
1 parent e3edb43 commit 1632da5

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import print_function
2+
3+
import pandas as pd
4+
from keras_text_summarization.library.rnn import RecursiveRNN3
5+
import numpy as np
6+
7+
8+
def main():
9+
np.random.seed(42)
10+
data_dir_path = './data'
11+
model_dir_path = './models'
12+
13+
print('loading csv file ...')
14+
df = pd.read_csv(data_dir_path + "/fake_or_real_news.csv")
15+
# df = df.loc[df.index < 1000]
16+
X = df['text']
17+
Y = df.title
18+
19+
config = np.load(RecursiveRNN3.get_config_file_path(model_dir_path=model_dir_path)).item()
20+
21+
summarizer = RecursiveRNN3(config)
22+
summarizer.load_weights(weight_file_path=RecursiveRNN3.get_weight_file_path(model_dir_path=model_dir_path))
23+
24+
print('start predicting ...')
25+
for i in np.random.permutation(np.arange(len(X)))[0:20]:
26+
x = X[i]
27+
actual_headline = Y[i]
28+
headline = summarizer.summarize(x)
29+
# print('Article: ', x)
30+
print('Generated Headline: ', headline)
31+
print('Original Headline: ', actual_headline)
32+
33+
34+
if __name__ == '__main__':
35+
main()

keras_text_summarization/demo/recursive_rnn_v3_train.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def main():
4444

4545
history_plot_file_path = report_dir_path + '/' + RecursiveRNN3.model_name + '-history.png'
4646
if LOAD_EXISTING_WEIGHTS:
47-
history_plot_file_path = report_dir_path + '/' + RecursiveRNN2.model_name + '-history-v' + str(summarizer.version) + '.png'
47+
history_plot_file_path = report_dir_path + '/' + RecursiveRNN3.model_name + '-history-v' + str(summarizer.version) + '.png'
4848
plot_and_save_history(history, summarizer.model_name, history_plot_file_path, metrics={'loss', 'acc'})
4949

5050

0 commit comments

Comments
 (0)