This repository was archived by the owner on Nov 16, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 920
/
Copy pathsentence_selection.py
131 lines (106 loc) · 4.42 KB
/
sentence_selection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# This script reuses some code from https://github.com/nlpyang/BertSum
import itertools
import re
def _get_ngrams(n, text):
"""Calcualtes n-grams.
Args:
n: which n-grams to calculate
text: An array of tokens
Returns:
A set of n-grams
"""
ngram_set = set()
text_length = len(text)
max_index_ngram_start = text_length - n
for i in range(max_index_ngram_start + 1):
ngram_set.add(tuple(text[i:i + n]))
return ngram_set
def _get_word_ngrams(n, sentences):
"""Calculates word n-grams for multiple sentences.
"""
assert len(sentences) > 0
assert n > 0
# words = _split_into_words(sentences)
words = sum(sentences, [])
# words = [w for w in words if w not in stopwords]
return _get_ngrams(n, words)
def cal_rouge(evaluated_ngrams, reference_ngrams):
reference_count = len(reference_ngrams)
evaluated_count = len(evaluated_ngrams)
overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
overlapping_count = len(overlapping_ngrams)
if evaluated_count == 0:
precision = 0.0
else:
precision = overlapping_count / evaluated_count
if reference_count == 0:
recall = 0.0
else:
recall = overlapping_count / reference_count
f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))
return {"f": f1_score, "p": precision, "r": recall}
def combination_selection(doc_sent_list, abstract_sent_list, summary_size):
def _rouge_clean(s):
return re.sub(r'[^a-zA-Z0-9 ]', '', s)
max_rouge = 0.0
max_idx = (0, 0)
abstract = sum(abstract_sent_list, [])
abstract = _rouge_clean(' '.join(abstract)).split()
sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list]
evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents]
reference_1grams = _get_word_ngrams(1, [abstract])
evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents]
reference_2grams = _get_word_ngrams(2, [abstract])
impossible_sents = []
for s in range(summary_size + 1):
combinations = itertools.combinations([i for i in range(len(sents)) if i not in impossible_sents], s + 1)
for c in combinations:
candidates_1 = [evaluated_1grams[idx] for idx in c]
candidates_1 = set.union(*map(set, candidates_1))
candidates_2 = [evaluated_2grams[idx] for idx in c]
candidates_2 = set.union(*map(set, candidates_2))
rouge_1 = cal_rouge(candidates_1, reference_1grams)['f']
rouge_2 = cal_rouge(candidates_2, reference_2grams)['f']
rouge_score = rouge_1 + rouge_2
if (s == 0 and rouge_score == 0):
impossible_sents.append(c[0])
if rouge_score > max_rouge:
max_idx = c
max_rouge = rouge_score
return sorted(list(max_idx))
def greedy_selection(doc_sent_list, abstract_sent_list, summary_size):
def _rouge_clean(s):
return re.sub(r'[^a-zA-Z0-9 ]', '', s)
max_rouge = 0.0
abstract = sum(abstract_sent_list, [])
abstract = _rouge_clean(' '.join(abstract)).split()
sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list]
evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents]
reference_1grams = _get_word_ngrams(1, [abstract])
evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents]
reference_2grams = _get_word_ngrams(2, [abstract])
selected = []
for s in range(summary_size):
cur_max_rouge = max_rouge
cur_id = -1
for i in range(len(sents)):
if (i in selected):
continue
c = selected + [i]
candidates_1 = [evaluated_1grams[idx] for idx in c]
candidates_1 = set.union(*map(set, candidates_1))
candidates_2 = [evaluated_2grams[idx] for idx in c]
candidates_2 = set.union(*map(set, candidates_2))
rouge_1 = cal_rouge(candidates_1, reference_1grams)['f']
rouge_2 = cal_rouge(candidates_2, reference_2grams)['f']
rouge_score = rouge_1 + rouge_2
if rouge_score > cur_max_rouge:
cur_max_rouge = rouge_score
cur_id = i
if (cur_id == -1):
return selected
selected.append(cur_id)
max_rouge = cur_max_rouge
return sorted(selected)