Skip to content

Commit 1d79f95

Browse files
author
silviase
committed
refactor(tasks): add _prepare_test_dataset across tasks and stop relying on _maybe_slice_split; optimize MMMU early-stop; add HF caches cleanup in CI
1 parent d86fc42 commit 1d79f95

23 files changed

+307
-95
lines changed

.github/workflows/test.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ jobs:
1111
# HF token is required to avoid 429 rate limits on HF Hub
1212
HF_TOKEN: ${{ secrets.HF_TOKEN }}
1313
HUGGINGFACE_HUB_TOKEN: ${{ secrets.HF_TOKEN }}
14+
HF_DATASETS_CACHE: ${{ runner.temp }}/hf-datasets
15+
HUGGINGFACE_HUB_CACHE: ${{ runner.temp }}/hf-hub
16+
TRANSFORMERS_CACHE: ${{ runner.temp }}/hf-transformers
1417

1518
steps:
1619
- uses: actions/checkout@v4
@@ -29,8 +32,12 @@ jobs:
2932
- name: Run tests (metrics)
3033
run: uv run pytest src/eval_mm/metrics/*.py
3134

35+
- name: Clear HF caches before task tests
36+
run: |
37+
rm -rf "$HF_DATASETS_CACHE" "$HUGGINGFACE_HUB_CACHE" || true
38+
3239
- name: Run tests (tasks)
33-
run: uv run pytest src/eval_mm/tasks/*.py
40+
run: uv run pytest src/eval_mm/tasks/*.py
3441

3542
# Optional model smoke; enable when runners have resources
3643
# - name: Run model smoke tests

src/eval_mm/tasks/ai2d.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@ def __init__(self, config):
1010
super().__init__(config)
1111

1212
def _prepare_dataset(self) -> Dataset:
13-
ds = load_dataset("lmms-lab/ai2d", split=self._maybe_slice_split("test"))
13+
ds = load_dataset("lmms-lab/ai2d", split="test")
14+
ds = ds.map(lambda example, idx: {"question_id": idx}, with_indices=True)
15+
return ds
16+
17+
def _prepare_test_dataset(self) -> Dataset:
18+
n = getattr(self.config, "max_dataset_len", 10)
19+
ds = load_dataset("lmms-lab/ai2d", split=f"test[:{n}]")
1420
ds = ds.map(lambda example, idx: {"question_id": idx}, with_indices=True)
1521
return ds
1622

src/eval_mm/tasks/blink.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ def _prepare_dataset(self) -> Dataset:
3535
total = 0
3636

3737
for config_name in BLINK.CONFIGS:
38-
ds = load_dataset(
39-
"BLINK-Benchmark/BLINK", config_name, split=self._maybe_slice_split("val")
40-
)
38+
ds = load_dataset("BLINK-Benchmark/BLINK", config_name, split="val")
4139
ds = ds.map(lambda x: {"config_name": config_name})
4240
all_datasets.append(ds)
4341
total += len(ds)
@@ -52,6 +50,11 @@ def _prepare_dataset(self) -> Dataset:
5250
)
5351

5452
return combined_dataset
53+
54+
def _prepare_test_dataset(self) -> Dataset:
55+
# Reuse the same incremental loading logic; Task base will apply
56+
# final length cap if needed.
57+
return self._prepare_dataset()
5558

5659
@staticmethod
5760
def doc_to_text(doc) -> str:

src/eval_mm/tasks/cc_ocr.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,35 +23,6 @@ class CCOCR(Task):
2323
default_metric = "ccocr"
2424

2525
def _prepare_dataset(self) -> Dataset:
26-
# Use streaming during tests to avoid empty slices after filtering
27-
n = getattr(self.config, "max_dataset_len", None)
28-
test_subset = os.getenv("PYTEST_CURRENT_TEST") or os.getenv("EVAL_MM_TEST_SUBSET") == "1"
29-
if n is not None and test_subset:
30-
stream = load_dataset(
31-
"wulipc/CC-OCR", "multi_lan_ocr", split="test", streaming=True
32-
)
33-
buf = {
34-
"index": [],
35-
"question_id": [],
36-
"question": [],
37-
"answer": [],
38-
"input_text": [],
39-
"image": [],
40-
}
41-
count = 0
42-
for ex in stream:
43-
if ex.get("l2-category") == "Japanese":
44-
buf["index"].append(str(count))
45-
buf["question_id"].append(str(count))
46-
buf["question"].append(ex["question"])
47-
buf["answer"].append(ex["answer"])
48-
buf["input_text"].append(ex["question"])
49-
buf["image"].append(ex["image"])
50-
count += 1
51-
if count >= n:
52-
break
53-
return Dataset.from_dict(buf)
54-
5526
ds = load_dataset("wulipc/CC-OCR", "multi_lan_ocr", split="test")
5627
ds = ds.filter(lambda example: example["l2-category"] == "Japanese")
5728
ds = ds.map(
@@ -67,6 +38,32 @@ def _prepare_dataset(self) -> Dataset:
6738
)
6839
return ds
6940

41+
def _prepare_test_dataset(self) -> Dataset:
42+
# Stream to collect first N Japanese samples without downloading full split
43+
n = getattr(self.config, "max_dataset_len", 10)
44+
stream = load_dataset("wulipc/CC-OCR", "multi_lan_ocr", split="test", streaming=True)
45+
buf = {
46+
"index": [],
47+
"question_id": [],
48+
"question": [],
49+
"answer": [],
50+
"input_text": [],
51+
"image": [],
52+
}
53+
count = 0
54+
for ex in stream:
55+
if ex.get("l2-category") == "Japanese":
56+
buf["index"].append(str(count))
57+
buf["question_id"].append(str(count))
58+
buf["question"].append(ex["question"])
59+
buf["answer"].append(ex["answer"])
60+
buf["input_text"].append(ex["question"])
61+
buf["image"].append(ex["image"])
62+
count += 1
63+
if count >= n:
64+
break
65+
return Dataset.from_dict(buf)
66+
7067
@staticmethod
7168
def doc_to_text(doc) -> str:
7269
return doc["input_text"]

src/eval_mm/tasks/chartqa.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@ def __init__(self, config):
1919
def _prepare_dataset(self) -> Dataset:
2020
"""Load ChartQA validation set."""
2121
# Load the ChartQA dataset from lmms-lab
22-
ds = load_dataset("lmms-lab/ChartQA", split=self._maybe_slice_split("test"))
22+
ds = load_dataset("lmms-lab/ChartQA", split="test")
23+
ds = ds.map(lambda example, idx: {"question_id": idx}, with_indices=True)
24+
return ds
25+
26+
def _prepare_test_dataset(self) -> Dataset:
27+
n = getattr(self.config, "max_dataset_len", 10)
28+
ds = load_dataset("lmms-lab/ChartQA", split=f"test[:{n}]")
2329
ds = ds.map(lambda example, idx: {"question_id": idx}, with_indices=True)
2430
return ds
2531

src/eval_mm/tasks/chartqapro.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@ def __init__(self, config):
2121
def _prepare_dataset(self) -> Dataset:
2222
"""Load ChartQAPro test set."""
2323
# Load the ChartQAPro dataset from ahmed-masry
24-
ds = load_dataset("ahmed-masry/ChartQAPro", split=self._maybe_slice_split("test"))
24+
ds = load_dataset("ahmed-masry/ChartQAPro", split="test")
2525

2626
return ds
27+
28+
def _prepare_test_dataset(self) -> Dataset:
29+
n = getattr(self.config, "max_dataset_len", 10)
30+
ds = load_dataset("ahmed-masry/ChartQAPro", split=f"test[:{n}]")
31+
return ds
2732

2833
@staticmethod
2934
def doc_to_text(doc) -> str:

src/eval_mm/tasks/cvqa.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -37,40 +37,6 @@ class CVQA(Task):
3737
default_metric = "substring-match"
3838

3939
def _prepare_dataset(self) -> Dataset:
40-
# Use streaming during tests to ensure we pick N Japanese samples
41-
# even if they are sparse early in the split.
42-
n = getattr(self.config, "max_dataset_len", None)
43-
test_subset = os.getenv("PYTEST_CURRENT_TEST") or os.getenv("EVAL_MM_TEST_SUBSET") == "1"
44-
if n is not None and test_subset:
45-
stream = load_dataset("afaji/cvqa", split="test", streaming=True)
46-
buf = {
47-
"index": [],
48-
"question_id": [],
49-
"question": [],
50-
"question_en": [],
51-
"options": [],
52-
"translated_options": [],
53-
"answer": [],
54-
"answer_text": [],
55-
"image": [],
56-
}
57-
count = 0
58-
for ex in stream:
59-
if ex.get("Subset") == "('Japanese', 'Japan')":
60-
buf["index"].append(str(count))
61-
buf["question_id"].append(str(count))
62-
buf["question"].append(ex["Question"])
63-
buf["question_en"].append(ex.get("Translated Question"))
64-
buf["options"].append(ex["Options"])
65-
buf["translated_options"].append(ex.get("Translated Options"))
66-
buf["answer"].append(ex["Label"]) # 0~3
67-
buf["answer_text"].append(OPTIONS_MAP[ex["Label"]])
68-
buf["image"].append(ex["image"]) # keep original to lazily decode later
69-
count += 1
70-
if count >= n:
71-
break
72-
return Dataset.from_dict(buf)
73-
7440
ds = load_dataset("afaji/cvqa", split="test")
7541
ds = ds.filter(lambda x: x["Subset"] == "('Japanese', 'Japan')")
7642
ds = ds.map(
@@ -89,6 +55,38 @@ def _prepare_dataset(self) -> Dataset:
8955
)
9056
return ds
9157

58+
def _prepare_test_dataset(self) -> Dataset:
59+
# Stream to pick the first N Japanese samples and build a tiny Dataset
60+
n = getattr(self.config, "max_dataset_len", 10)
61+
stream = load_dataset("afaji/cvqa", split="test", streaming=True)
62+
buf = {
63+
"index": [],
64+
"question_id": [],
65+
"question": [],
66+
"question_en": [],
67+
"options": [],
68+
"translated_options": [],
69+
"answer": [],
70+
"answer_text": [],
71+
"image": [],
72+
}
73+
count = 0
74+
for ex in stream:
75+
if ex.get("Subset") == "('Japanese', 'Japan')":
76+
buf["index"].append(str(count))
77+
buf["question_id"].append(str(count))
78+
buf["question"].append(ex["Question"])
79+
buf["question_en"].append(ex.get("Translated Question"))
80+
buf["options"].append(ex["Options"])
81+
buf["translated_options"].append(ex.get("Translated Options"))
82+
buf["answer"].append(ex["Label"]) # 0~3
83+
buf["answer_text"].append(OPTIONS_MAP[ex["Label"]])
84+
buf["image"].append(ex["image"]) # keep original to lazily decode later
85+
count += 1
86+
if count >= n:
87+
break
88+
return Dataset.from_dict(buf)
89+
9290
@staticmethod
9391
def doc_to_text(doc) -> str:
9492
# Lazily construct the prompt to reduce preprocessing cost

src/eval_mm/tasks/docvqa.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,18 @@ def __init__(self, config):
1919
def _prepare_dataset(self) -> Dataset:
2020
"""Load DocVQA validation set."""
2121
# Load the DocVQA config from lmms-lab/DocVQA dataset
22-
ds = load_dataset("lmms-lab/DocVQA", "DocVQA", split=self._maybe_slice_split("validation"))
22+
ds = load_dataset("lmms-lab/DocVQA", "DocVQA", split="validation")
2323

2424
# Rename questionId to question_id for consistency
2525
ds = ds.rename_column("questionId", "question_id")
2626

2727
return ds
28+
29+
def _prepare_test_dataset(self) -> Dataset:
30+
n = getattr(self.config, "max_dataset_len", 10)
31+
ds = load_dataset("lmms-lab/DocVQA", "DocVQA", split=f"validation[:{n}]")
32+
ds = ds.rename_column("questionId", "question_id")
33+
return ds
2834

2935
@staticmethod
3036
def doc_to_text(doc) -> str:

src/eval_mm/tasks/infographicvqa.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,23 @@ def _prepare_dataset(self) -> Dataset:
2222
ds = load_dataset(
2323
"lmms-lab/DocVQA",
2424
"InfographicVQA",
25-
split=self._maybe_slice_split("validation"),
25+
split="validation",
2626
)
2727

2828
# Rename questionId to question_id for consistency
2929
ds = ds.rename_column("questionId", "question_id")
3030

3131
return ds
32+
33+
def _prepare_test_dataset(self) -> Dataset:
34+
n = getattr(self.config, "max_dataset_len", 10)
35+
ds = load_dataset(
36+
"lmms-lab/DocVQA",
37+
"InfographicVQA",
38+
split=f"validation[:{n}]",
39+
)
40+
ds = ds.rename_column("questionId", "question_id")
41+
return ds
3242

3343
@staticmethod
3444
def doc_to_text(doc) -> str:

src/eval_mm/tasks/ja_multi_image_vqa.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,14 @@ class JAMultiImageVQA(Task):
1414
default_metric = "rougel"
1515

1616
def _prepare_dataset(self) -> Dataset:
17-
ds = load_dataset("SakanaAI/JA-Multi-Image-VQA", split=self._maybe_slice_split("test"))
17+
ds = load_dataset("SakanaAI/JA-Multi-Image-VQA", split="test")
18+
ds = ds.rename_column("question", "input_text")
19+
ds = ds.map(lambda example, idx: {"question_id": idx}, with_indices=True)
20+
return ds
21+
22+
def _prepare_test_dataset(self) -> Dataset:
23+
n = getattr(self.config, "max_dataset_len", 10)
24+
ds = load_dataset("SakanaAI/JA-Multi-Image-VQA", split=f"test[:{n}]")
1825
ds = ds.rename_column("question", "input_text")
1926
ds = ds.map(lambda example, idx: {"question_id": idx}, with_indices=True)
2027
return ds

0 commit comments

Comments
 (0)