Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Commit 165645b

Browse files
committed
move sentence selection out of bertsum package
1 parent 6c6af56 commit 165645b

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed
+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
# This script reuses some code from https://github.com/nlpyang/BertSum
5+
6+
7+
import itertools
8+
import re
9+
10+
11+
def _get_ngrams(n, text):
12+
"""Calcualtes n-grams.
13+
Args:
14+
n: which n-grams to calculate
15+
text: An array of tokens
16+
Returns:
17+
A set of n-grams
18+
"""
19+
ngram_set = set()
20+
text_length = len(text)
21+
max_index_ngram_start = text_length - n
22+
for i in range(max_index_ngram_start + 1):
23+
ngram_set.add(tuple(text[i:i + n]))
24+
return ngram_set
25+
26+
27+
def _get_word_ngrams(n, sentences):
28+
"""Calculates word n-grams for multiple sentences.
29+
"""
30+
assert len(sentences) > 0
31+
assert n > 0
32+
33+
# words = _split_into_words(sentences)
34+
35+
words = sum(sentences, [])
36+
# words = [w for w in words if w not in stopwords]
37+
return _get_ngrams(n, words)
38+
39+
40+
def cal_rouge(evaluated_ngrams, reference_ngrams):
41+
reference_count = len(reference_ngrams)
42+
evaluated_count = len(evaluated_ngrams)
43+
44+
overlapping_ngrams = evaluated_ngrams.intersection(reference_ngrams)
45+
overlapping_count = len(overlapping_ngrams)
46+
47+
if evaluated_count == 0:
48+
precision = 0.0
49+
else:
50+
precision = overlapping_count / evaluated_count
51+
52+
if reference_count == 0:
53+
recall = 0.0
54+
else:
55+
recall = overlapping_count / reference_count
56+
57+
f1_score = 2.0 * ((precision * recall) / (precision + recall + 1e-8))
58+
return {"f": f1_score, "p": precision, "r": recall}
59+
60+
61+
def combination_selection(doc_sent_list, abstract_sent_list, summary_size):
62+
def _rouge_clean(s):
63+
return re.sub(r'[^a-zA-Z0-9 ]', '', s)
64+
65+
max_rouge = 0.0
66+
max_idx = (0, 0)
67+
abstract = sum(abstract_sent_list, [])
68+
abstract = _rouge_clean(' '.join(abstract)).split()
69+
sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list]
70+
evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents]
71+
reference_1grams = _get_word_ngrams(1, [abstract])
72+
evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents]
73+
reference_2grams = _get_word_ngrams(2, [abstract])
74+
75+
impossible_sents = []
76+
for s in range(summary_size + 1):
77+
combinations = itertools.combinations([i for i in range(len(sents)) if i not in impossible_sents], s + 1)
78+
for c in combinations:
79+
candidates_1 = [evaluated_1grams[idx] for idx in c]
80+
candidates_1 = set.union(*map(set, candidates_1))
81+
candidates_2 = [evaluated_2grams[idx] for idx in c]
82+
candidates_2 = set.union(*map(set, candidates_2))
83+
rouge_1 = cal_rouge(candidates_1, reference_1grams)['f']
84+
rouge_2 = cal_rouge(candidates_2, reference_2grams)['f']
85+
86+
rouge_score = rouge_1 + rouge_2
87+
if (s == 0 and rouge_score == 0):
88+
impossible_sents.append(c[0])
89+
if rouge_score > max_rouge:
90+
max_idx = c
91+
max_rouge = rouge_score
92+
return sorted(list(max_idx))
93+
94+
95+
def greedy_selection(doc_sent_list, abstract_sent_list, summary_size):
96+
def _rouge_clean(s):
97+
return re.sub(r'[^a-zA-Z0-9 ]', '', s)
98+
99+
max_rouge = 0.0
100+
abstract = sum(abstract_sent_list, [])
101+
abstract = _rouge_clean(' '.join(abstract)).split()
102+
sents = [_rouge_clean(' '.join(s)).split() for s in doc_sent_list]
103+
evaluated_1grams = [_get_word_ngrams(1, [sent]) for sent in sents]
104+
reference_1grams = _get_word_ngrams(1, [abstract])
105+
evaluated_2grams = [_get_word_ngrams(2, [sent]) for sent in sents]
106+
reference_2grams = _get_word_ngrams(2, [abstract])
107+
108+
selected = []
109+
for s in range(summary_size):
110+
cur_max_rouge = max_rouge
111+
cur_id = -1
112+
for i in range(len(sents)):
113+
if (i in selected):
114+
continue
115+
c = selected + [i]
116+
candidates_1 = [evaluated_1grams[idx] for idx in c]
117+
candidates_1 = set.union(*map(set, candidates_1))
118+
candidates_2 = [evaluated_2grams[idx] for idx in c]
119+
candidates_2 = set.union(*map(set, candidates_2))
120+
rouge_1 = cal_rouge(candidates_1, reference_1grams)['f']
121+
rouge_2 = cal_rouge(candidates_2, reference_2grams)['f']
122+
rouge_score = rouge_1 + rouge_2
123+
if rouge_score > cur_max_rouge:
124+
cur_max_rouge = rouge_score
125+
cur_id = i
126+
if (cur_id == -1):
127+
return selected
128+
selected.append(cur_id)
129+
max_rouge = cur_max_rouge
130+
131+
return sorted(selected)

0 commit comments

Comments
 (0)