Skip to content

Commit 431733a

Browse files
silviaseclaude
andcommitted
feat: Add DocVQA task implementation
- Add DocVQA task for document visual question answering - Use substring-match scorer to handle answer variations - Update task registry and configuration files - Support 5,349 validation examples from lmms-lab/DocVQA DocVQA is an extractive QA task where models extract answers from document images. Multiple valid answers are provided per question. 🤖 Generated with Claude Code Co-Authored-By: Claude <[email protected]>
1 parent 79166d0 commit 431733a

File tree

4 files changed

+105
-3
lines changed

4 files changed

+105
-3
lines changed

scripts/make_leaderboard.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
"mmmu": "MMMU",
2222
"cc-ocr": "CC-OCR",
2323
"cvqa": "CVQA",
24+
"ai2d": "AI2D",
25+
"blink": "BLINK",
26+
"docvqa": "DocVQA",
2427
}
2528

2629
TASK_CLUSTER_ALIAS = {
@@ -31,11 +34,14 @@
3134
"VG-VQA": "視覚中心",
3235
"Heron": "視覚中心",
3336
"JVB-ItW": "視覚中心",
34-
"MulIm-VQA": "非日本語",
35-
"MMMU": "非日本語",
36-
"LLAVA": "非日本語",
37+
"MulIm-VQA": "その他",
38+
"MMMU": "英語",
39+
"LLAVA": "英語",
3740
"CC-OCR": "言語・知識中心",
3841
"CVQA": "視覚中心",
42+
"AI2D": "英語",
43+
"BLINK": "英語",
44+
"DocVQA": "英語",
3945
}
4046

4147
METRIC_ALIAS = {
@@ -49,6 +55,10 @@
4955
"mmmu": "Acc",
5056
"cc-ocr": "macro_f1",
5157
"substring-match": "Acc",
58+
"cvqa": "Acc",
59+
"ai2d": "Acc",
60+
"blink": "Acc",
61+
"docvqa": "Acc",
5262
}
5363

5464
MODEL_LIST = [

scripts/nvlink/config.sh

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,9 +171,11 @@ declare -a task_list=(
171171
"mmmu"
172172
"llava-bench-in-the-wild"
173173
"jic-vqa"
174+
"cvqa"
174175
"mecha-ja"
175176
"ai2d"
176177
"blink"
178+
"docvqa"
177179
)
178180

179181
# === Metrics Mapping ===
@@ -190,6 +192,8 @@ declare -A METRIC_MAP=(
190192
["mecha-ja"]="mecha-ja"
191193
["ai2d"]="ai2d"
192194
["blink"]="blink"
195+
["cvqa"]="substring-match"
196+
["docvqa"]="substring-match"
193197
)
194198

195199
# === Function to load .env file ===

src/eval_mm/tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .cvqa import CVQA
1414
from .ai2d import AI2D
1515
from .blink import BLINK
16+
from .docvqa import DocVQA
1617
from .task_registry import TaskRegistry
1718
from .task import TaskConfig
1819

@@ -32,6 +33,7 @@
3233
"CVQA",
3334
"AI2D",
3435
"BLINK",
36+
"DocVQA",
3537
"TaskRegistry",
3638
"TaskConfig",
3739
]

src/eval_mm/tasks/docvqa.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from eval_mm.tasks.task import Task
2+
from eval_mm.tasks.task_registry import register_task
3+
from datasets import load_dataset, Dataset
4+
from PIL import Image
5+
6+
7+
@register_task("docvqa", "DocVQA", "doc-vqa")
8+
class DocVQA(Task):
9+
"""DocVQA task implementation.
10+
11+
DocVQA is a VQA dataset for understanding images of document pages.
12+
It uses extractive QA where models need to extract answers from document images.
13+
Multiple valid answers are provided for each question.
14+
"""
15+
16+
def __init__(self, config):
17+
super().__init__(config)
18+
19+
@staticmethod
20+
def _prepare_dataset() -> Dataset:
21+
"""Load DocVQA validation set."""
22+
# Load the DocVQA config from lmms-lab/DocVQA dataset
23+
ds = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation")
24+
25+
# Rename questionId to question_id for consistency
26+
ds = ds.rename_column("questionId", "question_id")
27+
28+
return ds
29+
30+
@staticmethod
31+
def doc_to_text(doc) -> str:
32+
"""Convert document to text prompt.
33+
34+
DocVQA is an extractive QA task, so we just return the question.
35+
"""
36+
return doc['question']
37+
38+
@staticmethod
39+
def doc_to_visual(doc) -> list[Image.Image]:
40+
"""Extract image from document."""
41+
return [doc['image']]
42+
43+
@staticmethod
44+
def doc_to_id(doc) -> str:
45+
"""Return unique question ID."""
46+
return str(doc['question_id'])
47+
48+
@staticmethod
49+
def doc_to_answer(doc) -> list[str]:
50+
"""Return list of valid answers.
51+
52+
DocVQA provides multiple valid answers for each question.
53+
We return all of them for evaluation with substring-match scorer.
54+
"""
55+
return doc['answers']
56+
57+
58+
def test_docvqa_task():
59+
"""Test DocVQA task implementation."""
60+
from eval_mm.tasks.task import TaskConfig
61+
62+
# Create task instance
63+
task = DocVQA(TaskConfig(max_dataset_len=10))
64+
65+
# Load dataset
66+
print("Loading DocVQA dataset...")
67+
ds = task.dataset
68+
print(f"Dataset size: {len(ds)}")
69+
70+
# Test with first example
71+
example = ds[0]
72+
print(f"\nFirst example:")
73+
print(f" ID: {task.doc_to_id(example)}")
74+
print(f" Question: {task.doc_to_text(example)}")
75+
print(f" Image: {task.doc_to_visual(example)[0]}")
76+
print(f" Valid answers: {task.doc_to_answer(example)}")
77+
78+
# Verify data types
79+
assert isinstance(task.doc_to_text(example), str)
80+
assert isinstance(task.doc_to_visual(example), list)
81+
assert all(isinstance(img, Image.Image) for img in task.doc_to_visual(example))
82+
assert isinstance(task.doc_to_id(example), str)
83+
assert isinstance(task.doc_to_answer(example), list)
84+
assert all(isinstance(ans, str) for ans in task.doc_to_answer(example))
85+
86+
print("\nAll tests passed!")

0 commit comments

Comments
 (0)