Skip to content

Commit 479dead

Browse files
committed
Add eval method
1 parent be244d0 commit 479dead

14 files changed

+387
-98
lines changed

README.md

+14
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ Code-Prediction-Transformer is CC-BY-NC 4.0 (Attr Non-Commercial Inter.) (e.g.,
8585

8686
## Vanilla
8787

88+
### Preprocessing
89+
8890
1. generate_new_trees (nodes only have type/value) `python generate_new_trees.py -i PY150 -o NEW_TREES.json`
8991
2. generate_data (Splitting and Preorder Traversal) `python models/trav_trans/generate_data.py -a NEW_TREES.json -o DPS.TXT`
9092
3. generate_vocab (generate vocab files) `python generate_vocab.py -i NEW_TREES.json -o VOCAB.pkl -t ast`
@@ -93,6 +95,18 @@ Code-Prediction-Transformer is CC-BY-NC 4.0 (Attr Non-Commercial Inter.) (e.g.,
9395
5. Use torch.utils.data.DataLoader to pull batches from Dataset, using the dataset.collate function `dataloder = torch.utils.data.DataLoader(dataset, batch_size=X, collate_fn=lambda b: dataset.collate(b, setup.vocab.pad_idx))`
9496
6. Iterate through batches and feed to model?
9597

98+
### Evaluation
99+
100+
1. Iterate through Test dataset
101+
2. For each batch, get leaf_ids from "ids.txt"/"leaf_ids" which contains all type nodes that contain a value leaf node
102+
3. Make a model prediction for id-1 to predict the type and then for id to predict the value
103+
4. Check for "special" nodes, e.g. type "attr" belongs to the special type "attribute access" instead of leaf node prediciton
104+
- Attribute Access: `attr`
105+
- Numeric Constant: `Arithmetic expression (expr) will be either a numeric constant called Const`
106+
- Name (variable, module): `Nameload/Namestore`
107+
- Function parameter name: `Nameload`
108+
5. Calculate the MRR for all predictions and broken down into the four special types
109+
96110
## HuggingFace
97111

98112
1. generate_new_trees (nodes only have type/value)

__pycache__/model.cpython-38.pyc

0 Bytes
Binary file not shown.

evaluate.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import argparse
2+
import model
3+
import torch
4+
from tqdm import tqdm
5+
from models.trav_trans import dataset
6+
7+
def generate_test(model, context, device, depth=2, top_k=10):
8+
model.eval()
9+
with torch.no_grad():
10+
context = torch.tensor(context).to(device)
11+
output = model(context, None)[-1]
12+
top_k_values, top_k_indices = torch.topk(output, top_k)
13+
return top_k_values, top_k_indices
14+
15+
def main():
16+
parser = argparse.ArgumentParser(description="Evaluate GPT2 Model")
17+
parser.add_argument("--model", help="Specify the model file")
18+
parser.add_argument("--dps", help="Specify the data file (dps) on which the model should be tested on")
19+
parser.add_argument("--ids", help="Specify the data file (ids) on which the model should be tested on")
20+
parser.add_argument("--vocab", help="Specify the vocab file")
21+
parser.add_argument("--batch_size", default=1, type=int, help="Specify the batch size")
22+
23+
args = parser.parse_args()
24+
25+
setup = dataset.Setup("output", args.dps, args.ids, mode="test")
26+
27+
m = model.from_file("output/model-8.pt", setup.vocab)
28+
29+
dataloader = torch.utils.data.DataLoader(
30+
setup.dataset,
31+
batch_size = args.batch_size,
32+
collate_fn = lambda b: dataset.Dataset.collate(b, setup.vocab.pad_idx)
33+
)
34+
vocab = setup.vocab
35+
36+
eval(m, dataloader)
37+
38+
def eval(model, dataloader):
39+
print("Evaluating {} batches".format(len(dataloader)))
40+
reciprocal_rank = {
41+
"all_leaf_tokens": [],
42+
"attribute_access": [],
43+
"numeric_constant": [],
44+
"variable_name": [],
45+
"function_parameter_name": []
46+
}
47+
48+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49+
model = model.to(device)
50+
model.eval()
51+
for i, batch in tqdm(enumerate(dataloader)):
52+
if i % 100 == 0:
53+
print("Batch {}".format(i))
54+
x = batch["input_seq"][0]
55+
y = batch["target_seq"][0]
56+
ids = batch["ids"]["leaf_ids"]
57+
58+
for id in ids:
59+
if id > 0:
60+
y_type = x[id].item()
61+
y_value = y[id].item()
62+
63+
with torch.no_grad():
64+
y_type_pred = generate_test(model, [i.item() for i in x[range(id)]], device)
65+
y_value_pred = generate_test(model, [i.item() for i in x[range(id + 1)]], device)
66+
67+
type_rank = 0
68+
value_rank = 0
69+
70+
if y_type in y_type_pred[1]:
71+
type_rank = 1 / ((y_type_pred[1] == y_type).nonzero(as_tuple=True)[0].item() + 1)
72+
if y_value in y_value_pred[1]:
73+
value_rank = 1 / ((y_value_pred[1] == y_value).nonzero(as_tuple=True)[0].item() + 1)
74+
reciprocal_rank["all_leaf_tokens"].append((type_rank + value_rank) / 2)
75+
76+
if __name__ == "__main__":
77+
main()

generate.py

-13
This file was deleted.

generate_graph.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ def addChildren(i, data, graph):
1010
if "value" in data[c] and "type" in data[c]:
1111
graph.add_node(pydot.Node(c, label=data[c]["type"] + "\n{}".format(data[c]["value"])))
1212
elif "value" in data[c] and not "type" in data[c]:
13-
graph.add_node(pydot.Node(c, label=data[c]["value"]))
13+
graph.add_node(pydot.Node(c, label="{}\n".format(c) + data[c]["value"]))
1414
elif "value" not in data[c] and "type" in data[c]:
15-
graph.add_node(pydot.Node(c, label=data[c]["type"]))
15+
graph.add_node(pydot.Node(c, label="{}\n".format(c) + data[c]["type"]))
1616
graph.add_edge(pydot.Edge(i, c, color="blue"))
1717
addChildren(c, data, graph)
1818

generate_new_trees.py

+11
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,17 @@ def convert(ast):
4747
assert len(children) == len(set(children))
4848
return new_dp
4949

50+
def external(file_path, suffix):
51+
outfile = "output/{}_new_trees.json".format(suffix)
52+
if os.path.exists(outfile):
53+
os.remove(outfile)
54+
logging.info("Loading asts from: {}".format(file_path))
55+
with open(file_path, "r") as f, open(outfile, "w") as fout:
56+
for line in file_tqdm(f):
57+
dp = json.loads(line.strip())
58+
print(json.dumps(convert(dp)), file=fout)
59+
logging.info("Wrote dps to: {}".format(outfile))
60+
5061

5162
def main():
5263
parser = argparse.ArgumentParser(description="Generate datapoints from AST")

generate_vocab.py

+30
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,36 @@ def get_value(line, input_type):
2929
elif input_type == "source_code":
3030
return line[0]
3131

32+
def external(file_path, n_vocab):
33+
outfile = "output/vocab.pkl"
34+
logging.info("Reading from: {}".format(file_path))
35+
vocab = Counter()
36+
with open(file_path, "r") as f:
37+
for line in file_tqdm(f):
38+
vocab.update(get_value(json.loads(line.strip()), "ast"))
39+
vocab_to_keep = [i[0] for i in vocab.most_common(n_vocab)]
40+
top_total = sum(i[1] for i in vocab.most_common(n_vocab))
41+
total = sum(vocab.values())
42+
43+
logging.info("Total # of vocab: {}".format(len(vocab)))
44+
logging.info(
45+
"Using {} top vocab covers: {:.2f}% of the entire dataset".format(
46+
n_vocab, 100 * top_total / total
47+
)
48+
)
49+
logging.info("Top 10 most common vocab:")
50+
for v, i in vocab.most_common(10):
51+
print(v, i)
52+
53+
# add unk and pad tokens
54+
vocab_to_keep.append(UNK)
55+
vocab_to_keep.append(PAD)
56+
logging.info("Added {} and {}".format(UNK, PAD))
57+
58+
# dump vocab to file
59+
with open(outfile, "wb") as fout:
60+
pickle.dump(vocab_to_keep, fout)
61+
logging.info("Wrote {} vocab to: {}".format(len(vocab_to_keep), outfile))
3262

3363
def main():
3464
parser = argparse.ArgumentParser(description="Create vocab for py150 dataset")

generator.py

-30
This file was deleted.
Binary file not shown.

models/trav_trans/generate_ast_ids.py

+26
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,32 @@ def get_type_ids(ast):
7070
return ids
7171

7272

73+
def external(file_path, suffix, n_ctx):
74+
outfile = "output/{}_ids.txt".format(suffix)
75+
76+
if os.path.exists(outfile):
77+
os.remove(outfile)
78+
logging.info("Type of id to get: {}".format("leaf"))
79+
80+
logging.info("Loading dps from: {}".format(file_path))
81+
with open(file_path, "r") as f, open(outfile, "w") as fout:
82+
for line in file_tqdm(f):
83+
dp = json.loads(line.strip())
84+
asts = separate_dps(dp, n_ctx)
85+
for ast, _ in asts:
86+
ids = {}
87+
if len(ast) > 1:
88+
if "leaf" in {"leaf", "all"}:
89+
ids.update(get_leaf_ids(ast))
90+
if "leaf" in {"value", "all"}:
91+
ids.update(get_value_ids(ast))
92+
if "leaf" in {"type", "all"}:
93+
ids.update(get_type_ids(ast))
94+
95+
json.dump(ids, fp=fout)
96+
fout.write("\n")
97+
logging.info("Wrote to: {}".format(outfile))
98+
7399
def main():
74100
parser = argparse.ArgumentParser(
75101
description="Generate ids (leaf, values, types) from AST"

models/trav_trans/generate_data.py

+19
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,25 @@
1515

1616
logging.basicConfig(level=logging.INFO)
1717

18+
def external(file_path, suffix, context_size):
19+
outfile = "output/{}_dps.txt".format(suffix)
20+
if os.path.exists(outfile):
21+
os.remove(outfile)
22+
logging.info("Number of context: {}".format(context_size))
23+
24+
num_dps = 0
25+
logging.info("Loading asts from: {}".format(file_path))
26+
with open(file_path, "r") as f, open(outfile, "w") as fout:
27+
for line in file_tqdm(f):
28+
dp = json.loads(line.strip())
29+
asts = separate_dps(dp, context_size)
30+
for ast, extended in asts:
31+
if len(ast) > 1:
32+
json.dump([get_dfs(ast), extended], fp=fout)
33+
fout.write("\n")
34+
num_dps += 1
35+
36+
logging.info("Wrote {} datapoints to {}".format(num_dps, outfile))
1837

1938
def main():
2039
parser = argparse.ArgumentParser(description="Generate datapoints from AST")

notebook.ipynb

+176-53
Large diffs are not rendered by default.

preprocess.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
import argparse
3+
4+
import generate_new_trees
5+
import generate_vocab
6+
from models.trav_trans import generate_data, generate_ast_ids
7+
8+
def main():
9+
parser = argparse.ArgumentParser(description="Preprocess py150 train and eval files")
10+
parser.add_argument("--file_path", help="Specify py150 file path")
11+
parser.add_argument("--suffix", help="Specify suffix to determine between train/val/test files")
12+
parser.add_argument("--context_size", default=1000, type=int, help="Specify context size for slicing larger ASTs")
13+
parser.add_argument("--generate_vocab", default=False, type=bool, help="Specify wether or not to generate a vocab file")
14+
parser.add_argument("--n_vocab", default=100000, type=int, help="Specify the vocab size")
15+
16+
args = parser.parse_args()
17+
18+
# Generate new trees
19+
generate_new_trees.external(args.file_path, args.suffix)
20+
# Generate DPS
21+
generate_data.external("output/{}_new_trees.json".format(args.suffix), args.suffix, args.context_size)
22+
# Generate Vocab
23+
if args.generate_vocab:
24+
generate_vocab.external("output/{}_new_trees.json".format(args.suffix), args.n_vocab)
25+
# Generate AST IDs
26+
generate_ast_ids.external("output/{}_new_trees.json".format(args.suffix), args.suffix, args.context_size)
27+
28+
if __name__ == "__main__":
29+
main()

trainer.py

+3
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,14 @@ def train(self):
4848
loss = self.model(x, y, ext, return_loss = True)
4949
loss.backward()
5050
if batch_counter % 8 == 0:
51+
# Accumulate gradients over 8 iterations
5152
self.optimizer.step()
5253
self.optimizer.zero_grad()
5354
self.model.zero_grad()
55+
# All 100 batches save losses
5456
if batch_counter % 100 == 0:
5557
losses.append([epoch, i, loss.item()])
58+
# All 1,000 batches, output current metrics and save losses file
5659
if batch_counter % 1000 == 0:
5760
print("Epoch {}, It. {}/{}, Loss {}".format(epoch, i, self.dataset.__len__() / self.batch_size, loss))
5861
with open(os.path.join(self.output_dir, "losses.pickle"), "wb") as fout:

0 commit comments

Comments
 (0)