-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_dataset.py
122 lines (103 loc) · 4.52 KB
/
test_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import logging
import os
from dataclasses import dataclass
from typing import Optional, Dict, Sequence, Union, List
import datasets
import torch
import logging
from datasets import load_dataset, concatenate_datasets
import copy
import transformers
import random
IGNORE_INDEX = -100
logger = logging.getLogger('__name__')
PROMPT_TEMPLATE = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response: "
)
def buid_instruction_dataset(data_path: Union[List[str],str],
tokenizer: transformers.PreTrainedTokenizer,
max_seq_length: int, data_cache_dir = None,
preprocessing_num_workers = None,
):
def tokenization(examples):
sources = []
targets = []
# prompt = PROMPT_TEMPLATE
for instruction, input, output in zip(examples['instruct'],examples['query'],examples['answer']):
if input is not None and input !="":
instruction = instruction+'\n'+input
# source = prompt.format_map({'instruction': instruction})
source = instruction
target = f"{tokenizer.bos_token}{output}{tokenizer.eos_token}"
sources.append(source)
targets.append(target)
tokenized_sources = tokenizer(sources,return_attention_mask=False, add_special_tokens=False)
tokenized_targets = tokenizer(targets,return_attention_mask=False, add_special_tokens=False)
print(tokenized_targets)
all_input_ids = []
all_labels = []
for s,t in zip(tokenized_sources['input_ids'],tokenized_targets['input_ids']):
s = s + [tokenizer.gmask_token_id]
input_ids = torch.LongTensor(s + t)[:max_seq_length]
labels = torch.LongTensor([IGNORE_INDEX] * len(s) + t)[:max_seq_length]
assert len(input_ids) == len(labels)
all_input_ids.append(input_ids)
all_labels.append(labels)
results = {'input_ids':all_input_ids, 'labels': all_labels}
return results
logging.warning("building dataset...")
all_datasets = []
if not isinstance(data_path,(list,tuple)):
data_path = [data_path]
for file in data_path:
if data_cache_dir is None:
data_cache_dir = str(os.path.dirname(file))
cache_path = os.path.join(data_cache_dir,os.path.basename(file).split('.')[0])
os.makedirs(cache_path, exist_ok=True)
try:
processed_dataset = datasets.load_from_disk(cache_path)
logger.info(f'training datasets-{file} has been loaded from disk')
except Exception:
print(file)
raw_dataset = load_dataset("json", data_files=file, cache_dir=cache_path)
print(raw_dataset)
tokenization_func = tokenization
tokenized_dataset = raw_dataset.map(
tokenization_func,
batched=True,
num_proc=preprocessing_num_workers,
remove_columns=["instruct","query","answer"],
keep_in_memory=False,
desc="preprocessing on dataset",
)
processed_dataset = tokenized_dataset
processed_dataset.save_to_disk(cache_path)
processed_dataset.set_format('torch')
all_datasets.append(processed_dataset['train'])
all_datasets = concatenate_datasets(all_datasets)
return all_datasets
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids = instances["input_ids"]
labels = instances["labels"]
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
return dict(
input_ids=input_ids,
labels=labels,
)
if __name__ == "__main__":
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("model_hub/chatglm-6b", trust_remote_code=True)
all_datasets = buid_instruction_dataset(["data/msra/train.txt"], tokenizer, max_seq_length=256)
print(all_datasets[0])
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
data = data_collator(all_datasets[:2])
print(data)