-
Notifications
You must be signed in to change notification settings - Fork 240
/
Copy pathapps.py
122 lines (100 loc) · 4.12 KB
/
apps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Measuring Coding Challenge Competence With APPS
https://arxiv.org/abs/2105.09938
APPS is a benchmark for code generation with 10000 problems. With three difficulty levels: introductory, interview and competition.
It can be used to evaluate the ability of language models to generate code from natural language specifications.
Homepage: https://github.com/hendrycks/apps
"""
import json
from evaluate import load
from bigcode_eval.base import Task
_CITATION = """
@article{hendrycksapps2021,
title={Measuring Coding Challenge Competence With APPS},
author={Dan Hendrycks and Steven Basart and Saurav Kadavath and Mantas Mazeika and Akul Arora and Ethan Guo and Collin Burns and Samir Puranik and Horace He and Dawn Song and Jacob Steinhardt},
journal={NeurIPS},
year={2021}
}
"""
LEVELS = ["introductory", "interview", "competition"]
def create_all_tasks():
"""Creates a dictionary of tasks from a list of levels
:return: {task_name: task}
e.g. {apps-interview: Task, apps-competitoon: Task}
"""
return {f"apps-{level}": create_task(level) for level in LEVELS}
def create_task(level):
class APPS(GeneralAPPS):
def __init__(self, **kwargs):
super().__init__(level, **kwargs)
return APPS
class GeneralAPPS(Task):
"""A task represents an entire benchmark including its dataset, problems,
answers, generation settings and evaluation methods.
"""
DATASET_PATH = "codeparrot/apps"
DATASET_NAME = None
def __init__(self, level, k_list=[1, 10, 100]):
self.DATASET_NAME = level
super().__init__(
stop_words=["\nQUESTION", "\n---", "\nANSWER"],
requires_execution=True,
)
self.k_list = k_list
def get_dataset(self):
"""Returns dataset for the task or an iterable of any object, that get_prompt can handle"""
return self.dataset["test"]
def get_prompt(self, doc):
"""Generate prompts for APPS
Finetuning setup: prompt=question with some starter code and function name if they exist.
We also specify the type of the prompt, i.e. whether it is call-based or standard input-based.
"""
starter_code = None if len(doc["starter_code"]) == 0 else doc["starter_code"]
try:
input_outpout = json.loads(doc["input_output"])
fn_name = (
None if not input_outpout.get("fn_name") else input_outpout["fn_name"]
)
except ValueError:
fn_name = None
prompt = "\nQUESTION:\n"
prompt += doc["question"]
if starter_code:
prompt += starter_code
if not fn_name:
call_format = "\nUse Standard Input format"
prompt += call_format
else:
call_format = "\nUse Call-Based format"
prompt += call_format
prompt += "\nANSWER:\n"
return prompt
def get_reference(self, doc):
"""Builds the reference solution for the doc (sample from the test dataset)."""
return None
def postprocess_generation(self, generation, idx):
"""Defines the postprocessing for a LM generation.
:param generation: str
code generation from LM
:param idx: int
index of doc in the dataset to which the generation belongs
(not used for APPS)
"""
try:
generation = generation.split("\nANSWER:", 1)[1]
except IndexError:
# happens when prompts were very long and got truncated
pass
return generation
def process_results(self, generations, references):
"""Takes the list of LM generations and evaluates them against ground truth references,
returning the metric for the generations.
:param generations: list(list(str))
list of lists containing generations
:param references: list(str)
list of str containing refrences (not needed for APPS Task)
"""
code_metric = load("codeparrot/apps_metric")
results = code_metric.compute(
predictions=generations, k_list=self.k_list, level=self.DATASET_NAME
)
return results