@@ -10,6 +10,7 @@ def parse_args():
1010 parser = ArgumentParser ()
1111 parser .add_argument ("--task_id" , type = str , default = "japanese-heron-bench" )
1212 parser .add_argument ("--result_dir" , type = str , default = "result" )
13+ parser .add_argument ("--model_list" , type = str , nargs = "+" , default = [])
1314
1415 return parser .parse_args ()
1516
@@ -23,18 +24,11 @@ def scrollable_text(text):
2324if __name__ == "__main__" :
2425 args = parse_args ()
2526
26- task = eval_mm .tasks .TaskRegistry ().get_task_cls (args .task_id )(
27- eval_mm .tasks .TaskConfig ()
28- )
27+ task = eval_mm .TaskRegistry ().load_task (args .task_id )
2928
3029 # Load model prediction
31- model_list = [
32- "google/gemma-3-12b-it" ,
33- "google/gemma-3-27b-it" ,
34- "microsoft/Phi-4-multimodal-instruct" ,
35- ]
3630 predictions_per_model = {}
37- for model_id in model_list :
31+ for model_id in args . model_list :
3832 prediction_path = os .path .join (
3933 args .result_dir , args .task_id , model_id , "prediction.jsonl"
4034 )
@@ -50,8 +44,8 @@ def scrollable_text(text):
5044
5145 SAMPLES_PER_PAGE = 30 # 1ページに表示する件数
5246 # Question ID, Image, Question, Answer, Prediction_model1, Prediction_model2,..
53- column_width_list = [1 , 3 , 3 , 3 ] + [4 ] * len (model_list )
54- st .write (f"# { args .task_id } dataset " )
47+ column_width_list = [1 , 3 , 3 , 3 ] + [4 ] * len (args . model_list )
48+ st .write (f"# { args .task_id } " )
5549
5650 def show_sample (idx ):
5751 sample = ds [idx ]
@@ -64,8 +58,8 @@ def show_sample(idx):
6458 cols [3 ].markdown (
6559 scrollable_text (task .doc_to_answer (sample )), unsafe_allow_html = True
6660 )
67- for model_id in model_list :
68- cols [4 + model_list .index (model_id )].markdown (
61+ for model_id in args . model_list :
62+ cols [4 + args . model_list .index (model_id )].markdown (
6963 scrollable_text (predictions_per_model [model_id ][idx ]["text" ]),
7064 unsafe_allow_html = True ,
7165 )
@@ -93,8 +87,10 @@ def show_sample(idx):
9387 header_cols [1 ].markdown ("Image" )
9488 header_cols [2 ].markdown ("Question" )
9589 header_cols [3 ].markdown ("Answer" )
96- for model_id in model_list :
97- header_cols [4 + model_list .index (model_id )].markdown (f"Prediction ({ model_id } )" )
90+ for model_id in args .model_list :
91+ header_cols [4 + args .model_list .index (model_id )].markdown (
92+ f"Prediction ({ model_id } )"
93+ )
9894
9995 # サンプルを表示
10096 for idx in range (start_idx , end_idx ):
0 commit comments