|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import re |
| 4 | + |
| 5 | +from data_juicer.ops.base_op import OPERATORS, Mapper |
| 6 | +from data_juicer.utils.model_utils import get_model, prepare_model |
| 7 | + |
| 8 | + |
| 9 | +@OPERATORS.register_module('extract_qa_mapper') |
| 10 | +class ExtractQAMapper(Mapper): |
| 11 | + """ |
| 12 | + Mapper to extract question and answer pair from text samples. |
| 13 | + Recommended model list: [ |
| 14 | + 'alibaba-pai/pai-llama3-8b-doc2qa', |
| 15 | + 'alibaba-pai/pai-baichuan2-7b-doc2qa', |
| 16 | + 'alibaba-pai/pai-qwen1_5-4b-doc2qa', |
| 17 | + 'alibaba-pai/pai-qwen1_5-7b-doc2qa', |
| 18 | + 'alibaba-pai/pai-qwen1_5-1b8-doc2qa', |
| 19 | + 'alibaba-pai/pai-qwen1_5-0b5-doc2qa' |
| 20 | + ] |
| 21 | + These recommended models are all trained with Chinese data |
| 22 | + and are suitable for Chinese. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__(self, |
| 26 | + hf_model: str = 'alibaba-pai/pai-qwen1_5-7b-doc2qa', |
| 27 | + pattern: str = None, |
| 28 | + qa_format: str = 'chatml', |
| 29 | + *args, |
| 30 | + **kwargs): |
| 31 | + """ |
| 32 | + Initialization method. |
| 33 | + :param hf_model: Hugginface model id. |
| 34 | + :param pattern: regular expression pattern to search for within text. |
| 35 | + :param qa_format: Output format of question and answer pair. |
| 36 | + :param args: extra args |
| 37 | + :param kwargs: extra args |
| 38 | +
|
| 39 | + The default data format parsed by this interface is as follows: |
| 40 | + Model Input: |
| 41 | + 蒙古国的首都是乌兰巴托(Ulaanbaatar) |
| 42 | + 冰岛的首都是雷克雅未克(Reykjavik) |
| 43 | + Model Output: |
| 44 | + 蒙古国的首都是乌兰巴托(Ulaanbaatar) |
| 45 | + 冰岛的首都是雷克雅未克(Reykjavik) |
| 46 | + Human: 请问蒙古国的首都是哪里? |
| 47 | + Assistant: 你好,根据提供的信息,蒙古国的首都是乌兰巴托(Ulaanbaatar)。 |
| 48 | + Human: 冰岛的首都是哪里呢? |
| 49 | + Assistant: 冰岛的首都是雷克雅未克(Reykjavik)。 |
| 50 | + ... |
| 51 | + """ |
| 52 | + |
| 53 | + super().__init__(*args, **kwargs) |
| 54 | + self._batched_op = True |
| 55 | + self._accelerator = 'cuda' |
| 56 | + |
| 57 | + if pattern is None: |
| 58 | + self.pattern = r'Human: (.*?)\nAssistant: (.*?)(?=\nHuman|$)' |
| 59 | + else: |
| 60 | + self.pattern = pattern |
| 61 | + |
| 62 | + self.qa_format = qa_format |
| 63 | + self.model_key = prepare_model(model_type='huggingface', |
| 64 | + pretrained_model_name_or_path=hf_model) |
| 65 | + |
| 66 | + def _extract_qa(self, output): |
| 67 | + """Extract qestion and answer pair from model output response.""" |
| 68 | + qa_list = [] |
| 69 | + |
| 70 | + pat = re.compile(self.pattern, re.DOTALL) |
| 71 | + qa_pairs = pat.findall(output) |
| 72 | + |
| 73 | + for _, qa in enumerate(qa_pairs, 1): |
| 74 | + user, assistant = qa |
| 75 | + qa_list.append((user.strip(), assistant.strip())) |
| 76 | + |
| 77 | + return qa_list |
| 78 | + |
| 79 | + def process(self, sample, rank=None): |
| 80 | + self.model, self.processor = get_model(self.model_key, rank=rank) |
| 81 | + |
| 82 | + inputs = self.processor(sample[self.text_key], |
| 83 | + return_tensors='pt').to(self.model.device) |
| 84 | + response = self.model.generate(**inputs) |
| 85 | + output = self.processor.decode(response.cpu()[0], |
| 86 | + skip_special_tokens=True) |
| 87 | + qa_list = self._extract_qa(output) |
| 88 | + |
| 89 | + if not len(qa_list): |
| 90 | + logging.info( |
| 91 | + 'No question and answer data was extracted from this sample!') |
| 92 | + |
| 93 | + dialogue_data = [] |
| 94 | + if self.qa_format == 'chatml': |
| 95 | + for qa in qa_list: |
| 96 | + dialogue_data.append({ |
| 97 | + 'messages': [{ |
| 98 | + 'role': 'user', |
| 99 | + 'content': qa[0] |
| 100 | + }, { |
| 101 | + 'role': 'assistant', |
| 102 | + 'content': qa[1] |
| 103 | + }] |
| 104 | + }) |
| 105 | + else: |
| 106 | + raise ValueError(f'Not support {self.qa_format}!') |
| 107 | + |
| 108 | + sample[self.text_key] = json.dumps(dialogue_data, ensure_ascii=False) |
| 109 | + |
| 110 | + return sample |
0 commit comments