|
| 1 | +import argparse |
| 2 | +import pickle |
| 3 | +from tqdm import tqdm |
| 4 | +from collections import defaultdict |
| 5 | +from datasets import Dataset, DatasetDict, load_dataset |
| 6 | +from project.dataset.process import DataProcessUtil |
| 7 | +from project.util.path_util import PathUtil |
| 8 | +from project.util.logs_util import LogsUtil |
| 9 | + |
| 10 | +logger = LogsUtil.get_logs_util() |
| 11 | + |
| 12 | +DATA_TYPE_TRAIN = "train" |
| 13 | +DATA_TYPE_VALID = "validation" |
| 14 | +DATA_TYPE_TEST = "test" |
| 15 | + |
| 16 | + |
| 17 | +def convert_data(version: str): |
| 18 | + dataset = load_dataset("json", data_files=PathUtil.datasets(f"github-code-java-libs-{version}.json")) |
| 19 | + dataset.save_to_disk(PathUtil.datasets(f"{version}/raw-github-code-java-libs")) |
| 20 | + |
| 21 | + |
| 22 | +def process_data(version: str, input_r: str, label_r: str): |
| 23 | + dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/raw-github-code-java-libs")) |
| 24 | + tokenized_dataset = dataset.map( |
| 25 | + lambda x: DataProcessUtil.preprocess_function_with_connect(x, input_r=input_r, label_r=label_r), |
| 26 | + batched=True, |
| 27 | + load_from_cache_file=False, |
| 28 | + ) |
| 29 | + tokenized_dataset.save_to_disk(PathUtil.datasets(f"{version}/processed-github-code-java-libs")) |
| 30 | + |
| 31 | + |
| 32 | +def filter_with_token_length(version: str, max_length: int = 384): |
| 33 | + dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/processed-github-code-java-libs")) |
| 34 | + filter_dataset = dataset.filter( |
| 35 | + lambda x: x["labels_token_length"] <= max_length and x["input_token_length"] <= max_length |
| 36 | + ) |
| 37 | + filter_dataset.save_to_disk(PathUtil.datasets(f"{version}/filter-github-code-java-libs")) |
| 38 | + analyse_data(version, "train", filter_dataset, is_limit=True) |
| 39 | + |
| 40 | + |
| 41 | +def analyse_data(version: str, typ: str, dataset, is_limit=False): |
| 42 | + analysis = defaultdict(int) |
| 43 | + jdk_sdk_analysis = defaultdict(int) |
| 44 | + for item in tqdm(dataset[typ]): |
| 45 | + libs = item["libraries"] |
| 46 | + for lib in libs: |
| 47 | + if any(lib.startswith(_) for _ in ("java", "javax", "android", "androidx")): |
| 48 | + jdk_sdk_analysis[lib] += 1 |
| 49 | + continue |
| 50 | + analysis[lib] += 1 |
| 51 | + analysis = sorted(analysis.items(), key=lambda x: x[1], reverse=True) |
| 52 | + jdk_sdk_analysis = sorted(jdk_sdk_analysis.items(), key=lambda x: x[1], reverse=True) |
| 53 | + analysis += jdk_sdk_analysis |
| 54 | + if is_limit: |
| 55 | + with open(PathUtil.datasets(f"{version}/count_lib.bin"), "wb") as file: |
| 56 | + pickle.dump({lib: count for lib, count in analysis}, file) |
| 57 | + with open(PathUtil.datasets(f"{version}/{typ}-github-code-java-libs.txt"), "w") as file: |
| 58 | + for lib, count in analysis: |
| 59 | + file.write(lib + ", " + str(count) + "\n") |
| 60 | + |
| 61 | + |
| 62 | +def check_data(args, console_only: bool = False, do_analyse: bool = False): |
| 63 | + dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{args.version}/{args.filename}")) |
| 64 | + print(dataset) |
| 65 | + if do_analyse: |
| 66 | + # 数据集分析,三方库频次统计 |
| 67 | + analyse_data(args.version, DATA_TYPE_TRAIN, dataset) |
| 68 | + analyse_data(args.version, DATA_TYPE_VALID, dataset) |
| 69 | + analyse_data(args.version, DATA_TYPE_TEST, dataset) |
| 70 | + for i in range(args.check_size): |
| 71 | + if console_only: |
| 72 | + continue |
| 73 | + logger.info("libraries=" + dataset[DATA_TYPE_TEST][i]["comment"]) |
| 74 | + logger.info("libraries=" + dataset[DATA_TYPE_TEST][i]["decoded_labels"]) |
| 75 | + logger.info("libraries=" + dataset[DATA_TYPE_TEST][i]["decoded_preds"]) |
| 76 | + |
| 77 | + |
| 78 | +def split_data(version: str, ration: float = 0.02, test_size: int = None): |
| 79 | + dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/filter-github-code-java-libs"))["train"] |
| 80 | + # dataset = dataset.train_test_split(test_size=100)["test"] |
| 81 | + if test_size is None: |
| 82 | + test_size = len(dataset) * ration |
| 83 | + with open(PathUtil.datasets(f"{version}/count_lib.bin"), "rb") as f: |
| 84 | + count_lib = pickle.load(f) |
| 85 | + validation_dataset, train_dataset, test_dataset = [], [], [] |
| 86 | + lib_count_4_validation = {_: 0 for _ in count_lib.copy().keys()} |
| 87 | + lib_count_4_test = lib_count_4_validation.copy() |
| 88 | + nl_set_4_validation, nl_set_4_test = set(), set() |
| 89 | + for row in tqdm(dataset): |
| 90 | + lib_size = len(row["libraries"]) |
| 91 | + # 按库划分数据集 |
| 92 | + is_append_validation = False |
| 93 | + for lib in row["libraries"]: |
| 94 | + if lib_count_4_validation[lib] >= count_lib[lib] * ration: |
| 95 | + break |
| 96 | + # 按优先级过滤JDK&SDK |
| 97 | + if lib == "jdk" and lib_size > 1 or lib == "sdk" and lib_size > 2: |
| 98 | + continue |
| 99 | + lib_count_4_validation[lib] += 1 |
| 100 | + is_append_validation = True |
| 101 | + if is_append_validation: |
| 102 | + validation_dataset.append(row) |
| 103 | + continue |
| 104 | + is_append_test = False |
| 105 | + for lib in row["libraries"]: |
| 106 | + if lib_count_4_test[lib] >= count_lib[lib] * ration: |
| 107 | + break |
| 108 | + # 按优先级过滤JDK&SDK |
| 109 | + if lib == "jdk" and lib_size > 1 or lib == "sdk" and lib_size > 2: |
| 110 | + continue |
| 111 | + lib_count_4_test[lib] += 1 |
| 112 | + is_append_test = True |
| 113 | + if is_append_test: |
| 114 | + test_dataset.append(row) |
| 115 | + continue |
| 116 | + # 同NL采集 |
| 117 | + if row["comment"] in nl_set_4_validation or row["comment"] in nl_set_4_test: |
| 118 | + logger.info(row["comment"]) |
| 119 | + continue |
| 120 | + train_dataset.append(row) |
| 121 | + dataset = DatasetDict() |
| 122 | + dataset["train"] = Dataset.from_list(train_dataset) |
| 123 | + dataset["validation"] = Dataset.from_list(validation_dataset) |
| 124 | + dataset["test"] = Dataset.from_list(test_dataset) |
| 125 | + dataset.save_to_disk(PathUtil.datasets(f"{version}/github-code-java-libs")) |
| 126 | + |
| 127 | + |
| 128 | +def postprocess_data(args, saved_version: str, input_r: str, label_r: str): |
| 129 | + dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{args.version}/{args.filename}")) |
| 130 | + tokenized_dataset = dataset.map( |
| 131 | + lambda x: DataProcessUtil.preprocess_function_with_connect(x, input_r=input_r, label_r=label_r), batched=True |
| 132 | + ) |
| 133 | + tokenized_dataset.save_to_disk(PathUtil.datasets(f"{saved_version}/github-code-java-libs")) |
| 134 | + |
| 135 | +def add_args(parser): |
| 136 | + parser.add_argument( |
| 137 | + "--task", type=str, required=True, choices=["convert", "process", "filter", "split", "check", "postprocess"] |
| 138 | + ) |
| 139 | + parser.add_argument('--version', type=str, help="The version of datasets.") |
| 140 | + parser.add_argument('--filename', type=str, help="The filename of dataset.") |
| 141 | + parser.add_argument('--check_size', type=int, help="Size of checking.") |
| 142 | + parser.add_argument('--input', default=None, type=str, help="Type of input.") |
| 143 | + parser.add_argument('--label', default=None, type=str, help="Type of label.") |
| 144 | + |
| 145 | + parser.add_argument('--upper', default=None, type=int, help="Max of the lib count in test split.") |
| 146 | + parser.add_argument('--test_size', default=4000, type=int, help="Size of the test split.") |
| 147 | + parser.add_argument('--saved_version', default=None, type=str, help="The version of postprocess datasets.") |
| 148 | + args = parser.parse_args() |
| 149 | + return args |
| 150 | + |
| 151 | +if __name__ == "__main__": |
| 152 | + versions = { |
| 153 | + "l0": ("nl", "code"), |
| 154 | + "l1": ("nl+libs", "code"), |
| 155 | + "l2": ("nl+libs+codeRet", "code"), |
| 156 | + "l3": ("nl+libs+importsGen", "code"), |
| 157 | + "l4": ("nl+libs+importsGen+codeRet", "code"), |
| 158 | + "l5": ("nl+libs", "imports"), |
| 159 | + "l6": ("nl+libs+importsRet", "imports") |
| 160 | + } |
| 161 | + |
| 162 | + parser = argparse.ArgumentParser() |
| 163 | + args = add_args(parser) |
| 164 | + logger.info(args) |
| 165 | + |
| 166 | + if args.task == "convert": |
| 167 | + # step 1 |
| 168 | + convert_data(version=args.version) |
| 169 | + elif args.task == "process": |
| 170 | + # step 2 |
| 171 | + process_data(version=args.version, input_r=args.input, label_r=args.label) |
| 172 | + elif args.task == "filter": |
| 173 | + # step 3 |
| 174 | + filter_with_token_length(version=args.version) |
| 175 | + elif args.task == "filter" and args.upper != None: |
| 176 | + # step 4 |
| 177 | + filter_with_upper(version=args.version, upper=args.upper) |
| 178 | + elif args.task == "split": |
| 179 | + # step 5 |
| 180 | + split_data(version=args.version, test_size=args.test_size) |
| 181 | + elif args.task == "check": |
| 182 | + # step 6 |
| 183 | + check_data(args=args) |
| 184 | + elif args.task == "postprocess": |
| 185 | + # other step |
| 186 | + input_label = versions[args.saved_version[:2]] |
| 187 | + postprocess_data(args, saved_version=args.saved_version, input_r=input_label[0], label_r=input_label[1]) |
0 commit comments