1+ from eval_mm .tasks .task import Task
2+ from eval_mm .tasks .task_registry import register_task
3+ from datasets import load_dataset , Dataset
4+ from PIL import Image
5+
6+
7+ @register_task ("docvqa" , "DocVQA" , "doc-vqa" )
8+ class DocVQA (Task ):
9+ """DocVQA task implementation.
10+
11+ DocVQA is a VQA dataset for understanding images of document pages.
12+ It uses extractive QA where models need to extract answers from document images.
13+ Multiple valid answers are provided for each question.
14+ """
15+
16+ def __init__ (self , config ):
17+ super ().__init__ (config )
18+
19+ @staticmethod
20+ def _prepare_dataset () -> Dataset :
21+ """Load DocVQA validation set."""
22+ # Load the DocVQA config from lmms-lab/DocVQA dataset
23+ ds = load_dataset ("lmms-lab/DocVQA" , "DocVQA" , split = "validation" )
24+
25+ # Rename questionId to question_id for consistency
26+ ds = ds .rename_column ("questionId" , "question_id" )
27+
28+ return ds
29+
30+ @staticmethod
31+ def doc_to_text (doc ) -> str :
32+ """Convert document to text prompt.
33+
34+ DocVQA is an extractive QA task, so we just return the question.
35+ """
36+ return doc ['question' ]
37+
38+ @staticmethod
39+ def doc_to_visual (doc ) -> list [Image .Image ]:
40+ """Extract image from document."""
41+ return [doc ['image' ]]
42+
43+ @staticmethod
44+ def doc_to_id (doc ) -> str :
45+ """Return unique question ID."""
46+ return str (doc ['question_id' ])
47+
48+ @staticmethod
49+ def doc_to_answer (doc ) -> list [str ]:
50+ """Return list of valid answers.
51+
52+ DocVQA provides multiple valid answers for each question.
53+ We return all of them for evaluation with substring-match scorer.
54+ """
55+ return doc ['answers' ]
56+
57+
58+ def test_docvqa_task ():
59+ """Test DocVQA task implementation."""
60+ from eval_mm .tasks .task import TaskConfig
61+
62+ # Create task instance
63+ task = DocVQA (TaskConfig (max_dataset_len = 10 ))
64+
65+ # Load dataset
66+ print ("Loading DocVQA dataset..." )
67+ ds = task .dataset
68+ print (f"Dataset size: { len (ds )} " )
69+
70+ # Test with first example
71+ example = ds [0 ]
72+ print (f"\n First example:" )
73+ print (f" ID: { task .doc_to_id (example )} " )
74+ print (f" Question: { task .doc_to_text (example )} " )
75+ print (f" Image: { task .doc_to_visual (example )[0 ]} " )
76+ print (f" Valid answers: { task .doc_to_answer (example )} " )
77+
78+ # Verify data types
79+ assert isinstance (task .doc_to_text (example ), str )
80+ assert isinstance (task .doc_to_visual (example ), list )
81+ assert all (isinstance (img , Image .Image ) for img in task .doc_to_visual (example ))
82+ assert isinstance (task .doc_to_id (example ), str )
83+ assert isinstance (task .doc_to_answer (example ), list )
84+ assert all (isinstance (ans , str ) for ans in task .doc_to_answer (example ))
85+
86+ print ("\n All tests passed!" )
0 commit comments