|
| 1 | +# ReasonMap Plus Dataset is an extension of the original ReasonMap dataset, |
| 2 | +# designed for providing more dense rewards in visual understanding and reasoning |
| 3 | +# tasks. |
| 4 | +# The reference paper is: |
| 5 | +# 1. Can MLLMs Guide Me Home? A Benchmark Study on Fine-Grained Visual |
| 6 | +# Reasoning from Transit Maps: https://arxiv.org/abs/2505.18675 |
| 7 | +# 2. RewardMap: Tackling Sparse Rewards in Fine-grained Visual Reasoning via |
| 8 | +# Multi-Stage Reinforcement Learning: https://arxiv.org/abs/2510.02240 |
| 9 | +# |
| 10 | +# If any problem occurs, please open an issue on GitHub |
| 11 | +# (https://github.com/fscdc/RewardMap or https://github.com/fscdc/ReasonMap). |
| 12 | + |
| 13 | +import os |
| 14 | +import re |
| 15 | +import pandas as pd |
| 16 | +import numpy as np |
| 17 | +from pathlib import Path |
| 18 | +from typing import Any |
| 19 | +from vlmeval.dataset.image_base import ImageBaseDataset |
| 20 | +from vlmeval.smp import load, d2df |
| 21 | + |
| 22 | +_BOXED_PAT = re.compile(r'(?:\\boxed|boxed)\{([^}]*)\}', re.IGNORECASE) |
| 23 | +_TEXT_PAT = re.compile(r'\\text\{([^}]*)\}', re.IGNORECASE) |
| 24 | + |
| 25 | +_YES = {"yes", "y", "true", "t", "1"} |
| 26 | +_NO = {"no", "n", "false", "f", "0"} |
| 27 | + |
| 28 | + |
| 29 | +def _strip(s: Any) -> str: |
| 30 | + return ("" if s is None else str(s)).strip() |
| 31 | + |
| 32 | + |
| 33 | +def _lower(s: Any) -> str: |
| 34 | + return _strip(s).lower() |
| 35 | + |
| 36 | + |
| 37 | +def _extract_boxed(s: str) -> str | None: |
| 38 | + m = list(_BOXED_PAT.finditer(s)) |
| 39 | + if not m: |
| 40 | + return None |
| 41 | + raw = m[-1].group(1).strip() |
| 42 | + texts = _TEXT_PAT.findall(raw) |
| 43 | + return " ".join(t.strip() for t in texts) if texts else raw |
| 44 | + |
| 45 | + |
| 46 | +def _extract_after_phrases(s: str) -> str: |
| 47 | + phrases = [ |
| 48 | + "the final answer is", "final answer is", |
| 49 | + "the answer is", "answer is", |
| 50 | + "the correct answer is", "correct answer is", |
| 51 | + "final answer:", "final:", "answer:", "ans:" |
| 52 | + ] |
| 53 | + lo = s.lower() |
| 54 | + for ph in phrases: |
| 55 | + if ph in lo: |
| 56 | + part = s[lo.rfind(ph) + len(ph):].strip() |
| 57 | + cand = re.split(r'(?:\n|\. |\.$)', part, maxsplit=1)[0] |
| 58 | + return cand.strip() |
| 59 | + return s.strip() |
| 60 | + |
| 61 | + |
| 62 | +def _normalize_yesno(s: str) -> str | None: |
| 63 | + t = _lower(s) |
| 64 | + if t in _YES: |
| 65 | + return "yes" |
| 66 | + if t in _NO: |
| 67 | + return "no" |
| 68 | + return None |
| 69 | + |
| 70 | + |
| 71 | +def _normalize_abcd(s: str) -> str | None: |
| 72 | + m = re.search(r'\b([ABCD])\b', s, flags=re.IGNORECASE) |
| 73 | + return m.group(1).upper() if m else None |
| 74 | + |
| 75 | + |
| 76 | +def _extract_int(s: str) -> int | None: |
| 77 | + m = re.search(r'[-+]?\d+', s) |
| 78 | + return int(m.group(0)) if m else None |
| 79 | + |
| 80 | + |
| 81 | +def normalize_prediction(pred_raw: Any, typ: str) -> str: |
| 82 | + s = _strip(pred_raw) |
| 83 | + if not s: |
| 84 | + return "" |
| 85 | + |
| 86 | + boxed = _extract_boxed(s) |
| 87 | + cand = boxed if boxed else _extract_after_phrases(s) |
| 88 | + |
| 89 | + t = (typ or "").lower() |
| 90 | + if "torf" in t: |
| 91 | + yn = _normalize_yesno(cand) |
| 92 | + if yn is None: |
| 93 | + yn = _normalize_yesno(s) |
| 94 | + return yn or cand |
| 95 | + |
| 96 | + if t == "counting1" or "counting1" in t: |
| 97 | + abcd = _normalize_abcd(cand) |
| 98 | + if abcd is None: |
| 99 | + abcd = _normalize_abcd(s) |
| 100 | + return abcd or cand |
| 101 | + |
| 102 | + if t in {"counting2", "counting3"} or t.startswith("counting"): |
| 103 | + num = _extract_int(cand) |
| 104 | + if num is None: |
| 105 | + num = _extract_int(s) |
| 106 | + return str(num) if num is not None else cand |
| 107 | + |
| 108 | + return cand |
| 109 | + |
| 110 | + |
| 111 | +class ReasonMap_Plus(ImageBaseDataset): |
| 112 | + TYPE = "VQA" |
| 113 | + DATASET_URL = { |
| 114 | + "ReasonMap-Plus": "https://opencompass.openxlab.space/utils/VLMEval/ReasonMap-Plus.tsv" |
| 115 | + } |
| 116 | + |
| 117 | + DATASET_MD5 = { |
| 118 | + "ReasonMap-Plus": "205d3ac1c3af07d3e4930f25e01008be" |
| 119 | + } |
| 120 | + |
| 121 | + @classmethod |
| 122 | + def supported_datasets(cls): |
| 123 | + return ['ReasonMap-Plus'] |
| 124 | + |
| 125 | + def build_prompt(self, line): |
| 126 | + if not isinstance(line, pd.Series): |
| 127 | + line = self.data_df.iloc[line] |
| 128 | + |
| 129 | + img_val = line.get("image", None) |
| 130 | + if not img_val: |
| 131 | + img_val = line.get("image_path", "") |
| 132 | + prompt = line.get("question", "") |
| 133 | + |
| 134 | + return [ |
| 135 | + dict(type="image", value=img_val), |
| 136 | + dict(type="text", value=prompt), |
| 137 | + ] |
| 138 | + |
| 139 | + def evaluate(self, eval_file, **judge_kwargs): |
| 140 | + df = load(eval_file) |
| 141 | + if len(df) == 0: |
| 142 | + return pd.DataFrame([dict(metric="accuracy", value=0.0, n=0)]) |
| 143 | + |
| 144 | + df["_pred_norm"] = [ |
| 145 | + normalize_prediction(p, t) |
| 146 | + for p, t in zip(df.get("prediction", ""), df.get("type", "")) |
| 147 | + ] |
| 148 | + |
| 149 | + def _score_one(a, p, t): |
| 150 | + tlo = (t or "").lower() |
| 151 | + try: |
| 152 | + if "torf" in tlo: |
| 153 | + gt = "yes" if int(a) == 1 else "no" |
| 154 | + pp = _normalize_yesno(p) |
| 155 | + return 1 if (pp == gt) else 0 |
| 156 | + |
| 157 | + if tlo == "counting1" or "counting1" in tlo: |
| 158 | + mapping = {"A": 0, "B": 1, "C": 2, "D": 3} |
| 159 | + pp = _normalize_abcd(p) |
| 160 | + if pp is None: |
| 161 | + return 0 |
| 162 | + return 1 if mapping[pp] == int(a) else 0 |
| 163 | + |
| 164 | + if tlo in {"counting2", "counting3"} or tlo.startswith("counting"): |
| 165 | + return 1 if int(str(p)) == int(a) else 0 |
| 166 | + |
| 167 | + return 1 if _strip(a).lower() == _strip(p).lower() else 0 |
| 168 | + except Exception: |
| 169 | + return 0 |
| 170 | + |
| 171 | + difficulty_weights = { |
| 172 | + "easy": 1.0, |
| 173 | + "middle": 1.5, |
| 174 | + "hard": 2.0 |
| 175 | + } |
| 176 | + |
| 177 | + def _score_weighted_one(a, p, t, difficulty): |
| 178 | + weighted_acc = difficulty_weights[difficulty] |
| 179 | + tlo = (t or "").lower() |
| 180 | + try: |
| 181 | + if "torf" in tlo: |
| 182 | + gt = "yes" if int(a) == 1 else "no" |
| 183 | + pp = _normalize_yesno(p) |
| 184 | + return weighted_acc if (pp == gt) else 0 |
| 185 | + |
| 186 | + if tlo == "counting1" or "counting1" in tlo: |
| 187 | + mapping = {"A": 0, "B": 1, "C": 2, "D": 3} |
| 188 | + pp = _normalize_abcd(p) |
| 189 | + if pp is None: |
| 190 | + return 0 |
| 191 | + return weighted_acc if mapping[pp] == int(a) else 0 |
| 192 | + |
| 193 | + if tlo in {"counting2", "counting3"} or tlo.startswith("counting"): |
| 194 | + return weighted_acc if int(str(p)) == int(a) else 0 |
| 195 | + |
| 196 | + return ( |
| 197 | + weighted_acc if _strip(a).lower() == _strip(p).lower() else 0 |
| 198 | + ) |
| 199 | + except Exception: |
| 200 | + return 0 |
| 201 | + |
| 202 | + df["_correct"] = [ |
| 203 | + _score_one(a, p, t) |
| 204 | + for a, p, t in zip(df.get("answer", ""), df["_pred_norm"], df.get("type", "")) |
| 205 | + ] |
| 206 | + |
| 207 | + df["_weighted_correct"] = [ |
| 208 | + _score_weighted_one(a, p, t, difficulty) |
| 209 | + for a, p, t, difficulty in zip( |
| 210 | + df.get("answer", ""), |
| 211 | + df["_pred_norm"], |
| 212 | + df.get("type", ""), |
| 213 | + df.get("difficulty_city", ""), |
| 214 | + ) |
| 215 | + ] |
| 216 | + |
| 217 | + total = np.sum(difficulty_weights[a] for a in df.get("difficulty_city", "")) |
| 218 | + |
| 219 | + overall = float(np.mean(df["_correct"])) if len(df) else 0.0 |
| 220 | + weighted_overall = ( |
| 221 | + float(np.sum(df["_weighted_correct"]) / total) if len(df) else 0.0 |
| 222 | + ) |
| 223 | + |
| 224 | + out_rows = [ |
| 225 | + dict(metric="accuracy", value=overall, n=len(df)), |
| 226 | + dict(metric="weighted_accuracy", value=weighted_overall, n=len(df)), |
| 227 | + ] |
| 228 | + |
| 229 | + for tname, sub in df.groupby(df.get("type", "")): |
| 230 | + total_sub = np.sum( |
| 231 | + difficulty_weights[a] for a in sub.get("difficulty_city", "") |
| 232 | + ) |
| 233 | + if len(sub): |
| 234 | + out_rows.append( |
| 235 | + dict( |
| 236 | + metric=f"accuracy[{tname}]", |
| 237 | + value=float(np.mean(sub["_correct"])), |
| 238 | + n=len(sub), |
| 239 | + ) |
| 240 | + ) |
| 241 | + out_rows.append( |
| 242 | + dict( |
| 243 | + metric=f"weighted_accuracy[{tname}]", |
| 244 | + value=float(np.sum(sub["_weighted_correct"]) / total_sub), |
| 245 | + n=len(sub), |
| 246 | + ) |
| 247 | + ) |
| 248 | + out_df = pd.DataFrame(out_rows, columns=["metric", "value", "n"]) |
| 249 | + try: |
| 250 | + eval_path = Path(eval_file) |
| 251 | + out_path = eval_path.with_name(f"{eval_path.stem}_metrics.tsv") |
| 252 | + out_df.to_csv(out_path, sep="\t", index=False) |
| 253 | + except TypeError: |
| 254 | + pass |
| 255 | + |
| 256 | + return pd.DataFrame(out_rows, columns=["metric", "value", "n"]) |
0 commit comments