Skip to content

Commit 9aaa4d0

Browse files
committed
feat: add InternVL3 size variants to vllm registry and eval script
1 parent 44ff382 commit 9aaa4d0

File tree

3 files changed

+274
-20
lines changed

3 files changed

+274
-20
lines changed

eval_with_vllm.sh

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,37 @@ export CUDA_VISIBLE_DEVICES=0,1
66
# Model name to group name mapping
77
declare -A MODEL_GROUP_MAP=(
88
["Qwen/Qwen3-VL-30B-A3B-Instruct"]="vllm_normal"
9+
# ["moonshotai/Kimi-VL-A3B-Instruct"]="vllm_normal" # 今は動かない
10+
["OpenGVLab/InternVL3-1B"]="vllm_normal"
11+
["OpenGVLab/InternVL3-2B"]="vllm_normal"
12+
["OpenGVLab/InternVL3-8B"]="vllm_normal"
13+
["OpenGVLab/InternVL3-14B"]="vllm_normal"
14+
["OpenGVLab/InternVL3-38B"]="vllm_normal"
15+
["OpenGVLab/InternVL3-78B"]="vllm_normal"
916
)
1017

1118
declare -a task_list=(
1219
"japanese-heron-bench"
13-
# "ja-vlm-bench-in-the-wild"
14-
# "ja-vg-vqa-500"
15-
# "jmmmu"
16-
# "ja-multi-image-vqa"
17-
# "jdocqa"
18-
# "mmmu"
19-
# "llava-bench-in-the-wild"
20-
# "jic-vqa"
21-
# "cvqa"
22-
# "cc-ocr"
23-
# "mecha-ja"
24-
# "ai2d"
20+
"ja-vlm-bench-in-the-wild"
21+
"ja-vg-vqa-500"
22+
"jmmmu"
23+
"ja-multi-image-vqa"
24+
"jdocqa"
25+
"mmmu"
26+
"llava-bench-in-the-wild"
27+
"jic-vqa"
28+
"cvqa"
29+
"cc-ocr"
30+
"mecha-ja"
31+
"ai2d"
2532
# "blink"
26-
# "docvqa"
27-
# "infographicvqa"
28-
# "textvqa"
29-
# "chartqa"
33+
"docvqa"
34+
"infographicvqa"
35+
"textvqa"
36+
"chartqa"
3037
# "chartqapro"
3138
# "mathvista"
32-
# "okvqa"
39+
"okvqa"
3340
)
3441

3542
# === Metrics Mapping ===
@@ -68,8 +75,9 @@ for RESULT_DIR in "${result_dir_list[@]}"; do
6875
METRIC=${METRIC_MAP[$task]}
6976
for model_name in "${!MODEL_GROUP_MAP[@]}"; do
7077
model_group=${MODEL_GROUP_MAP[$model_name]}
71-
source .uv/$model_group-env/bin/activate
72-
uv run --active python examples/sample_vllm.py \
78+
source .uv/vllm_normal-env/bin/activate
79+
uv pip list
80+
python examples/sample_vllm.py \
7381
--model_id "$model_name" \
7482
--task_id "$task" \
7583
--metrics "$METRIC" \

examples/sample_vllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def main():
178178
)
179179
task = eval_mm.TaskRegistry.load_task(args.task_id, task_config)
180180

181-
output_dir = os.path.join(args.result_dir, args.task_id, args.model_id + "_vllm")
181+
output_dir = os.path.join(args.result_dir, args.task_id, args.model_id)
182182
os.makedirs(output_dir, exist_ok=True)
183183

184184
preds, _ = load_or_generate_predictions(args, task, gen_kwargs, output_dir)

examples/vllm_registry.py

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
import argparse
12
from collections.abc import Callable
23
from dataclasses import dataclass
34
from typing import Optional
45

56
from PIL import Image
7+
from transformers import AutoTokenizer
68
from vllm import EngineArgs
79
from vllm.lora.request import LoRARequest
810

@@ -14,6 +16,16 @@ class ModelRequestData:
1416
lora_requests: Optional[list[LoRARequest]] = None
1517

1618

19+
INTERNVL_MODELS: tuple[str, ...] = (
20+
"OpenGVLab/InternVL3-1B",
21+
"OpenGVLab/InternVL3-2B",
22+
"OpenGVLab/InternVL3-8B",
23+
"OpenGVLab/InternVL3-14B",
24+
"OpenGVLab/InternVL3-38B",
25+
"OpenGVLab/InternVL3-78B",
26+
)
27+
28+
1729
class VLLMModelRegistry:
1830
def __init__(self, model_id: str):
1931
self.model_id = model_id
@@ -30,8 +42,18 @@ def __init__(self, model_id: str):
3042
self._engine_args_qwen3_vl,
3143
self._load_qwen3_vl,
3244
),
45+
"moonshotai/Kimi-VL-A3B-Instruct": (
46+
self._engine_args_kimi_vl,
47+
self._load_kimi_vl,
48+
),
3349
}
3450

