Skip to content

Commit 03757c7

Browse files
silviaseclaude
andcommitted
feat: Add AI2D task and refactor registries to use decorator pattern
- Add new AI2D task that loads the lmms-lab/ai2d dataset - Refactor task registry to use @register_task decorator pattern - Refactor scorer registry to use @register_scorer decorator pattern - Update all task and scorer classes to use the new decorators - Support multiple registration names for each task/scorer - Maintain full backward compatibility with existing API This change eliminates duplication between __init__.py and registry files, making it easier to add new tasks and scorers. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 0a5e161 commit 03757c7

29 files changed

+167
-67
lines changed

src/eval_mm/metrics/cc_ocr_scorer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import List, Dict, Any, cast # Added cast for type hinting clarity
44

55
from .scorer import Scorer, AggregateOutput, ScorerConfig
6+
from .scorer_registry import register_scorer
67

78

89
def token_normalize(
@@ -158,6 +159,7 @@ def calculate_metrics(
158159

159160

160161
# CCOCRScorer class, specialized for Japanese (character-level, no alphanum_only)
162+
@register_scorer("cc-ocr", "CC-OCR", "CCOCRScorer")
161163
class CCOCRScorer(Scorer):
162164
def __init__(self, config: ScorerConfig):
163165
super().__init__(config)

src/eval_mm/metrics/exact_match_scorer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .scorer import Scorer, AggregateOutput
2+
from .scorer_registry import register_scorer
23

34

5+
@register_scorer("exact-match", "ExactMatch", "ExactMatchScorer")
46
class ExactMatchScorer(Scorer):
57
@staticmethod
68
def score(refs: list[str], preds: list[str]) -> list[int]:

src/eval_mm/metrics/heron_bench_scorer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from collections import defaultdict
33
import numpy as np
44
from eval_mm.metrics.scorer import Scorer, AggregateOutput
5+
from .scorer_registry import register_scorer
56
import re
67
import json
78

@@ -110,6 +111,7 @@ def ask_gpt4_batch(
110111
return completions
111112

112113

114+
@register_scorer("heron-bench", "HeronBench", "HeronBenchScorer")
113115
class HeronBenchScorer(Scorer):
114116
def score(self, refs, preds: list[str]) -> list[dict[str, int]]:
115117
docs = self.config.docs

src/eval_mm/metrics/jdocqa_scorer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from eval_mm.metrics.scorer import Scorer, AggregateOutput
2+
from .scorer_registry import register_scorer
23
from sacrebleu import sentence_bleu
34
from unicodedata import normalize
45

@@ -38,6 +39,7 @@ def bleu_ja(refs, pred):
3839
return bleu_score.score / 100
3940

4041

42+
@register_scorer("jdocqa", "JDocQA", "JDocQAScorer")
4143
class JDocQAScorer(Scorer):
4244
def score(self, refs: list[str], preds: list[str]) -> list[int]:
4345
docs = self.config.docs

src/eval_mm/metrics/jic_vqa_scorer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .scorer import Scorer, AggregateOutput
2+
from .scorer_registry import register_scorer
23

34

5+
@register_scorer("jic-vqa", "JICVQA", "JICVQAScorer")
46
class JICVQAScorer(Scorer):
57
@staticmethod
68
def score(refs: list[str], preds: list[str]) -> list[int]:

src/eval_mm/metrics/jmmmu_scorer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
from datasets import Dataset
66
from eval_mm.metrics.scorer import Scorer, AggregateOutput
7+
from .scorer_registry import register_scorer
78

89

910
DOMAIN_CAT2SUB_CAT = {
@@ -411,6 +412,7 @@ def get_score(doc: Dataset, pred: str, random_choice: bool = False) -> int:
411412
return score
412413

413414

415+
@register_scorer("jmmmu", "JMMMU", "JMMMUScorer")
414416
class JMMMUScorer(Scorer):
415417
def score(self, refs: list[str], preds: list[str]) -> list[int]:
416418
docs = self.config.docs

src/eval_mm/metrics/llm_as_a_judge_scorer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from eval_mm.metrics.scorer import Scorer, AggregateOutput
2+
from .scorer_registry import register_scorer
23
from tqdm import tqdm
34
import re
45

@@ -24,6 +25,7 @@
2425
"""
2526

2627

28+
@register_scorer("llm-as-a-judge", "LLM-as-a-Judge", "LlmAsaJudgeScorer")
2729
class LlmAsaJudgeScorer(Scorer):
2830
def score(
2931
self,

src/eval_mm/metrics/mecha_ja_scorer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# mecha-ja-scorer.py
22
from .scorer import Scorer, AggregateOutput
3+
from .scorer_registry import register_scorer
34
import re
45
from collections import defaultdict
56

@@ -9,6 +10,7 @@
910
}
1011

1112

13+
@register_scorer("mecha-ja", "MECHAJa", "MECHAJaScorer")
1214
class MECHAJaScorer(Scorer):
1315
@staticmethod
1416
def _parse_rotation_id(qid: str) -> str:

src/eval_mm/metrics/mmmu_scorer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55
from datasets import Dataset
66
from .scorer import Scorer, AggregateOutput
7+
from .scorer_registry import register_scorer
78

89
DOMAIN_CAT2SUB_CAT = {
910
"Art and Design": ["Art", "Art_Theory", "Design", "Music"],
@@ -410,6 +411,7 @@ def get_score(doc: Dataset, pred: str, random_choice: bool) -> int:
410411
return score
411412

412413

414+
@register_scorer("mmmu", "MMMU", "MMMUScorer")
413415
class MMMUScorer(Scorer):
414416
def score(self, refs: list[str], preds: list[str]) -> list[int]:
415417
docs = self.config.docs

src/eval_mm/metrics/rougel_scorer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import emoji
77
import unicodedata
88
from .scorer import Scorer, AggregateOutput
9+
from .scorer_registry import register_scorer
910
from concurrent.futures import ProcessPoolExecutor, Future
1011

1112

@@ -72,6 +73,7 @@ def rouge_ja(refs: list[str], preds: list[str]) -> dict:
7273
return {type: result[type].mid.fmeasure * 100 for type in rouge_types}
7374

7475

76+
@register_scorer("rougel", "RougeL", "RougeLScorer")
7577
class RougeLScorer(Scorer):
7678
@staticmethod
7779
def score(refs: list[str], preds: list[str]) -> list[float]:

0 commit comments

Comments
 (0)