Skip to content

Commit 840fe8f

Browse files
committed
support extract QA operator
1 parent 1c7bdc4 commit 840fe8f

File tree

4 files changed

+151
-1
lines changed

4 files changed

+151
-1
lines changed

configs/config_all.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ process:
6868
- clean_links_mapper: # remove web links from text.
6969
- clean_copyright_mapper: # remove copyright comments.
7070
- expand_macro_mapper: # expand macro definitions in Latex text.
71+
- extract_qa_mapper: # mapper to extract question and answer pair from text.
72+
hf_model: 'alibaba-pai/pai-qwen1_5-7b-doc2qa'
7173
- fix_unicode_mapper: # fix unicode errors in text.
7274
- image_blur_mapper: # mapper to blur images.
7375
p: 0.2 # probability of the image being blured

data_juicer/ops/mapper/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from . import (audio_ffmpeg_wrapped_mapper, chinese_convert_mapper,
33
clean_copyright_mapper, clean_email_mapper, clean_html_mapper,
44
clean_ip_mapper, clean_links_mapper, expand_macro_mapper,
5-
fix_unicode_mapper, image_blur_mapper,
5+
extract_qa_mapper, fix_unicode_mapper, image_blur_mapper,
66
image_captioning_from_gpt4v_mapper, image_captioning_mapper,
77
image_diffusion_mapper, image_face_blur_mapper,
88
nlpaug_en_mapper, nlpcda_zh_mapper,
@@ -32,6 +32,7 @@
3232
from .clean_ip_mapper import CleanIpMapper
3333
from .clean_links_mapper import CleanLinksMapper
3434
from .expand_macro_mapper import ExpandMacroMapper
35+
from .extract_qa_mapper import ExtractQAMapper
3536
from .fix_unicode_mapper import FixUnicodeMapper
3637
from .image_blur_mapper import ImageBlurMapper
3738
from .image_captioning_from_gpt4v_mapper import ImageCaptioningFromGPT4VMapper
@@ -102,6 +103,7 @@
102103
'VideoTaggingFromFramesMapper',
103104
'RemoveCommentsMapper',
104105
'ExpandMacroMapper',
106+
'ExtractQAMapper',
105107
'ImageCaptioningMapper',
106108
'RemoveWordsWithIncorrectSubstringsMapper',
107109
'VideoCaptioningFromVideoMapper',
+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import json
2+
import logging
3+
import re
4+
5+
from data_juicer.ops.base_op import OPERATORS, Mapper
6+
from data_juicer.utils.model_utils import get_model, prepare_model
7+
8+
9+
@OPERATORS.register_module('extract_qa_mapper')
10+
class ExtractQAMapper(Mapper):
11+
"""
12+
Mapper to extract question and answer pair from text samples.
13+
Recommended model list: [
14+
'alibaba-pai/pai-llama3-8b-doc2qa',
15+
'alibaba-pai/pai-baichuan2-7b-doc2qa',
16+
'alibaba-pai/pai-qwen1_5-4b-doc2qa',
17+
'alibaba-pai/pai-qwen1_5-7b-doc2qa',
18+
'alibaba-pai/pai-qwen1_5-1b8-doc2qa',
19+
'alibaba-pai/pai-qwen1_5-0b5-doc2qa'
20+
]
21+
These recommended models are all trained with Chinese data
22+
and are suitable for Chinese.
23+
"""
24+
25+
def __init__(self,
26+
hf_model: str = 'alibaba-pai/pai-qwen1_5-7b-doc2qa',
27+
pattern: str = None,
28+
qa_format: str = 'chatml',
29+
*args,
30+
**kwargs):
31+
"""
32+
Initialization method.
33+
:param hf_model: Hugginface model id.
34+
:param pattern: regular expression pattern to search for within text.
35+
:param qa_format: Output format of question and answer pair.
36+
:param args: extra args
37+
:param kwargs: extra args
38+
39+
The default data format parsed by this interface is as follows:
40+
Model Input:
41+
蒙古国的首都是乌兰巴托(Ulaanbaatar)
42+
冰岛的首都是雷克雅未克(Reykjavik)
43+
Model Output:
44+
蒙古国的首都是乌兰巴托(Ulaanbaatar)
45+
冰岛的首都是雷克雅未克(Reykjavik)
46+
Human: 请问蒙古国的首都是哪里?
47+
Assistant: 你好,根据提供的信息,蒙古国的首都是乌兰巴托(Ulaanbaatar)。
48+
Human: 冰岛的首都是哪里呢?
49+
Assistant: 冰岛的首都是雷克雅未克(Reykjavik)。
50+
...
51+
"""
52+
53+
super().__init__(*args, **kwargs)
54+
self._batched_op = True
55+
self._accelerator = 'cuda'
56+
57+
if pattern is None:
58+
self.pattern = r'Human: (.*?)\nAssistant: (.*?)(?=\nHuman|$)'
59+
else:
60+
self.pattern = pattern
61+
62+
self.qa_format = qa_format
63+
self.model_key = prepare_model(model_type='huggingface',
64+
pretrained_model_name_or_path=hf_model)
65+
66+
def _extract_qa(self, output):
67+
"""Extract qestion and answer pair from model output response."""
68+
qa_list = []
69+
70+
pat = re.compile(self.pattern, re.DOTALL)
71+
qa_pairs = pat.findall(output)
72+
73+
for _, qa in enumerate(qa_pairs, 1):
74+
user, assistant = qa
75+
qa_list.append((user.strip(), assistant.strip()))
76+
77+
return qa_list
78+
79+
def process(self, sample, rank=None):
80+
self.model, self.processor = get_model(self.model_key, rank=rank)
81+
82+
inputs = self.processor(sample[self.text_key],
83+
return_tensors='pt').to(self.model.device)
84+
response = self.model.generate(**inputs)
85+
output = self.processor.decode(response.cpu()[0],
86+
skip_special_tokens=True)
87+
qa_list = self._extract_qa(output)
88+
89+
if not len(qa_list):
90+
logging.info(
91+
'No question and answer data was extracted from this sample!')
92+
93+
dialogue_data = []
94+
if self.qa_format == 'chatml':
95+
for qa in qa_list:
96+
dialogue_data.append({
97+
'messages': [{
98+
'role': 'user',
99+
'content': qa[0]
100+
}, {
101+
'role': 'assistant',
102+
'content': qa[1]
103+
}]
104+
})
105+
else:
106+
raise ValueError(f'Not support {self.qa_format}!')
107+
108+
sample[self.text_key] = json.dumps(dialogue_data, ensure_ascii=False)
109+
110+
return sample
+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import unittest
2+
import json
3+
from data_juicer.ops.mapper.extract_qa_mapper import ExtractQAMapper
4+
from data_juicer.utils.unittest_utils import (SKIPPED_TESTS,
5+
DataJuicerTestCaseBase)
6+
7+
# Skip tests for this OP in the GitHub actions due to disk space limitation.
8+
# These tests have been tested locally.
9+
@SKIPPED_TESTS.register_module()
10+
class ExtractQAMapperTest(DataJuicerTestCaseBase):
11+
text_key = 'text'
12+
13+
def _run_extract_qa(self, samples):
14+
op = ExtractQAMapper(
15+
hf_model='alibaba-pai/pai-qwen1_5-7b-doc2qa',
16+
qa_format='chatml'
17+
)
18+
for sample in samples:
19+
result = op.process(sample)
20+
out_text = json.loads(result[self.text_key])
21+
22+
# test one output qa sample
23+
qa_sample = out_text[0]
24+
self.assertIn('role', qa_sample['messages'][0])
25+
self.assertIn('content', qa_sample['messages'][0])
26+
27+
def test_extract_qa(self):
28+
samples = [
29+
{
30+
self.text_key: '蒙古国的首都是乌兰巴托(Ulaanbaatar)\n冰岛的首都是雷克雅未克(Reykjavik)\n'
31+
}]
32+
self._run_extract_qa(samples)
33+
34+
35+
if __name__ == '__main__':
36+
unittest.main()

0 commit comments

Comments
 (0)