Skip to content

Commit 99d0f21

Browse files
committed
Add task_id_list option
1 parent 7bfbfc4 commit 99d0f21

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

scripts/make_leaderboard.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def load_evaluation_data(result_dir: str, model: str, task_dirs: List[str]) -> d
4747
evaluation = json.load(f)
4848

4949
for metric, aggregate_output in evaluation.items():
50-
if metric not in eval_mm.metrics.ScorerRegistry._scorers.keys():
50+
if metric not in eval_mm.ScorerRegistry.get_metric_list():
5151
logger.warning(f"Skipping unsupported metric: {metric}")
5252
continue
5353

@@ -56,15 +56,24 @@ def load_evaluation_data(result_dir: str, model: str, task_dirs: List[str]) -> d
5656
return model_results
5757

5858

59-
def process_results(result_dir: str, model_list: List[str]) -> pd.DataFrame:
59+
def process_results(
60+
result_dir: str,
61+
model_list: List[str],
62+
add_avg: bool = False,
63+
task_id_list: Optional[List[str]] = None,
64+
) -> pd.DataFrame:
6065
"""Process all evaluation results into a structured DataFrame."""
61-
task_dirs = [d for d in os.listdir(result_dir) if not d.startswith(".")]
66+
if task_id_list:
67+
task_dirs = task_id_list
68+
else:
69+
task_dirs = [d for d in os.listdir(result_dir) if not d.startswith(".")]
70+
6271
df = pd.DataFrame()
6372

6473
for model in model_list:
6574
logger.info(f"Processing results for {model}")
6675
model_results = load_evaluation_data(result_dir, model, task_dirs)
67-
if not model_results:
76+
if len(model_results) == 1:
6877
continue
6978
df = df._append(model_results, ignore_index=True)
7079

@@ -75,11 +84,12 @@ def process_results(result_dir: str, model_list: List[str]) -> pd.DataFrame:
7584
for k in df.columns
7685
}
7786
)
78-
# すべてのスコアを 100 点満点に正規化
79-
df_normalized = df.apply(lambda x: x / x.max() * 100, axis=0)
87+
if add_avg:
88+
# すべてのスコアを 100 点満点に正規化
89+
df_normalized = df.apply(lambda x: x / x.max() * 100, axis=0)
8090

81-
# 各モデルの全体スコア(平均)を計算し、最後の列に追加
82-
df["Avg/Avg"] = df_normalized.mean(axis=1).round(2)
91+
# 各モデルの全体スコア(平均)を計算し、最後の列に追加
92+
df["Avg/Avg"] = df_normalized.mean(axis=1).round(2)
8393

8494
return df
8595

@@ -172,6 +182,7 @@ def format_output(df: pd.DataFrame, output_format: str) -> str:
172182
df.loc[top2_model, col] = f"<u>{top2_score}</u>"
173183

174184
df = df.fillna("")
185+
175186
if output_format == "markdown":
176187
return df.to_markdown(mode="github", floatfmt=".3g")
177188
elif output_format == "latex":
@@ -187,8 +198,10 @@ def main(
187198
plot_bar: bool = False,
188199
plot_corr: bool = False,
189200
update_pages: bool = False,
201+
add_avg: bool = False,
202+
task_id_list: Optional[List[str]] = None,
190203
):
191-
df = process_results(result_dir, model_list)
204+
df = process_results(result_dir, model_list, add_avg, task_id_list)
192205
if plot_corr:
193206
plot_correlation(df.copy(), "correlation.png")
194207
# plot_correlation(df.T, "correlation_model.png")
@@ -236,6 +249,14 @@ def parse_args():
236249
parser.add_argument(
237250
"--update_pages", action="store_true", help="Update the GitHub Pages JSON"
238251
)
252+
parser.add_argument(
253+
"--add_avg", action="store_true", help="Add average score column"
254+
)
255+
parser.add_argument(
256+
"--task_id_list",
257+
type=str,
258+
help="List of task IDs to include in the leaderboard (e.g. jmmmu,mmmu). If not specified, all tasks will be included.",
259+
)
239260
return parser.parse_args()
240261

241262

@@ -267,6 +288,7 @@ def parse_args():
267288
"microsoft/Phi-4-multimodal-instruct",
268289
"gpt-4o-2024-11-20",
269290
]
291+
print(args.task_id_list)
270292
main(
271293
args.result_dir,
272294
model_list,
@@ -275,4 +297,6 @@ def parse_args():
275297
args.plot_bar,
276298
args.plot_corr,
277299
args.update_pages,
300+
args.add_avg,
301+
args.task_id_list.split(",") if args.task_id_list else None,
278302
)

0 commit comments

Comments
 (0)