Skip to content

Commit 9c5158d

Browse files
authored
adding Multimodal BKC (#3649) (#3658)
* adding Multimodal BKC * format correction * Add input-mode support in sharding script and revert "verify=False" * BKC update * add peft install & remove deepspeed inf examples for phi4 * fixtypo
1 parent 3a31bfc commit 9c5158d

File tree

6 files changed

+147
-23
lines changed

6 files changed

+147
-23
lines changed

examples/cpu/llm/inference/README.md

+106-8
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,15 @@ python run.py --help # for more detailed usages
110110
|---|---|
111111
| generation | default: beam search (beam size = 4), "--greedy" for greedy search |
112112
| input tokens or prompt | provide fixed sizes for input prompt size, use "--input-tokens" for <INPUT_LENGTH> in [32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 32768, 130944]; if "--input-tokens" is not used, use "--prompt" to choose other strings as inputs|
113-
| input images | default: None, use "--image-url" to choose the image link address for vision-text tasks |
114-
| vision text tasks | default: False, use "--vision-text-model" to choose if your model (like llama3.2 11B model) is running for vision-text generation tasks, default False meaning text generation tasks only|
115113
| output tokens | default: 32, use "--max-new-tokens" to choose any other size |
116114
| batch size | default: 1, use "--batch-size" to choose any other size |
117115
| token latency | enable "--token-latency" to print out the first or next token latency |
118116
| generation iterations | use "--num-iter" and "--num-warmup" to control the repeated iterations of generation, default: 100-iter/10-warmup |
119117
| streaming mode output | greedy search only (work with "--greedy"), use "--streaming" to enable the streaming generation output |
120118
| KV Cache dtype | default: auto, use "--kv-cache-dtype=fp8_e5m2" to enable e5m2 KV Cache. More information refer to [vLLM FP8 E5M2 KV Cache](https://docs.vllm.ai/en/v0.6.6/quantization/fp8_e5m2_kvcache.html) |
121119
| input mode | default: 0, use "--input-mode" to choose input mode for multimodal models. 0: language; 1: vision; 2: speech; 3: vision and speech |
122-
| input audios | default: None, use "--audio" to choose the audio link address for speech tasks |
120+
| input images | default: None, use "--image-url" to choose the image file address for vision-text tasks |
121+
| input audios | default: None, use "--audio" to choose the audio file address for speech tasks |
123122

124123
*Note:* You may need to log in your HuggingFace account to access the model files. Please refer to [HuggingFace login](https://huggingface.co/docs/huggingface_hub/quick-start#login).
125124

@@ -522,19 +521,118 @@ There are some model-specific requirements to be aware of, as follows:
522521

523522
- For Llava models from remote hub, additional setup is required, i.e., `bash ./tools/prepare_llava.sh`.
524523

525-
## 2.3 Instructions for Running LLM with Intel® Xeon® CPU Max Series
524+
## 2.3 Instructions for Running Multimodal LLMs
525+
526+
Multimodal LLMs are large language models capable of processing multiple types of inputs,
527+
like images and audios, in addition to text prompts.
528+
We have optimized the performance of some popular multimodal LLMs like `microsoft/Phi-4-multimodal-instruct`
529+
and `meta-llama/Llama-3.2-11B-Vision-Instruct`, which can be showcased with the provided `run.py` script.
530+
In the commands, the additional arguments need to be specified are highlighted here:
531+
532+
| Special args for multimodal | Notes |
533+
|---|---|
534+
| input mode | Use "--input-mode" to choose input mode for multimodal models. 0: language; 1: vision; 2: speech; 3: vision and speech |
535+
| input image | Use "--image-url" to specify the image link address or local path for vision-text tasks |
536+
| input audio | Use "--audio" to specify the audio file local path for speech tasks |
537+
538+
Meanwhile, for multimodal tasks we need to set the text prompt and bind it with the input image/audio.
539+
The binding is realized with the special tokens, like the image tag `<|image|>` and the audio tag `<|audio|>`.
540+
We provide the following example commands to showcase the argument settings in detail.
541+
542+
### 2.3.1 Phi-4-multimodal-instruct
543+
544+
You can download the sample image and audio to your local folder beforehand.
545+
Also, `peft` package is required for running the model.
546+
547+
```bash
548+
wget https://www.ilankelman.org/stopsigns/australia.jpg
549+
wget https://voiceage.com/wbsamples/in_mono/Trailer.wav
550+
pip install peft
551+
```
552+
553+
- BF16, single instance
554+
555+
We provide example commands running in BF16 precision for all the input modes.
556+
The OMP thread num and `numactl` setup parts are omitted.
557+
558+
Example command for pure text input
559+
560+
```bash
561+
python run.py --input-mode 0 --benchmark -m microsoft/Phi-4-multimodal-instruct --ipex --token-latency --greedy --dtype bfloat16 --max-new-tokens 128 --prompt "<|system|>You are a helpful assistant.<|end|><|user|>How to explain Internet for a medieval knight?<|end|><|assistant|>"
562+
```
563+
564+
Example command for image comprehension
565+
566+
```bash
567+
python run.py --input-mode 1 --benchmark -m microsoft/Phi-4-multimodal-instruct --ipex --token-latency --greedy --dtype bfloat16 --max-new-tokens 128 --prompt "<|user|><|image_1|>What is shown in this image?<|end|><|assistant|>" --image-url australia.jpg
568+
```
569+
570+
Example command for speech comprehension
571+
572+
```bash
573+
python run.py --input-mode 2 --benchmark -m microsoft/Phi-4-multimodal-instruct --ipex --token-latency --greedy --dtype bfloat16 --max-new-tokens 128 --prompt "<|user|><|audio_1|>Transcribe the audio to text, and then translate the audio to French. Use <sep> as a separator between the original transcript and the translation.<|end|><|assistant|>" --audio Trailer.wav
574+
```
575+
576+
Example command for image and speech comprehension
577+
578+
```bash
579+
python run.py --input-mode 3 --benchmark -m microsoft/Phi-4-multimodal-instruct --ipex --token-latency --greedy --dtype bfloat16 --max-new-tokens 128 --prompt "<|user|><|image_1|><|audio_1|><|end|><|assistant|>" --audio Trailer.wav --image-url australia.jpg
580+
```
581+
582+
- Weight-only quantization INT8, single instance
583+
584+
For WoQ INT8 precision, we need to replace the arguments `--ipex` and `--dtype bfloat16`
585+
`--ipex-weight-only-quantization`, `--weight-dtype INT8` and `--quant-with-amp`.
586+
In addition, `--group-size 128` is needed as group-wise quantization should be applied.
587+
588+
Example command for image and speech comprehension
589+
590+
```bash
591+
python run.py --input-mode 3 --benchmark -m microsoft/Phi-4-multimodal-instruct --token-latency --greedy --ipex-weight-only-quantization --weight-dtype INT8 --quant-with-amp --group-size 128 --max-new-tokens 128 --prompt "<|user|><|image_1|><|audio_1|><|end|><|assistant|>" --audio Trailer.wav --image-url australia.jpg
592+
```
593+
594+
### 2.3.2 meta-llama/Llama-3.2-11B-Vision-Instruct
595+
596+
`Llama-3.2-11B-Vision-Instruct` model supports image comprehension tasks.
597+
`--input-mode 1` should always be specified for this model.
598+
599+
- BF16, single instance
600+
601+
```bash
602+
python run.py --input-mode 1 --benchmark -m meta-llama/Llama-3.2-11B-Vision-Instruct --ipex --dtype bfloat16 --prompt "<|image|>Describe the contents of this image." --image-url australia.jpg
603+
```
604+
605+
- Weight-only quantization INT8, single instance
606+
607+
```bash
608+
python run.py --input-mode 1 --benchmark -m meta-llama/Llama-3.2-11B-Vision-Instruct --ipex-weight-only-quantization --weight-dtype INT8 --quant-with-amp --prompt "<|image|>Describe the contents of this image." --image-url australia.jpg
609+
```
610+
611+
- BF16, distributed inference
612+
613+
```bash
614+
deepspeed --bind_cores_to_rank run.py --input-mode 1 --benchmark -m meta-llama/Llama-3.2-11B-Vision-Instruct --ipex --dtype bfloat16 --prompt "<|image|>Describe the contents of this image." --image-url australia.jpg --autotp --shard-model
615+
```
616+
617+
- Weight-only quantization INT8, distributed inference
618+
619+
```bash
620+
deepspeed --bind_cores_to_rank run.py --input-mode 1 --benchmark -m meta-llama/Llama-3.2-11B-Vision-Instruct --ipex-weight-only-quantization --weight-dtype INT8 --quant-with-amp --prompt "<|image|>Describe the contents of this image." --image-url australia.jpg --autotp --shard-model
621+
```
622+
623+
## 2.4 Instructions for Running LLM with Intel® Xeon® CPU Max Series
526624

527625
Intel® Xeon® CPU Max Series are equipped with high bandwidth memory (HBM), which further accelerates LLM inference. For the common case that HBM and DDR are both installed in a Xeon® CPU Max Series server, the memory mode can be configured to Flat Mode or Cache Mode.
528626
Details about memory modes can be found at Section 3.1 in [the Xeon® CPU Max Series Configuration Guide](https://cdrdv2-public.intel.com/769060/354227-intel-xeon-cpu-max-series-configuration-and-tuning-guide.pdf).
529627

530-
### 2.3.1 Single Instance Inference with Xeon® CPU Max Series
628+
### 2.4.1 Single Instance Inference with Xeon® CPU Max Series
531629

532-
#### 2.3.1.1 Cache Mode HBM
630+
#### 2.4.1.1 Cache Mode HBM
533631

534632
In cache mode, only DDR address space is visible to software and HBM functions as a transparent memory-side cache for DDR.
535633
Therefore the usage is the same with [the common usage](#221-run-generation-with-one-instance).
536634

537-
#### 2.3.1.2 Flat Mode HBM
635+
#### 2.4.1.2 Flat Mode HBM
538636

539637
In flat mode, HBM and DDR are exposed to software as separate address spaces.
540638
Therefore we need to check the `HBM_NODE_INDEX` of interest with commands like `lscpu`, then the LLM inference invoking command would be like:
@@ -561,7 +659,7 @@ OMP_NUM_THREADS=<HBM node cores num> numactl -p <HBM_NODE_INDEX> -C <HBM cores l
561659
OMP_NUM_THREADS=56 numactl -p 2 -C 0-55 python run.py --benchmark -m meta-llama/Meta-Llama-3.1-8B-Instruct --dtype bfloat16 --ipex
562660
```
563661

564-
### 2.3.2 Distributed Inference with Xeon® CPU Max Series
662+
### 2.4.2 Distributed Inference with Xeon® CPU Max Series
565663

566664
As HBM has memory capacity limitations, we need to shard the model in advance with DDR memory.
567665
Please follow [the example](#31-how-to-shard-model-for-distributed-tests-with-deepspeed-autotp).

examples/cpu/llm/inference/distributed/run_generation_with_deepspeed.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
parser.add_argument(
6868
"--vision-text-model",
6969
action="store_true",
70-
help="whether or not it is vision-text multi-model structure",
70+
help="[deprecated] whether it is vision-text multi-model structure",
7171
)
7272
parser.add_argument(
7373
"--dtype",
@@ -239,6 +239,12 @@
239239
if args.verbose:
240240
logger.setLevel(logging.DEBUG)
241241

242+
if args.vision_text_model:
243+
logger.warning(
244+
"'--vision-text-model' flag is deprecated. Please set '--input-mode 1' instead."
245+
)
246+
args.input_mode = "1"
247+
242248
num_tokens = args.max_new_tokens
243249
use_ipex = args.ipex or args.ipex_weight_only_quantization
244250

@@ -347,7 +353,7 @@ def get_checkpoint_files(model_name_or_path):
347353

348354
print_rank0(f"*** Loading the model {model_name}")
349355
model_type = next((x for x in MODEL_CLASSES.keys() if x in model_name.lower()), "auto")
350-
if model_type == "llama" and args.vision_text_model:
356+
if model_type == "llama" and args.input_mode == "1":
351357
model_type = "mllama"
352358
if model_type in ["maira-2", "deepseek-v2", "deepseek-v3", "deepseek-r1"]:
353359
model_type = model_type.replace("-", "")
@@ -814,7 +820,7 @@ def load_image(image_file):
814820
if image_file.startswith("http://") or image_file.startswith("https://"):
815821
import requests
816822

817-
raw_image = Image.open(requests.get(args.image_url, stream=True).raw)
823+
raw_image = Image.open(requests.get(image_file, stream=True).raw)
818824
else:
819825
raw_image = Image.open(image_file)
820826
return raw_image
@@ -854,7 +860,7 @@ def load_image(image_file):
854860
if image_file.startswith("http://") or image_file.startswith("https://"):
855861
import requests
856862

857-
raw_image = Image.open(requests.get(args.image_url, stream=True).raw)
863+
raw_image = Image.open(requests.get(image_file, stream=True).raw)
858864
else:
859865
raw_image = Image.open(image_file)
860866
return raw_image

examples/cpu/llm/inference/run.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -629,8 +629,8 @@ def main(args_in: Optional[List[str]] = None) -> None:
629629
model_path = Path(str(args.output_dir) + str(MODEL_CLASSES[model_type]))
630630
if not model_path.exists():
631631
Path.mkdir(model_path)
632-
if args.vision_text_model:
633-
shard_cmd.extend(["--vision-text-model"])
632+
if args.vision_text_model or args.input_mode == "1":
633+
shard_cmd.extend(["--input-mode", "1"])
634634
shard_cmd.extend(
635635
["--save-path", str(args.output_dir) + str(MODEL_CLASSES[model_type])]
636636
)

examples/cpu/llm/inference/single_instance/run_generation.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@
103103
parser.add_argument(
104104
"--vision-text-model",
105105
action="store_true",
106-
help="whether or not it is vision-text multi-model structure",
106+
help="[deprecated] whether it is vision-text multi-model structure",
107107
)
108108
parser.add_argument(
109109
"--kv-cache-dtype",
@@ -126,6 +126,12 @@
126126
args = parser.parse_args()
127127
print(args)
128128

129+
if args.vision_text_model:
130+
logger.warning(
131+
"'--vision-text-model' flag is deprecated. Please set '--input-mode 1' instead."
132+
)
133+
args.input_mode = "1"
134+
129135
# import ipex
130136
if args.ipex:
131137
import intel_extension_for_pytorch as ipex
@@ -144,7 +150,7 @@
144150
model_type = next(
145151
(x for x in MODEL_CLASSES.keys() if x in args.model_id.lower()), "auto"
146152
)
147-
if model_type == "llama" and args.vision_text_model:
153+
if model_type == "llama" and args.input_mode == "1":
148154
model_type = "mllama"
149155
if model_type in ["maira-2", "deepseek-v2", "deepseek-v3", "deepseek-r1"]:
150156
model_type = model_type.replace("-", "")
@@ -303,7 +309,7 @@ def load_image(image_file):
303309
if image_file.startswith("http://") or image_file.startswith("https://"):
304310
import requests
305311

306-
raw_image = Image.open(requests.get(args.image_url, stream=True).raw)
312+
raw_image = Image.open(requests.get(image_file, stream=True).raw)
307313
else:
308314
raw_image = Image.open(image_file)
309315
return raw_image

examples/cpu/llm/inference/single_instance/run_quantization.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
parser.add_argument(
6363
"--vision-text-model",
6464
action="store_true",
65-
help="whether or not it is vision-text multi-model structure",
65+
help="[deprecated] whether it is vision-text multi-model structure",
6666
)
6767
parser.add_argument(
6868
"--max-new-tokens", default=32, type=int, help="output max new tokens"
@@ -302,6 +302,11 @@
302302
logger.setLevel(logging.DEBUG)
303303
ipex.set_logging_level(logging.DEBUG)
304304

305+
if args.vision_text_model:
306+
logger.warning(
307+
"'--vision-text-model' flag is deprecated. Please set '--input-mode 1' instead."
308+
)
309+
args.input_mode = "1"
305310

306311
# disable
307312
try:
@@ -354,15 +359,15 @@
354359
"llava", config.architectures[0], re.IGNORECASE
355360
):
356361

357-
if args.vision_text_model:
362+
if args.input_mode == "1":
358363
model = MLLAMAConfig(args.model_id)
359364
from PIL import Image
360365

361366
def load_image(image_file):
362367
if image_file.startswith("http://") or image_file.startswith("https://"):
363368
import requests
364369

365-
raw_image = Image.open(requests.get(args.image_url, stream=True).raw)
370+
raw_image = Image.open(requests.get(image_file, stream=True).raw)
366371
else:
367372
raw_image = Image.open(image_file)
368373
return raw_image
@@ -455,7 +460,7 @@ def load_image(image_file):
455460
if image_file.startswith("http://") or image_file.startswith("https://"):
456461
import requests
457462

458-
raw_image = Image.open(requests.get(args.image_url, stream=True).raw)
463+
raw_image = Image.open(requests.get(image_file, stream=True).raw)
459464
else:
460465
raw_image = Image.open(image_file)
461466
return raw_image

examples/cpu/llm/inference/utils/create_shard_model.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,24 @@
4343
parser.add_argument(
4444
"--vision-text-model",
4545
action="store_true",
46-
help="whether or not it is vision-text multi-model structure",
46+
help="[deprecated] whether it is vision-text multi-model structure",
47+
)
48+
parser.add_argument(
49+
"--input-mode",
50+
default="0",
51+
choices=["0", "1", "2", "3"],
52+
type=str,
53+
help="Input mode for multimodal models. 0: language; 1: vision; 2: speech; 3: vision_speech",
4754
)
4855
args = parser.parse_args()
4956
print(args)
57+
if args.vision_text_model:
58+
args.input_mode = "1"
5059
if args.local_rank == 0:
5160
model_type = next(
5261
(x for x in MODEL_CLASSES.keys() if x in args.model_id.lower()), "auto"
5362
)
54-
if model_type == "llama" and args.vision_text_model:
63+
if model_type == "llama" and args.input_mode == "1":
5564
model_type = "mllama"
5665
if model_type in ["maira-2", "deepseek-v2", "deepseek-v3", "deepseek-r1"]:
5766
model_type = model_type.replace("-", "")

0 commit comments

Comments
 (0)