-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
46 lines (41 loc) · 2.87 KB
/
dataset.py
File metadata and controls
46 lines (41 loc) · 2.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import json
from datasets import Dataset, Features, Value, Sequence, Image, load_dataset
def get_eval_dataset(dataset_name: str, root_dir: str):
if dataset_name == 'A-OKVQA':
dataset = load_dataset(f'{root_dir}/A-OKVQA', split='test')
elif dataset_name == 'CVQA':
dataset = load_dataset(f'{root_dir}/cvqa', split='test')
elif dataset_name == 'ScienceQA':
problems_file = f'submodules/ScienceQA/data/scienceqa/problems.json'
with open(problems_file, 'r') as f:
problems = json.load(f)
chosen_split = 'test'
get_image_path = lambda qid, sample: f'{root_dir}/ScienceQA/test_images/{qid}/{sample['image']}'
filtered_problems = { k: v for k, v in problems.items() if v['split'] == chosen_split and v['image'] is not None }
dataset_features = Features({
'question_id': Value('string'), 'question': Value('string'), 'choices': Sequence(Value('string')),
'answer': Value('string'), 'hint': Value('string'), 'task': Value('string'),
'grade': Value('string'), 'subject': Value('string'), 'topic': Value('string'),
'category': Value('string'), 'skill': Value('string'), 'lecture': Value('string'),
'solution': Value('string'), 'image': Image(),
})
dataset_dict = {
'question_id': [ qid for qid, _ in filtered_problems.items() ],
'question': [ data['question'] for _, data in filtered_problems.items() ],
'choices': [ data['choices'] for _, data in filtered_problems.items() ],
'answer': [ data['answer'] for _, data in filtered_problems.items() ],
'hint': [ data['hint'] for _, data in filtered_problems.items() ],
'task': [ data['task'] for _, data in filtered_problems.items() ],
'grade': [ data['grade'] for _, data in filtered_problems.items() ],
'subject': [ data['subject'] for _, data in filtered_problems.items() ],
'topic': [ data['topic'] for _, data in filtered_problems.items() ],
'category': [ data['category'] for _, data in filtered_problems.items() ],
'skill': [ data['skill'] for _, data in filtered_problems.items() ],
'lecture': [ data['lecture'] for _, data in filtered_problems.items() ],
'solution': [ data['solution'] for _, data in filtered_problems.items() ],
'image': [ get_image_path(qid, data) for qid, data in filtered_problems.items() ],
}
dataset = Dataset.from_dict(dataset_dict, features=dataset_features)
else:
raise ValueError(f'Unsupported dataset "{dataset_name}"')
return dataset