Skip to content

Commit 200f46f

Browse files
committed
Fix JIC-VQA dataset preparation
1 parent 9fdb65e commit 200f46f

File tree

6 files changed

+130
-96
lines changed

6 files changed

+130
-96
lines changed

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ For details on the data format and the list of supported data, please check [DAT
3232
- [How to Add Inference Code for a VLM Model](#how-to-add-inference-code-for-a-vlm-model)
3333
- [How to Add Dependencies](#how-to-add-dependencies)
3434
- [Formatting and Linting with ruff](#formatting-and-linting-with-ruff)
35+
- [Testing](#testing)
3536
- [How to Release to PyPI](#how-to-release-to-pypi)
3637
- [How to Update the Website](#how-to-update-the-website)
3738
- [Acknowledgements](#acknowledgements)
@@ -206,6 +207,14 @@ uv run ruff format src
206207
uv run ruff check --fix src
207208
```
208209

210+
### Testing
211+
212+
You can test task classes and metric classes with the following command:
213+
```
214+
bash test.sh
215+
```
216+
217+
209218
### How to Release to PyPI
210219

211220
```

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies = [
2020
"backoff>=2.2.1",
2121
"scipy>=1.15.1",
2222
"torch>=2.5.1",
23+
"webdataset>=0.2.111",
2324
]
2425
readme = "README.md"
2526
license = "Apache-2.0"

scripts/consistency_mecha_ja.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import json
33
import pandas as pd
44
import matplotlib.pyplot as plt
5-
import japanize_matplotlib
5+
import japanize_matplotlib # noqa
66
import numpy as np
77

88
# ======================================

scripts/prepare_jic_vqa.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from datasets import load_dataset
2+
import os
3+
import requests
4+
from PIL import Image
5+
from io import BytesIO
6+
import backoff
7+
import webdataset as wds
8+
from tqdm import tqdm
9+
10+
11+
# 画像をダウンロード
12+
@backoff.on_exception(
13+
backoff.expo, # 指数バックオフ
14+
requests.exceptions.RequestException, # 対象例外
15+
max_tries=5, # 最大リトライ回数
16+
)
17+
def download_image(image_url: str) -> Image:
18+
user_agent_string = (
19+
"Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0"
20+
)
21+
response = requests.get(
22+
image_url, headers={"User-Agent": user_agent_string}, timeout=10
23+
)
24+
response.raise_for_status()
25+
image = Image.open(BytesIO(response.content)).convert("RGB")
26+
return image
27+
28+
29+
def download_image_wrap(image_url: str) -> Image:
30+
try:
31+
return download_image(image_url)
32+
except Exception as e:
33+
print(f"Failed to process {image_url}: {e}")
34+
return None
35+
36+
37+
def get_domain_from_question(question: str) -> str:
38+
for keyword, domain in domain_dict.items():
39+
if keyword in question:
40+
return domain
41+
42+
43+
ds = load_dataset("line-corporation/JIC-VQA", split="train")
44+
45+
input_texts = []
46+
answers = []
47+
images = []
48+
question_ids = []
49+
domains = []
50+
51+
domain_dict = {
52+
"花": "jaflower30",
53+
"食べ物": "jafood101",
54+
"ランドマーク": "jalandmark10",
55+
"施設": "jafacility20",
56+
}
57+
58+
output_dir = "dataset/jic_vqa"
59+
os.makedirs(output_dir, exist_ok=True)
60+
if not os.path.exists(f"{output_dir}/images.tar"):
61+
with wds.TarWriter(f"{output_dir}/images.tar") as sink:
62+
for i, example in tqdm(enumerate(ds), total=len(ds)):
63+
image_url = example["url"]
64+
image = download_image_wrap(image_url)
65+
# resize
66+
if image is not None:
67+
image = image.resize((224, 224))
68+
image = image.convert("RGB")
69+
if image is None:
70+
continue
71+
sample = {
72+
"__key__": str(example["id"]),
73+
"jpg": image,
74+
"txt": example["category"],
75+
"url.txt": image_url,
76+
"question.txt": example["question"],
77+
}
78+
sink.write(sample)
79+
80+
ds = load_dataset("webdataset", data_files=f"{output_dir}/images.tar", split="train")
81+
print(ds)
82+
print(ds[0])
83+
84+
ds = ds.remove_columns(["__url__"])
85+
ds = ds.rename_columns(
86+
{
87+
"txt": "category",
88+
"url.txt": "url",
89+
"question.txt": "question",
90+
}
91+
)
92+
93+
# Phase 2: Load images and populate data structures
94+
ds = ds.map(
95+
lambda x: {
96+
"input_text": x["question"].decode("utf-8"),
97+
"url": x["url"].decode("utf-8").encode("utf-8"),
98+
"answer": str(x["category"]),
99+
"image": x["jpg"],
100+
"question_id": int(x["__key__"]),
101+
"domain": get_domain_from_question(str(x["question"].decode("utf-8"))),
102+
}
103+
)
104+
ds = ds.remove_columns(["question", "__key__", "jpg"])
105+
106+
print(ds)
107+
print(ds[0])
108+
# {'category': 'ガソリンスタンド', 'url': b'https://live.staticflickr.com/5536/11190751074_f97587084e_o.jpg', 'input_text': "この画像にはどの施設が映っていますか?次の四つの選択肢から正しいものを選んでください: ['スーパーマーケット', 'コンビニ', '駐車場', 'ガソリンスタンド']", 'answer': 'ガソリンスタンド', 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=224x224 at 0x7F83A660F710>, 'question_id': '11190751074', 'domain': 'jafacility20'}
109+
ds.to_parquet("dataset/jic_vqa.parquet")

src/eval_mm/tasks/jic_vqa.py

Lines changed: 8 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,112 +1,25 @@
1-
import time
2-
import warnings
3-
from io import BytesIO
4-
5-
import requests
61
from PIL import Image
72
from datasets import Dataset, load_dataset
8-
from huggingface_hub import cached_assets_path
93

104
from ..api.registry import register_task
115
from ..api.task import Task
126
from eval_mm.metrics import ScorerRegistry
13-
from tqdm import tqdm
7+
import os
148

159

1610
@register_task("jic-vqa")
1711
class JICVQA(Task):
1812
@staticmethod
1913
def _prepare_dataset() -> Dataset:
20-
cache_dir = cached_assets_path(
21-
library_name="datasets", namespace="JICVQA", subfolder="download"
22-
)
23-
24-
dataset = load_dataset("line-corporation/JIC-VQA")
25-
input_texts = []
26-
answers = []
27-
images = []
28-
question_ids = []
29-
domains = []
30-
31-
domain_dict = {
32-
"花": "jaflower30",
33-
"食べ物": "jafood101",
34-
"ランドマーク": "jalandmark10",
35-
"施設": "jafacility20",
36-
}
37-
38-
def get_domain_from_question(question):
39-
for keyword, domain in domain_dict.items():
40-
if keyword in question:
41-
return domain
42-
43-
def download_image(url, image_id):
44-
# TODO: Multi-threading for faster download
45-
img_format = url.split(".")[-1]
46-
image_path = cache_dir / f"{image_id}.{img_format}"
47-
if image_path.exists():
48-
return
49-
50-
max_attempts = 5
51-
attempt_errors = []
52-
for _ in range(max_attempts):
53-
try:
54-
response = requests.get(url)
55-
if response.status_code == 200:
56-
image = Image.open(BytesIO(response.content))
57-
image.save(image_path)
58-
print(f"Downloaded: {image_path}")
59-
wait_time = 1.0
60-
time.sleep(wait_time)
61-
return
62-
else:
63-
error_msg = f"Status code: {response.status_code}"
64-
attempt_errors.append(error_msg)
65-
66-
except Exception as e:
67-
error_msg = f"Exception: {e}"
68-
attempt_errors.append(error_msg)
69-
70-
warnings.warn(
71-
f"Failed to download {url} after {max_attempts} attempts. Errors: {attempt_errors}"
14+
if not os.path.exists("dataset/jic_vqa.parquet"):
15+
raise FileNotFoundError(
16+
"Dataset not found. Please run `scripts/prepare_jic_vqa.py` to prepare the dataset."
7217
)
7318

74-
# Phase 1: Download all images
75-
for subset in dataset:
76-
for entry in tqdm(dataset[subset], desc=f"Downloading {subset} images"):
77-
url = entry["url"]
78-
image_id = entry["id"]
79-
download_image(url, image_id)
80-
81-
# Phase 2: Load images and populate data structures
82-
for subset in dataset:
83-
for entry in dataset[subset]:
84-
image_id = entry["id"]
85-
img_format = entry["url"].split(".")[-1]
86-
image_path = cache_dir / f"{image_id}.{img_format}"
87-
88-
if not image_path.exists():
89-
warnings.warn(f"The image path {image_path} does not exist.")
90-
continue
91-
try:
92-
image = Image.open(image_path)
93-
except Exception as e:
94-
print(f"{e} : Failed to open {image_path}")
95-
images.append(image)
96-
input_texts.append(entry["question"])
97-
answers.append(entry["category"])
98-
question_ids.append(image_id)
99-
domain = get_domain_from_question(entry["question"])
100-
domains.append(domain)
101-
102-
data_dict = {
103-
"input_text": input_texts,
104-
"answer": answers,
105-
"image": images,
106-
"question_id": question_ids,
107-
"domain": domains,
108-
}
109-
return Dataset.from_dict(data_dict)
19+
dataset = load_dataset(
20+
"parquet", data_files="dataset/jic_vqa.parquet", split="train"
21+
)
22+
return dataset
11023

11124
@staticmethod
11225
def doc_to_text(doc) -> str:

test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
uv run pytest src/eval_mm/tasks/*.py
2+
uv run pytest src/eval_mm/metrics/*.py

0 commit comments

Comments
 (0)