Skip to content

Commit 314beff

Browse files
committed
init
1 parent 22b1a9d commit 314beff

File tree

7 files changed

+513
-5
lines changed

7 files changed

+513
-5
lines changed

README.md

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,37 @@
11
# CodeGen4Libs
2-
We will release the complete project soon.
2+
3+
### Benchmark Format
4+
Benchmark has been meticulously structured and saved in the DatasetDict format, accessible at [Dataset and Models of CodeGen4Libs](https://zenodo.org/record/7920906#.ZFyPm-xByDV). The specific data fields for each task are delineated as follows:
5+
6+
- id
7+
- method
8+
- clean_method
9+
- doc
10+
- comment
11+
- method_name
12+
- extra
13+
- license
14+
- path
15+
- repo_name
16+
- size
17+
- imports_info
18+
- libraries_info
19+
20+
- input_str
21+
- attention_mask
22+
- input_ids
23+
- tokenized_input_str
24+
- input_token_length
25+
- labels
26+
- tokenized_labels_str
27+
- labels_token_length
28+
29+
- retrieved_imports_info
30+
- generated_imports_info
31+
- union_gen_ret_imports_info
32+
- intersection_gen_ret_imports_info
33+
- similar_code
34+
35+
- decoded_labels
36+
- predictions
37+
- decoded_preds

data/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
https://zenodo.org/record/7920906#.ZFyPm-xByDV

dataset/README.md

Lines changed: 0 additions & 1 deletion
This file was deleted.

build_corpus.py renamed to generation/corpus_builder.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@
22
import jsonlines
33
import pickle
44
import string
5-
from collections import defaultdict
65
import random
7-
86
import re
7+
from collections import defaultdict
98
from tqdm import tqdm
109

1110
from script.extract_method import ProjectMethodExtractor
12-
1311
from sckg.util.path_util import PathUtil
1412
from sckg.util.log_util import LogUtil
1513

generation/dataset_filter.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import pickle
2+
from tqdm import tqdm
3+
from datasets import Dataset, DatasetDict
4+
from project.util.path_util import PathUtil
5+
from project.util.logs_util import LogsUtil
6+
7+
logger = LogsUtil.get_logs_util()
8+
9+
10+
def filter_with_upper(version: str, upper: int):
11+
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/filter-github-code-java-libs"))["train"]
12+
data = []
13+
other_data = []
14+
with open(PathUtil.datasets(f"{version}/count_lib.bin"), "rb") as f:
15+
count_lib = pickle.load(f)
16+
chose_libs = count_lib.keys()
17+
count_lib = {_: 0 for _ in count_lib.copy().keys()}
18+
for row in tqdm(dataset):
19+
lib_size = len(row["libraries"])
20+
is_append = False
21+
if all(
22+
any(lib.startswith(_) for _ in ("java.", "javax.", "android.", "androidx.")) for lib in row["libraries"]
23+
):
24+
other_data.append(row)
25+
continue
26+
if any(lib not in chose_libs for lib in row["libraries"]):
27+
continue
28+
for lib in row["libraries"]:
29+
if count_lib[lib] >= upper:
30+
continue
31+
# 按优先级过滤JDK&SDK
32+
if lib == "jdk" and lib_size > 1 or lib == "sdk" and lib_size > 2:
33+
continue
34+
if not any(lib.startswith(_) for _ in ("java.", "javax.", "android.", "androidx.")):
35+
count_lib[lib] += 1
36+
is_append = True
37+
if is_append:
38+
data.append(row)
39+
else:
40+
other_data.append(row)
41+
dataset = DatasetDict()
42+
dataset["train"] = Dataset.from_list(data)
43+
dataset.save_to_disk(PathUtil.datasets(f"{version}_{upper}/filter-github-code-java-libs"))
44+
45+
other_dataset = DatasetDict()
46+
other_dataset["train"] = Dataset.from_list(other_data)
47+
other_dataset.save_to_disk(PathUtil.datasets(f"{version}_{upper}/other-github-code-java-libs"))
48+
with open(PathUtil.datasets(f"{version}_{upper}/train-github-code-java-libs.txt"), "w") as file:
49+
for lib, count in count_lib.items():
50+
file.write(lib + ", " + str(count) + "\n")
51+
with open(PathUtil.datasets(f"{version}_{upper}/count_lib.bin"), "wb") as file:
52+
pickle.dump({lib: count for lib, count in count_lib.items()}, file)
53+
54+
55+
def split_data(version: str, ration: float = 0.02, test_size: int = None):
56+
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/filter-github-code-java-libs"))["train"]
57+
if test_size is None:
58+
test_size = len(dataset) * ration
59+
with open(PathUtil.datasets(f"{version}/count_lib.bin"), "rb") as f:
60+
count_lib = pickle.load(f)
61+
validation_dataset, train_dataset, test_dataset = [], [], []
62+
lib_count_4_validation = {_: 0 for _ in count_lib.copy().keys()}
63+
lib_count_4_test = lib_count_4_validation.copy()
64+
nl_set_4_validation, nl_set_4_test = set(), set()
65+
for row in tqdm(dataset):
66+
lib_size = len(row["libraries"])
67+
# 按库划分数据集
68+
is_append_validation = False
69+
for lib in row["libraries"]:
70+
if lib not in count_lib:
71+
continue
72+
if lib_count_4_validation[lib] >= count_lib[lib] * ration:
73+
break
74+
# 按优先级过滤JDK&SDK
75+
if lib == "jdk" and lib_size > 1 or lib == "sdk" and lib_size > 2:
76+
continue
77+
lib_count_4_validation[lib] += 1
78+
is_append_validation = True
79+
if is_append_validation:
80+
validation_dataset.append(row)
81+
nl_set_4_validation.add(row["comment"] + row["libraries_info"])
82+
continue
83+
is_append_test = False
84+
for lib in row["libraries"]:
85+
if lib not in count_lib:
86+
continue
87+
if lib_count_4_test[lib] >= count_lib[lib] * ration:
88+
break
89+
# 按优先级过滤JDK&SDK
90+
if lib == "jdk" and lib_size > 1 or lib == "sdk" and lib_size > 2:
91+
continue
92+
lib_count_4_test[lib] += 1
93+
is_append_test = True
94+
if is_append_test:
95+
test_dataset.append(row)
96+
nl_set_4_test.add(row["comment"] + row["libraries_info"])
97+
continue
98+
# 同NL采集
99+
if (
100+
row["comment"] + row["libraries_info"] in nl_set_4_validation
101+
or row["comment"] + row["libraries_info"] in nl_set_4_test
102+
):
103+
logger.info(row["comment"] + row["libraries_info"])
104+
continue
105+
train_dataset.append(row)
106+
dataset = DatasetDict()
107+
dataset["train"] = Dataset.from_list(train_dataset)
108+
dataset["validation"] = Dataset.from_list(validation_dataset)
109+
dataset["test"] = Dataset.from_list(test_dataset)
110+
dataset.save_to_disk(PathUtil.datasets(f"{version}/github-code-java-libs"))
111+
112+
113+
def slim_data(version: str):
114+
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/github-code-java-libs"))
115+
def chunk_examples(examples):
116+
return {
117+
"input_ids": examples["input_ids"],
118+
"attention_mask": examples["attention_mask"],
119+
"labels": examples["labels"],
120+
}
121+
122+
dataset = dataset.map(chunk_examples, batched=True)
123+
dataset = dataset.map(chunk_examples, batched=True, remove_columns=dataset["train"].column_names)
124+
dataset.save_to_disk(PathUtil.datasets(f"{version}/slim-github-code-java-libs"))
125+
126+
127+
if __name__ == "__main__":
128+
# with open(PathUtil.datasets(f"latest_0,800000_5000/count_lib.bin"), "rb") as file:
129+
# data = pickle.load(file)
130+
# print(data)
131+
132+
# filter_with_upper("latest_400000,600000", 5000)
133+
134+
split_data("latest_0,400000_5000", ration=0.02)

generation/dataset_generator.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import argparse
2+
import pickle
3+
from tqdm import tqdm
4+
from collections import defaultdict
5+
from datasets import Dataset, DatasetDict, load_dataset
6+
from project.dataset.process import DataProcessUtil
7+
from project.util.path_util import PathUtil
8+
from project.util.logs_util import LogsUtil
9+
10+
logger = LogsUtil.get_logs_util()
11+
12+
DATA_TYPE_TRAIN = "train"
13+
DATA_TYPE_VALID = "validation"
14+
DATA_TYPE_TEST = "test"
15+
16+
17+
def convert_data(version: str):
18+
dataset = load_dataset("json", data_files=PathUtil.datasets(f"github-code-java-libs-{version}.json"))
19+
dataset.save_to_disk(PathUtil.datasets(f"{version}/raw-github-code-java-libs"))
20+
21+
22+
def process_data(version: str, input_r: str, label_r: str):
23+
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/raw-github-code-java-libs"))
24+
tokenized_dataset = dataset.map(
25+
lambda x: DataProcessUtil.preprocess_function_with_connect(x, input_r=input_r, label_r=label_r),
26+
batched=True,
27+
load_from_cache_file=False,
28+
)
29+
tokenized_dataset.save_to_disk(PathUtil.datasets(f"{version}/processed-github-code-java-libs"))
30+
31+
32+
def filter_with_token_length(version: str, max_length: int = 384):
33+
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/processed-github-code-java-libs"))
34+
filter_dataset = dataset.filter(
35+
lambda x: x["labels_token_length"] <= max_length and x["input_token_length"] <= max_length
36+
)
37+
filter_dataset.save_to_disk(PathUtil.datasets(f"{version}/filter-github-code-java-libs"))
38+
analyse_data(version, "train", filter_dataset, is_limit=True)
39+
40+
41+
def analyse_data(version: str, typ: str, dataset, is_limit=False):
42+
analysis = defaultdict(int)
43+
jdk_sdk_analysis = defaultdict(int)
44+
for item in tqdm(dataset[typ]):
45+
libs = item["libraries"]
46+
for lib in libs:
47+
if any(lib.startswith(_) for _ in ("java", "javax", "android", "androidx")):
48+
jdk_sdk_analysis[lib] += 1
49+
continue
50+
analysis[lib] += 1
51+
analysis = sorted(analysis.items(), key=lambda x: x[1], reverse=True)
52+
jdk_sdk_analysis = sorted(jdk_sdk_analysis.items(), key=lambda x: x[1], reverse=True)
53+
analysis += jdk_sdk_analysis
54+
if is_limit:
55+
with open(PathUtil.datasets(f"{version}/count_lib.bin"), "wb") as file:
56+
pickle.dump({lib: count for lib, count in analysis}, file)
57+
with open(PathUtil.datasets(f"{version}/{typ}-github-code-java-libs.txt"), "w") as file:
58+
for lib, count in analysis:
59+
file.write(lib + ", " + str(count) + "\n")
60+
61+
62+
def check_data(args, console_only: bool = False, do_analyse: bool = False):
63+
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{args.version}/{args.filename}"))
64+
print(dataset)
65+
if do_analyse:
66+
# 数据集分析,三方库频次统计
67+
analyse_data(args.version, DATA_TYPE_TRAIN, dataset)
68+
analyse_data(args.version, DATA_TYPE_VALID, dataset)
69+
analyse_data(args.version, DATA_TYPE_TEST, dataset)
70+
for i in range(args.check_size):
71+
if console_only:
72+
continue
73+
logger.info("libraries=" + dataset[DATA_TYPE_TEST][i]["comment"])
74+
logger.info("libraries=" + dataset[DATA_TYPE_TEST][i]["decoded_labels"])
75+
logger.info("libraries=" + dataset[DATA_TYPE_TEST][i]["decoded_preds"])
76+
77+
78+
def split_data(version: str, ration: float = 0.02, test_size: int = None):
79+
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{version}/filter-github-code-java-libs"))["train"]
80+
# dataset = dataset.train_test_split(test_size=100)["test"]
81+
if test_size is None:
82+
test_size = len(dataset) * ration
83+
with open(PathUtil.datasets(f"{version}/count_lib.bin"), "rb") as f:
84+
count_lib = pickle.load(f)
85+
validation_dataset, train_dataset, test_dataset = [], [], []
86+
lib_count_4_validation = {_: 0 for _ in count_lib.copy().keys()}
87+
lib_count_4_test = lib_count_4_validation.copy()
88+
nl_set_4_validation, nl_set_4_test = set(), set()
89+
for row in tqdm(dataset):
90+
lib_size = len(row["libraries"])
91+
# 按库划分数据集
92+
is_append_validation = False
93+
for lib in row["libraries"]:
94+
if lib_count_4_validation[lib] >= count_lib[lib] * ration:
95+
break
96+
# 按优先级过滤JDK&SDK
97+
if lib == "jdk" and lib_size > 1 or lib == "sdk" and lib_size > 2:
98+
continue
99+
lib_count_4_validation[lib] += 1
100+
is_append_validation = True
101+
if is_append_validation:
102+
validation_dataset.append(row)
103+
continue
104+
is_append_test = False
105+
for lib in row["libraries"]:
106+
if lib_count_4_test[lib] >= count_lib[lib] * ration:
107+
break
108+
# 按优先级过滤JDK&SDK
109+
if lib == "jdk" and lib_size > 1 or lib == "sdk" and lib_size > 2:
110+
continue
111+
lib_count_4_test[lib] += 1
112+
is_append_test = True
113+
if is_append_test:
114+
test_dataset.append(row)
115+
continue
116+
# 同NL采集
117+
if row["comment"] in nl_set_4_validation or row["comment"] in nl_set_4_test:
118+
logger.info(row["comment"])
119+
continue
120+
train_dataset.append(row)
121+
dataset = DatasetDict()
122+
dataset["train"] = Dataset.from_list(train_dataset)
123+
dataset["validation"] = Dataset.from_list(validation_dataset)
124+
dataset["test"] = Dataset.from_list(test_dataset)
125+
dataset.save_to_disk(PathUtil.datasets(f"{version}/github-code-java-libs"))
126+
127+
128+
def postprocess_data(args, saved_version: str, input_r: str, label_r: str):
129+
dataset = DatasetDict.load_from_disk(PathUtil.datasets(f"{args.version}/{args.filename}"))
130+
tokenized_dataset = dataset.map(
131+
lambda x: DataProcessUtil.preprocess_function_with_connect(x, input_r=input_r, label_r=label_r), batched=True
132+
)
133+
tokenized_dataset.save_to_disk(PathUtil.datasets(f"{saved_version}/github-code-java-libs"))
134+
135+
def add_args(parser):
136+
parser.add_argument(
137+
"--task", type=str, required=True, choices=["convert", "process", "filter", "split", "check", "postprocess"]
138+
)
139+
parser.add_argument('--version', type=str, help="The version of datasets.")
140+
parser.add_argument('--filename', type=str, help="The filename of dataset.")
141+
parser.add_argument('--check_size', type=int, help="Size of checking.")
142+
parser.add_argument('--input', default=None, type=str, help="Type of input.")
143+
parser.add_argument('--label', default=None, type=str, help="Type of label.")
144+
145+
parser.add_argument('--upper', default=None, type=int, help="Max of the lib count in test split.")
146+
parser.add_argument('--test_size', default=4000, type=int, help="Size of the test split.")
147+
parser.add_argument('--saved_version', default=None, type=str, help="The version of postprocess datasets.")
148+
args = parser.parse_args()
149+
return args
150+
151+
if __name__ == "__main__":
152+
versions = {
153+
"l0": ("nl", "code"),
154+
"l1": ("nl+libs", "code"),
155+
"l2": ("nl+libs+codeRet", "code"),
156+
"l3": ("nl+libs+importsGen", "code"),
157+
"l4": ("nl+libs+importsGen+codeRet", "code"),
158+
"l5": ("nl+libs", "imports"),
159+
"l6": ("nl+libs+importsRet", "imports")
160+
}
161+
162+
parser = argparse.ArgumentParser()
163+
args = add_args(parser)
164+
logger.info(args)
165+
166+
if args.task == "convert":
167+
# step 1
168+
convert_data(version=args.version)
169+
elif args.task == "process":
170+
# step 2
171+
process_data(version=args.version, input_r=args.input, label_r=args.label)
172+
elif args.task == "filter":
173+
# step 3
174+
filter_with_token_length(version=args.version)
175+
elif args.task == "filter" and args.upper != None:
176+
# step 4
177+
filter_with_upper(version=args.version, upper=args.upper)
178+
elif args.task == "split":
179+
# step 5
180+
split_data(version=args.version, test_size=args.test_size)
181+
elif args.task == "check":
182+
# step 6
183+
check_data(args=args)
184+
elif args.task == "postprocess":
185+
# other step
186+
input_label = versions[args.saved_version[:2]]
187+
postprocess_data(args, saved_version=args.saved_version, input_r=input_label[0], label_r=input_label[1])

0 commit comments

Comments
 (0)