Skip to content

Commit d015821

Browse files
committed
[Bench] Add bench for MMLU eval
Usage: ``` python python/bench/eval/mmlu.py --model dist/Meta-Llama-3-8B-Instruct-q4f16_1-MLC --dataset /path/to/dataset --device cuda --log-dir debug/mmlu ``` Note that chat mode is problematic, which needs to be fixed.
1 parent c25834d commit d015821

File tree

1 file changed

+245
-0
lines changed

1 file changed

+245
-0
lines changed

python/mlc_llm/bench/eval/mmlu.py

+245
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
"""Eval MMLU with MLCEngine."""
2+
3+
import argparse
4+
import asyncio
5+
import csv
6+
import json
7+
import string
8+
from datetime import datetime
9+
from pathlib import Path
10+
from typing import Any, Dict, List, Optional
11+
12+
import numpy as np
13+
import tqdm
14+
15+
from mlc_llm import AsyncMLCEngine
16+
17+
SUBJECTS = [
18+
"abstract_algebra",
19+
"anatomy",
20+
"astronomy",
21+
"business_ethics",
22+
"clinical_knowledge",
23+
"college_biology",
24+
"college_chemistry",
25+
"college_computer_science",
26+
"college_mathematics",
27+
"college_medicine",
28+
"college_physics",
29+
"computer_security",
30+
"conceptual_physics",
31+
"econometrics",
32+
"electrical_engineering",
33+
"elementary_mathematics",
34+
"formal_logic",
35+
"global_facts",
36+
"high_school_biology",
37+
"high_school_chemistry",
38+
"high_school_computer_science",
39+
"high_school_european_history",
40+
"high_school_geography",
41+
"high_school_government_and_politics",
42+
"high_school_macroeconomics",
43+
"high_school_mathematics",
44+
"high_school_microeconomics",
45+
"high_school_physics",
46+
"high_school_psychology",
47+
"high_school_statistics",
48+
"high_school_us_history",
49+
"high_school_world_history",
50+
"human_aging",
51+
"human_sexuality",
52+
"international_law",
53+
"jurisprudence",
54+
"logical_fallacies",
55+
"machine_learning",
56+
"management",
57+
"marketing",
58+
"medical_genetics",
59+
"miscellaneous",
60+
"moral_disputes",
61+
"moral_scenarios",
62+
"nutrition",
63+
"philosophy",
64+
"prehistory",
65+
"professional_accounting",
66+
"professional_law",
67+
"professional_medicine",
68+
"professional_psychology",
69+
"public_relations",
70+
"security_studies",
71+
"sociology",
72+
"us_foreign_policy",
73+
"virology",
74+
"world_religions",
75+
]
76+
PADDING_LEN = max(len(subject) for subject in SUBJECTS)
77+
DEVICES = ["cuda", "rocm", "metal", "vulkan"]
78+
PROMPT_TEMPLATE = string.Template("$Q\nA. $A\nB. $B\nC. $C\nD. $D\nAnswer:")
79+
80+
81+
def parse_args():
82+
"""Parse command line arguments."""
83+
84+
parser = argparse.ArgumentParser()
85+
parser.add_argument("--model", type=str, required=True)
86+
parser.add_argument(
87+
"--dataset", type=Path, required=True, help="Path to MMLU test dataset home."
88+
)
89+
parser.add_argument("--device", type=str, choices=["auto"] + DEVICES, default="auto")
90+
parser.add_argument("--model-lib", type=str, default=None)
91+
parser.add_argument("-s", "--subject", nargs="+", type=str, choices=SUBJECTS, default=SUBJECTS)
92+
parser.add_argument("-bs", "--batch-size", type=int, default=16)
93+
parser.add_argument("--log-dir", type=Path, default=None)
94+
return parser.parse_args()
95+
96+
97+
async def send_request(
98+
async_engine: AsyncMLCEngine,
99+
prompts: List[str],
100+
semaphore: asyncio.Semaphore,
101+
subject: str,
102+
):
103+
"""Send the calibration requests to the engine."""
104+
tasks = []
105+
106+
async def generate_task(prompt):
107+
async with semaphore:
108+
return await async_engine.completions.create(
109+
prompt=prompt,
110+
stream=False,
111+
max_tokens=1,
112+
temperature=1.0,
113+
logprobs=True,
114+
top_logprobs=5,
115+
)
116+
117+
for prompt in prompts:
118+
task = asyncio.create_task(generate_task(prompt))
119+
tasks.append(task)
120+
121+
return await tqdm.asyncio.tqdm.gather(
122+
*tasks,
123+
desc=f"Running {subject.ljust(PADDING_LEN)}",
124+
bar_format="{desc} {percentage:3.0f}%|{bar}{r_bar}",
125+
)
126+
127+
128+
async def evaluate( # pylint: disable=too-many-arguments, too-many-locals
129+
model: str,
130+
device: str,
131+
dataset: Path,
132+
model_lib: Optional[str],
133+
subjects: List[str],
134+
semaphore: asyncio.Semaphore,
135+
log_dir: Optional[Path], # pylint: disable=redefined-outer-name
136+
):
137+
"""Evaluate MMLU for the model."""
138+
async_engine = AsyncMLCEngine(model, device=device, model_lib=model_lib, mode="server")
139+
140+
results: Dict[str, Any] = {}
141+
for subject in subjects:
142+
with open(dataset / "test" / f"{subject}_test.csv", encoding="utf-8") as csvfile:
143+
tests = list(csv.reader(csvfile, delimiter=",", quotechar='"'))
144+
assert all(len(test) == 6 for test in tests)
145+
146+
logs = []
147+
num_correct = 0
148+
prompts = [
149+
PROMPT_TEMPLATE.substitute(Q=test[0], A=test[1], B=test[2], C=test[3], D=test[4])
150+
for test in tests
151+
]
152+
responses = await send_request(async_engine, prompts, semaphore, subject)
153+
154+
assert len(responses) == len(tests)
155+
for response, test in zip(responses, tests):
156+
token_logprobs = {}
157+
logprobs = response.choices[0].logprobs.content[0].top_logprobs
158+
for logprob in logprobs:
159+
if logprob.token not in token_logprobs:
160+
token_logprobs[logprob.token] = logprob.logprob
161+
162+
abcd_logprobs = {}
163+
for choice in ["A", "B", "C", "D"]:
164+
abcd_logprobs[choice] = token_logprobs[choice] if choice in token_logprobs else -100
165+
166+
pred = {0: "A", 1: "B", 2: "C", 3: "D"}[int(np.argmax(list(abcd_logprobs.values())))]
167+
num_correct += pred == test[5]
168+
169+
logs.append(
170+
{
171+
"Question": {
172+
"Q": test[0],
173+
"A": test[1],
174+
"B": test[2],
175+
"C": test[3],
176+
"D": test[4],
177+
},
178+
"Answer": test[5],
179+
"Response": {
180+
"pred": pred,
181+
"logprobs": list(abcd_logprobs.values()),
182+
},
183+
}
184+
)
185+
186+
results[subject] = {
187+
"correct": num_correct,
188+
"total": len(tests),
189+
"accuracy": num_correct / len(tests),
190+
}
191+
192+
if log_dir:
193+
with open(log_dir / "subjects" / f"{subject}.json", "w", encoding="utf-8") as f:
194+
json.dump(logs, f, indent=2)
195+
196+
total_correct, total_tests = 0, 0
197+
for subject, v in results.items():
198+
num_correct, num_tests, accuracy = v["correct"], v["total"], v["accuracy"]
199+
print(f"{subject}: {num_correct} / {num_tests} = {accuracy * 100:.2f}%")
200+
total_correct += num_correct
201+
total_tests += num_tests
202+
203+
total_accuracy = total_correct / total_tests
204+
results["total"] = {
205+
"correct": total_correct,
206+
"total": total_tests,
207+
"accuracy": total_accuracy,
208+
}
209+
print(f"Total accuracy: {total_correct} / {total_tests} = {total_accuracy * 100:.2f}%")
210+
211+
if log_dir:
212+
results = {
213+
"config": {
214+
"model": model,
215+
"device": device,
216+
"model_lib": model_lib,
217+
"subjects": subjects,
218+
},
219+
"results": results,
220+
}
221+
with open(log_dir / "summary.json", "w", encoding="utf-8") as f:
222+
json.dump(results, f, indent=2)
223+
224+
225+
if __name__ == "__main__":
226+
args = parse_args()
227+
start_time = datetime.now()
228+
log_dir: Optional[Path] = None
229+
if args.log_dir is not None:
230+
time_dir = start_time.strftime("%Y-%m-%d_%H-%M-%S")
231+
log_dir = args.log_dir / time_dir
232+
(log_dir / "subjects").mkdir(parents=True, exist_ok=True)
233+
asyncio.run(
234+
evaluate(
235+
model=args.model,
236+
device=args.device,
237+
dataset=args.dataset,
238+
model_lib=args.model_lib,
239+
subjects=args.subject,
240+
semaphore=asyncio.Semaphore(args.batch_size),
241+
log_dir=log_dir,
242+
)
243+
)
244+
end_time = datetime.now()
245+
print(f"Time used: {end_time - start_time}")

0 commit comments

Comments
 (0)