Skip to content

Commit 1012234

Browse files
committed
[Bench] Add bench for GSM8K eval
1 parent 75b970b commit 1012234

File tree

1 file changed

+315
-0
lines changed

1 file changed

+315
-0
lines changed

python/mlc_llm/bench/eval/gsm8k.py

+315
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
"""Eval GSM8K with MLCEngine."""
2+
3+
import argparse
4+
import asyncio
5+
import json
6+
import random
7+
import re
8+
from datetime import datetime
9+
from pathlib import Path
10+
from typing import List, Literal, Optional
11+
12+
import tqdm
13+
14+
from mlc_llm import AsyncMLCEngine
15+
16+
DEVICES = ["cuda", "rocm", "metal", "vulkan"]
17+
ANSWER_TRIGGER = "The answer is"
18+
INVALID_ANS = "[invalid]"
19+
20+
21+
def extract_answer(text: str, regex: re.Pattern, select_index: int) -> str:
22+
"""Extract the answer from the text."""
23+
match_all = regex.findall(text)
24+
if len(match_all) == 0:
25+
return INVALID_ANS
26+
match = match_all[select_index]
27+
if isinstance(match, tuple):
28+
match = [m for m in match if m][0]
29+
match_str: str = match.strip()
30+
match_str = match_str.lstrip("$").rstrip(".").replace(",", "")
31+
return match_str
32+
33+
34+
def extract_ground_truth(text: str) -> str:
35+
"""Extract the ground truth from the text."""
36+
return extract_answer(text, re.compile(r"#### (\-?[0-9\.\,]+)"), 0)
37+
38+
39+
def strict_extract_answer(text: str) -> str:
40+
"""Strictly extract the answer from the text."""
41+
return extract_answer(text, re.compile(r"The answer is \$?(\-?[0-9\.\,]+)."), 0)
42+
43+
44+
def flexible_extract_answer(text: str) -> str:
45+
"""Extract the last number from the text."""
46+
return extract_answer(text, re.compile(r"(-?[$0-9.,]{2,})|(-?[0-9]+)"), -1)
47+
48+
49+
def create_few_shot_prompt(n_shot: int, use_cot: bool, random_order=False) -> str:
50+
"""
51+
Create a prompt for the few-shot learning task.
52+
53+
Note
54+
----
55+
The examples are taken from the paper https://arxiv.org/pdf/2201.11903.pdf page 35.
56+
"""
57+
question, chain, answer = [], [], []
58+
59+
question.append(
60+
"There are 15 trees in the grove. "
61+
"Grove workers will plant trees in the grove today. "
62+
"After they are done, there will be 21 trees. "
63+
"How many trees did the grove workers plant today?"
64+
)
65+
chain.append(
66+
"There are 15 trees originally. "
67+
"Then there were 21 trees after some more were planted. "
68+
"So there must have been 21 - 15 = 6."
69+
)
70+
answer.append("6")
71+
72+
question.append(
73+
"If there are 3 cars in the parking lot and 2 more cars arrive, "
74+
"how many cars are in the parking lot?"
75+
)
76+
chain.append("There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5.")
77+
answer.append("5")
78+
79+
question.append(
80+
"Leah had 32 chocolates and her sister had 42. If they ate 35, "
81+
"how many pieces do they have left in total?"
82+
)
83+
chain.append(
84+
"Originally, Leah had 32 chocolates. "
85+
"Her sister had 42. So in total they had 32 + 42 = 74. "
86+
"After eating 35, they had 74 - 35 = 39."
87+
)
88+
answer.append("39")
89+
90+
question.append(
91+
"Jason had 20 lollipops. He gave Denny some lollipops. Now Jason "
92+
"has 12 lollipops. How many lollipops did Jason give to Denny?"
93+
)
94+
chain.append(
95+
"Jason started with 20 lollipops. Then he had 12 after giving some "
96+
"to Denny. So he gave Denny 20 - 12 = 8."
97+
)
98+
answer.append("8")
99+
100+
question.append(
101+
"Shawn has five toys. For Christmas, he got two toys each from his "
102+
"mom and dad. How many toys does he have now?"
103+
)
104+
chain.append(
105+
"Shawn started with 5 toys. If he got 2 toys each from his mom and "
106+
"dad, then that is 4 more toys. 5 + 4 = 9."
107+
)
108+
answer.append("9")
109+
110+
question.append(
111+
"There were nine computers in the server room. Five more computers "
112+
"were installed each day, from monday to thursday. "
113+
"How many computers are now in the server room?"
114+
)
115+
chain.append(
116+
"There were originally 9 computers. For each of 4 days, 5 more "
117+
"computers were added. So 5 * 4 = 20 computers were added. "
118+
"9 + 20 is 29."
119+
)
120+
answer.append("29")
121+
122+
question.append(
123+
"Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On "
124+
"wednesday, he lost 2 more. "
125+
"How many golf balls did he have at the end of wednesday?"
126+
)
127+
chain.append(
128+
"Michael started with 58 golf balls. After losing 23 on tuesday, "
129+
"he had 58 - 23 = 35. After losing 2 more, "
130+
"he had 35 - 2 = 33 golf balls."
131+
)
132+
answer.append("33")
133+
134+
question.append(
135+
"Olivia has $23. She bought five bagels for $3 each. How much money does she have left?"
136+
)
137+
chain.append(
138+
"Olivia had 23 dollars. "
139+
"5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. "
140+
"So she has 23 - 15 dollars left. 23 - 15 is 8."
141+
)
142+
answer.append("8")
143+
144+
index_list = list(range(len(question)))
145+
if random_order:
146+
random.shuffle(index_list)
147+
148+
prompt = ""
149+
for i in index_list[:n_shot]:
150+
if use_cot:
151+
prompt += f"Q: {question[i]}\nA: {chain[i]} {ANSWER_TRIGGER} {answer[i]}.\n\n"
152+
else:
153+
prompt += f"Question: {question[i]}\nAnswer: {ANSWER_TRIGGER} {answer[i]}.\n\n"
154+
return prompt
155+
156+
157+
def create_prompt(question: str, n_shot: int, use_cot: bool, random_order: bool = False) -> str:
158+
"""Create a prompt for the few-shot learning task."""
159+
prompt = create_few_shot_prompt(n_shot, use_cot, random_order)
160+
if use_cot:
161+
prompt += f"Q: {question}\nA:"
162+
else:
163+
prompt += f"Question: {question}\nAnswer:"
164+
return prompt
165+
166+
167+
def parse_args():
168+
"""Parse command line arguments."""
169+
170+
parser = argparse.ArgumentParser()
171+
parser.add_argument("--model", type=str, required=True)
172+
parser.add_argument(
173+
"--dataset", type=Path, required=True, help="Path to GSM8K test dataset home."
174+
)
175+
parser.add_argument("--device", type=str, choices=["auto"] + DEVICES, default="auto")
176+
parser.add_argument("--model-lib", type=str, default=None)
177+
parser.add_argument("--n-shot", type=int, default=8)
178+
parser.add_argument("--disable_cot", action="store_true", default=False)
179+
parser.add_argument("-bs", "--batch-size", type=int, default=16)
180+
parser.add_argument("--log-dir", type=Path, default=None)
181+
return parser.parse_args()
182+
183+
184+
async def send_request(
185+
async_engine: AsyncMLCEngine,
186+
prompts: List[str],
187+
semaphore: asyncio.Semaphore,
188+
):
189+
"""Send the calibration requests to the engine."""
190+
tasks = []
191+
192+
async def generate_task(prompt):
193+
async with semaphore:
194+
return await async_engine.completions.create(
195+
prompt=prompt,
196+
stream=False,
197+
max_tokens=512,
198+
stop=["Q:", "Question:"],
199+
temperature=0.0,
200+
)
201+
202+
for prompt in prompts:
203+
task = asyncio.create_task(generate_task(prompt))
204+
tasks.append(task)
205+
206+
return await tqdm.asyncio.tqdm.gather(*tasks)
207+
208+
209+
async def evaluate( # pylint: disable=too-many-arguments, too-many-locals
210+
model: str,
211+
device: str,
212+
dataset: Path,
213+
model_lib: Optional[str],
214+
n_shot: int,
215+
use_cot: bool,
216+
batch_size: int,
217+
log_dir: Optional[Path], # pylint: disable=redefined-outer-name
218+
):
219+
"""Evaluate GSM8K for the model."""
220+
mode: Literal["local", "interactive", "server"] = (
221+
"server" if batch_size > 4 else "interactive" if batch_size == 1 else "local"
222+
)
223+
async_engine = AsyncMLCEngine(model, device=device, model_lib=model_lib, mode=mode)
224+
225+
with open(dataset / "test.jsonl", "r", encoding="utf-8") as file:
226+
tests = [json.loads(line) for line in file]
227+
228+
prompts = [create_prompt(test["question"], n_shot, use_cot) for test in tests]
229+
responses = await send_request(async_engine, prompts, asyncio.Semaphore(batch_size))
230+
assert len(responses) == len(tests)
231+
232+
num_strict_correct, num_flexible_correct = 0, 0
233+
num_tests = len(tests)
234+
logs = []
235+
236+
for response, test in zip(responses, tests):
237+
response_text = response.choices[0].text.strip()
238+
gt_answer = extract_ground_truth(test["answer"])
239+
assert gt_answer != INVALID_ANS
240+
strict_answer = strict_extract_answer(response_text)
241+
flexible_answer = flexible_extract_answer(response_text)
242+
243+
if gt_answer == strict_extract_answer(response_text):
244+
# If the answer is exactly the same as the response, then it is correct
245+
num_strict_correct += 1
246+
num_flexible_correct += 1
247+
248+
elif gt_answer == flexible_extract_answer(response_text):
249+
# Try flexible extract if the strict match fails
250+
num_flexible_correct += 1
251+
252+
logs.append(
253+
{
254+
"question": test["question"],
255+
"response": response_text,
256+
"ground_truth": gt_answer,
257+
"strict_answer": strict_answer,
258+
"flexible_answer": flexible_answer,
259+
"strict_match": gt_answer == strict_answer,
260+
"flexible_match": gt_answer == flexible_answer,
261+
}
262+
)
263+
264+
results = {
265+
"config": {
266+
"model": model,
267+
"device": device,
268+
"model_lib": model_lib,
269+
"n_shot": n_shot,
270+
"use_cot": use_cot,
271+
},
272+
"results": {
273+
"strict_match": num_strict_correct,
274+
"flexible_match": num_flexible_correct,
275+
"total": num_tests,
276+
},
277+
}
278+
print(
279+
f"Strict Matching Accuracy: {num_strict_correct} / {num_tests} = "
280+
f"{num_strict_correct /num_tests * 100:.2f}%"
281+
)
282+
print(
283+
f"Flexible Matching Accuracy: {num_flexible_correct} / {num_tests} = "
284+
f"{num_flexible_correct /num_tests * 100:.2f}%"
285+
)
286+
287+
if log_dir:
288+
with open(log_dir / "summary.json", "w", encoding="utf-8") as f:
289+
json.dump(results, f, indent=2)
290+
with open(log_dir / "logs.json", "w", encoding="utf-8") as f:
291+
json.dump(logs, f, indent=2)
292+
293+
294+
if __name__ == "__main__":
295+
args = parse_args()
296+
start_time = datetime.now()
297+
log_dir: Optional[Path] = None
298+
if args.log_dir is not None:
299+
time_dir = start_time.strftime("%Y-%m-%d_%H-%M-%S")
300+
log_dir = args.log_dir / time_dir
301+
log_dir.mkdir(parents=True, exist_ok=True)
302+
asyncio.run(
303+
evaluate(
304+
model=args.model,
305+
device=args.device,
306+
dataset=args.dataset,
307+
model_lib=args.model_lib,
308+
n_shot=args.n_shot,
309+
use_cot=not args.disable_cot,
310+
batch_size=args.batch_size,
311+
log_dir=log_dir,
312+
)
313+
)
314+
end_time = datetime.now()
315+
print(f"Time used: {end_time - start_time}")

0 commit comments

Comments
 (0)