Skip to content

Commit ce8557c

Browse files
committed
Add MNIST task
1 parent 3ab3196 commit ce8557c

File tree

6 files changed

+76
-25
lines changed

6 files changed

+76
-25
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ uv run --group normal python examples/sample.py \
3636
--model_id llava-hf/llava-1.5-7b-hf \
3737
--task_id japanese-heron-bench \
3838
--result_dir result \
39-
--metrics "heron-bench" \
40-
--judge_model "gpt-4o-2024-11-20" \
39+
--metrics heron-bench \
40+
--judge_model gpt-4o-2024-11-20 \
4141
--overwrite
4242
```
4343

@@ -136,9 +136,9 @@ See `eval_all.sh` for the complete list of model dependencies.
136136
When adding a new group, remember to configure [conflict](https://docs.astral.sh/uv/concepts/projects/config/#conflicting-dependencies).
137137

138138
## Browse Predictions with Streamlit
139-
139+
f
140140
```bash
141-
uv run streamlit run scripts/browse_prediction.py --task_id "japanese-heron-bench" --result_dir "result"
141+
uv run streamlit run scripts/browse_prediction.py -- --task_id japanese-heron-bench --result_dir result --model_list llava-hf/llava-1.5-7b-hf
142142
```
143143

144144
![Streamlit](./assets/streamlit_visualization.png)

examples/sample.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414
def parse_args():
1515
parser = argparse.ArgumentParser()
1616
parser.add_argument("--model_id", default="llava-hf/llava-1.5-7b-hf")
17-
parser.add_argument("--task_id", default="japanese-heron-bench")
17+
parser.add_argument(
18+
"--task_id",
19+
default="japanese-heron-bench",
20+
help=f"Task ID to evaluate. Available: {eval_mm.TaskRegistry().get_task_list()}",
21+
)
1822
parser.add_argument("--judge_model", default="gpt-4o-2024-11-20")
1923
parser.add_argument("--batch_size_for_evaluation", type=int, default=10)
2024
parser.add_argument("--overwrite", action="store_true")
@@ -27,7 +31,13 @@ def parse_args():
2731
parser.add_argument("--do_sample", action="store_true", default=False)
2832
parser.add_argument("--use_cache", action="store_true", default=True)
2933
parser.add_argument("--max_dataset_len", type=int)
30-
parser.add_argument("--metrics", default="llm_as_a_judge_heron_bench")
34+
parser.add_argument(
35+
"--metrics",
36+
type=str,
37+
nargs="+",
38+
default=["heron-bench"],
39+
help=f"Metrics to evaluate. Available: {eval_mm.ScorerRegistry().get_metric_list()}",
40+
)
3141
parser.add_argument(
3242
"--rotate_choices", action="store_true", help="This option is used in MECHA-ja"
3343
)
@@ -137,7 +147,6 @@ def save_final_results(preds, task, metrics, scores_by_metric, output_path):
137147

138148
def main():
139149
args = parse_args()
140-
metrics = args.metrics.split(",")
141150

142151
gen_kwargs = GenerationConfig(
143152
max_new_tokens=args.max_new_tokens,
@@ -163,10 +172,10 @@ def main():
163172
logger.info("Inference only mode. Skipping evaluation.")
164173
return
165174

166-
scores_by_metric, aggregated_metrics = evaluate(args, task, preds, metrics)
175+
scores_by_metric, aggregated_metrics = evaluate(args, task, preds, args.metrics)
167176

168177
prediction_path = os.path.join(output_dir, "prediction.jsonl")
169-
save_final_results(preds, task, metrics, scores_by_metric, prediction_path)
178+
save_final_results(preds, task, args.metrics, scores_by_metric, prediction_path)
170179

171180
evaluation_path = os.path.join(output_dir, "evaluation.jsonl")
172181
with open(evaluation_path, "w") as f:

scripts/browse_prediction.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def parse_args():
1010
parser = ArgumentParser()
1111
parser.add_argument("--task_id", type=str, default="japanese-heron-bench")
1212
parser.add_argument("--result_dir", type=str, default="result")
13+
parser.add_argument("--model_list", type=str, nargs="+", default=[])
1314

1415
return parser.parse_args()
1516

@@ -23,18 +24,11 @@ def scrollable_text(text):
2324
if __name__ == "__main__":
2425
args = parse_args()
2526

26-
task = eval_mm.tasks.TaskRegistry().get_task_cls(args.task_id)(
27-
eval_mm.tasks.TaskConfig()
28-
)
27+
task = eval_mm.TaskRegistry().load_task(args.task_id)
2928

3029
# Load model prediction
31-
model_list = [
32-
"google/gemma-3-12b-it",
33-
"google/gemma-3-27b-it",
34-
"microsoft/Phi-4-multimodal-instruct",
35-
]
3630
predictions_per_model = {}
37-
for model_id in model_list:
31+
for model_id in args.model_list:
3832
prediction_path = os.path.join(
3933
args.result_dir, args.task_id, model_id, "prediction.jsonl"
4034
)
@@ -50,8 +44,8 @@ def scrollable_text(text):
5044

5145
SAMPLES_PER_PAGE = 30 # 1ページに表示する件数
5246
# Question ID, Image, Question, Answer, Prediction_model1, Prediction_model2,..
53-
column_width_list = [1, 3, 3, 3] + [4] * len(model_list)
54-
st.write(f"# {args.task_id} dataset")
47+
column_width_list = [1, 3, 3, 3] + [4] * len(args.model_list)
48+
st.write(f"# {args.task_id}")
5549

5650
def show_sample(idx):
5751
sample = ds[idx]
@@ -64,8 +58,8 @@ def show_sample(idx):
6458
cols[3].markdown(
6559
scrollable_text(task.doc_to_answer(sample)), unsafe_allow_html=True
6660
)
67-
for model_id in model_list:
68-
cols[4 + model_list.index(model_id)].markdown(
61+
for model_id in args.model_list:
62+
cols[4 + args.model_list.index(model_id)].markdown(
6963
scrollable_text(predictions_per_model[model_id][idx]["text"]),
7064
unsafe_allow_html=True,
7165
)
@@ -93,8 +87,10 @@ def show_sample(idx):
9387
header_cols[1].markdown("Image")
9488
header_cols[2].markdown("Question")
9589
header_cols[3].markdown("Answer")
96-
for model_id in model_list:
97-
header_cols[4 + model_list.index(model_id)].markdown(f"Prediction ({model_id})")
90+
for model_id in args.model_list:
91+
header_cols[4 + args.model_list.index(model_id)].markdown(
92+
f"Prediction ({model_id})"
93+
)
9894

9995
# サンプルを表示
10096
for idx in range(start_idx, end_idx):

scripts/make_leaderboard.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,8 @@ def parse_args():
286286
parser.add_argument(
287287
"--task_id_list",
288288
type=str,
289-
help="List of task IDs to include in the leaderboard (e.g. jmmmu,mmmu). If not specified, all tasks will be included.",
289+
nargs="+",
290+
help=f"List of task IDs to include in the leaderboard. Available: {TASK_ALIAS.keys()}",
290291
)
291292
return parser.parse_args()
292293

src/eval_mm/tasks/mnist.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from eval_mm.tasks.task import Task
2+
from datasets import load_dataset, Dataset
3+
from PIL import Image
4+
5+
6+
class MNIST(Task):
7+
def __init__(self, config):
8+
super().__init__(config)
9+
10+
@staticmethod
11+
def _prepare_dataset() -> Dataset:
12+
ds = load_dataset("ylecun/mnist", split="test")
13+
ds = ds.map(lambda example, idx: {"question_id": idx}, with_indices=True)
14+
return ds
15+
16+
@staticmethod
17+
def doc_to_text(doc) -> str:
18+
return "画像に写っている数字は何ですか? 数字のみを出力してください。"
19+
20+
@staticmethod
21+
def doc_to_visual(doc) -> list[Image.Image]:
22+
return [doc["image"]]
23+
24+
@staticmethod
25+
def doc_to_id(doc) -> int:
26+
return doc["question_id"]
27+
28+
@staticmethod
29+
def doc_to_answer(doc) -> str:
30+
return str(doc["label"])
31+
32+
33+
def test_task():
34+
from eval_mm.tasks.task import TaskConfig
35+
36+
task = MNIST(TaskConfig())
37+
ds = task.dataset
38+
print(ds[0])
39+
assert isinstance(task.doc_to_text(ds[0]), str)
40+
assert isinstance(task.doc_to_visual(ds[0]), list)
41+
assert isinstance(task.doc_to_visual(ds[0])[0], Image.Image)
42+
assert isinstance(task.doc_to_id(ds[0]), int)
43+
assert isinstance(task.doc_to_answer(ds[0]), str)

src/eval_mm/tasks/task_registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .jic_vqa import JICVQA
1010
from .mecha_ja import MECHAJa
1111
from .mmmlu import MMMLU
12+
from .mnist import MNIST
1213
from .task import TaskConfig, Task
1314

1415

@@ -27,6 +28,7 @@ class TaskRegistry:
2728
"jic-vqa": JICVQA,
2829
"mecha-ja": MECHAJa,
2930
"mmmlu": MMMLU,
31+
"mnist": MNIST,
3032
}
3133

3234
@classmethod

0 commit comments

Comments
 (0)