51+
for internvl_model in INTERNVL_MODELS:
52+
registry[internvl_model] = (
53+
self._engine_args_internvl,
54+
self._load_internvl,
55+
)
56+
3557
try:
3658
self._engine_resolver, self._request_builder = registry[model_id]
3759
except KeyError as exc: # pragma: no cover - defensive programming
@@ -88,3 +110,227 @@ def _load_qwen3_vl(
88110
prompts.append(prompt)
89111

90112
return ModelRequestData(prompts=prompts)
113+
114+
def _engine_args_kimi_vl(self) -> EngineArgs:
115+
return EngineArgs(
116+
model=self.model_id,
117+
max_model_len=4096,
118+
trust_remote_code=True,
119+
limit_mm_per_prompt={self.modality: 5},
120+
)
121+
122+
def _load_kimi_vl(
123+
self, texts: list[str], images_list: list[list[Image.Image]]
124+
) -> ModelRequestData:
125+
if len(texts) != len(images_list):
126+
msg = "texts and images_list must have identical length"
127+
raise ValueError(msg)
128+
129+
prompts: list[str] = []
130+
for text, images in zip(texts, images_list):
131+
num_images = len(images)
132+
if num_images > 0:
133+
placeholder = "".join("<|media_pad|>" for _ in range(num_images))
134+
vision_block = (
135+
"<|media_start|>image<|media_content|>"
136+
f"{placeholder}<|media_end|>"
137+
)
138+
else:
139+
vision_block = ""
140+
141+
prompt = (
142+
"<|im_user|>user<|im_middle|>"
143+
f"{vision_block}{text}<|im_end|>"
144+
"<|im_assistant|>assistant<|im_middle|>"
145+
)
146+
prompts.append(prompt)
147+
148+
return ModelRequestData(prompts=prompts)
149+
150+
def _engine_args_internvl(self) -> EngineArgs:
151+
return EngineArgs(
152+
model=self.model_id,
153+
trust_remote_code=True,
154+
max_model_len=8192,
155+
limit_mm_per_prompt={self.modality: 5},
156+
)
157+
158+
def _load_internvl(
159+
self, texts: list[str], images_list: list[list[Image.Image]]
160+
) -> ModelRequestData:
161+
if len(texts) != len(images_list):
162+
msg = "texts and images_list must have identical length"
163+
raise ValueError(msg)
164+
165+
if not hasattr(self, "_internvl_tokenizer"):
166+
self._internvl_tokenizer = AutoTokenizer.from_pretrained(
167+
self.model_id,
168+
trust_remote_code=True,
169+
)
170+
171+
tokenizer = self._internvl_tokenizer
172+
173+
prompts: list[str] = []
174+
for text, images in zip(texts, images_list):
175+
num_images = len(images)
176+
if num_images > 0:
177+
message_content = [{"type": "image"} for _ in range(num_images)]
178+
if text:
179+
message_content.append({"type": "text", "text": text})
180+
else:
181+
message_content = [{"type": "text", "text": text}]
182+
183+
messages = [[{"role": "user", "content": message_content}]]
184+
prompt = tokenizer.apply_chat_template(
185+
messages,
186+
tokenize=False,
187+
add_generation_prompt=True,
188+
)
189+
prompts.append(prompt[0])
190+
191+
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
192+
stop_token_ids = [
193+
token_id
194+
for token_id in (
195+
tokenizer.convert_tokens_to_ids(token) for token in stop_tokens
196+
)
197+
if token_id is not None
198+
]
199+
200+
return ModelRequestData(prompts=prompts, stop_token_ids=stop_token_ids)
201+
202+
203+
def _generate_dummy_images(count: int) -> list[Image.Image]:
204+
"""Return placeholder PIL images for prompt-construction tests."""
205+
206+
return [Image.new("RGB", (1, 1), color=0) for _ in range(count)]
207+
208+
209+
def preview_qwen3_vl_requests(
210+
texts: list[str], image_counts: list[int]
211+
) -> ModelRequestData:
212+
"""Build prompts for Qwen3-VL using dummy images (testing helper)."""
213+
214+
if len(texts) != len(image_counts):
215+
msg = "texts and image_counts must have identical length"
216+
raise ValueError(msg)
217+
218+
images_list = [_generate_dummy_images(count) for count in image_counts]
219+
registry = VLLMModelRegistry("Qwen/Qwen3-VL-30B-A3B-Instruct")
220+
return registry.build_requests(texts, images_list)
221+
222+
223+
def preview_kimi_vl_requests(
224+
texts: list[str], image_counts: list[int]
225+
) -> ModelRequestData:
226+
"""Build prompts for Kimi-VL using dummy images (testing helper)."""
227+
228+
if len(texts) != len(image_counts):
229+
msg = "texts and image_counts must have identical length"
230+
raise ValueError(msg)
231+
232+
images_list = [_generate_dummy_images(count) for count in image_counts]
233+
registry = VLLMModelRegistry("moonshotai/Kimi-VL-A3B-Instruct")
234+
return registry.build_requests(texts, images_list)
235+
236+
237+
def preview_internvl_requests(
238+
texts: list[str], image_counts: list[int]
239+
) -> ModelRequestData:
240+
"""Build prompts for InternVL using dummy images (testing helper)."""
241+
242+
if len(texts) != len(image_counts):
243+
msg = "texts and image_counts must have identical length"
244+
raise ValueError(msg)
245+
246+
images_list = [_generate_dummy_images(count) for count in image_counts]
247+
registry = VLLMModelRegistry("OpenGVLab/InternVL3-2B")
248+
return registry.build_requests(texts, images_list)
249+
250+
251+
def _parse_cli_args() -> argparse.Namespace:
252+
parser = argparse.ArgumentParser(
253+
description="Preview prompts generated by the VLLM model registry.",
254+
)
255+
parser.add_argument(
256+
"--model-id",
257+
required=True,
258+
choices=[
259+
"Qwen/Qwen3-VL-30B-A3B-Instruct",
260+
"moonshotai/Kimi-VL-A3B-Instruct",
261+
*INTERNVL_MODELS,
262+
],
263+
help="Registered model identifier to preview.",
264+
)
265+
parser.add_argument(
266+
"--texts",
267+
default=["What is in the image?"],
268+
nargs="+",
269+
help="One or more user messages to build prompts for.",
270+
)
271+
parser.add_argument(
272+
"--image-counts",
273+
nargs="+",
274+
type=int,
275+
help="Number of images to attach per message (broadcast if a single value).",
276+
)
277+
parser.add_argument(
278+
"--show-engine-args",
279+
action="store_true",
280+
help="Print the EngineArgs associated with the selected model.",
281+
)
282+
return parser.parse_args()
283+
284+
285+
def _broadcast_counts(texts: list[str], image_counts: Optional[list[int]]) -> list[int]:
286+
if image_counts is None:
287+
return [0] * len(texts)
288+
289+
if len(image_counts) == 1 and len(texts) > 1:
290+
image_counts = image_counts * len(texts)
291+
292+
if len(image_counts) != len(texts):
293+
msg = "image_counts must match the number of texts"
294+
raise ValueError(msg)
295+
296+
if any(count < 0 for count in image_counts):
297+
msg = "image_counts must be non-negative"
298+
raise ValueError(msg)
299+
300+
return image_counts
301+
302+
303+
def _preview_cli() -> None:
304+
args = _parse_cli_args()
305+
texts = args.texts
306+
image_counts = _broadcast_counts(texts, args.image_counts)
307+
308+
preview_dispatch: dict[str, Callable[[list[str], list[int]], ModelRequestData]] = {
309+
"Qwen/Qwen3-VL-30B-A3B-Instruct": preview_qwen3_vl_requests,
310+
"moonshotai/Kimi-VL-A3B-Instruct": preview_kimi_vl_requests,
311+
}
312+
313+
for internvl_model in INTERNVL_MODELS:
314+
preview_dispatch[internvl_model] = preview_internvl_requests
315+
316+
preview_fn = preview_dispatch[args.model_id]
317+
registry = VLLMModelRegistry(args.model_id)
318+
request_data = preview_fn(texts, image_counts)
319+
320+
if args.show_engine_args:
321+
engine_args = registry.get_engine_args()
322+
print("EngineArgs:")
323+
print(engine_args)
324+
print()
325+
326+
for idx, prompt in enumerate(request_data.prompts):
327+
print(f"Prompt[{idx}]:")
328+
print(prompt)
329+
print("---")
330+
331+
stop_ids = request_data.stop_token_ids
332+
print("Stop token IDs:", stop_ids if stop_ids is not None else "None")
333+
334+
335+
if __name__ == "__main__": # pragma: no cover - CLI helper
336+
_preview_cli()

0 commit comments

Comments
 (0)