Skip to content

Commit e3edb43

Browse files
committed
minor update
1 parent e08df6f commit e3edb43

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

keras_text_summarization/demo/recursive_rnn_v3_train.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def main():
3131
summarizer = RecursiveRNN3(config)
3232

3333
if LOAD_EXISTING_WEIGHTS:
34-
weight_file_path = RecursiveRNN2.get_weight_file_path(model_dir_path=model_dir_path)
34+
weight_file_path = RecursiveRNN3.get_weight_file_path(model_dir_path=model_dir_path)
3535
summarizer.load_weights(weight_file_path=weight_file_path)
3636

3737
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, test_size=0.2, random_state=42)
@@ -42,7 +42,7 @@ def main():
4242
print('start fitting ...')
4343
history = summarizer.fit(Xtrain, Ytrain, Xtest, Ytest, epochs=20, batch_size=256)
4444

45-
history_plot_file_path = report_dir_path + '/' + RecursiveRNN2.model_name + '-history.png'
45+
history_plot_file_path = report_dir_path + '/' + RecursiveRNN3.model_name + '-history.png'
4646
if LOAD_EXISTING_WEIGHTS:
4747
history_plot_file_path = report_dir_path + '/' + RecursiveRNN2.model_name + '-history-v' + str(summarizer.version) + '.png'
4848
plot_and_save_history(history, summarizer.model_name, history_plot_file_path, metrics={'loss', 'acc'})

0 commit comments

Comments
 (0)