Skip to content

Commit d11f257

Browse files
Add GPU example for MiniCPM-o-2_6 (#12735)
* Add init example for omni mode * Small fix * Small fix * Add chat example * Remove lagecy link * Further update link * Add readme * Small fix * Update main readme link * Update based on comments * Small fix * Small fix * Small fix
1 parent dcca522 commit d11f257

File tree

6 files changed

+419
-1
lines changed

6 files changed

+419
-1
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ Over 70 models have been optimized/verified on `ipex-llm`, including *LLaMA/LLaM
337337
| MiniCPM-V-2 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2) |
338338
| MiniCPM-Llama3-V-2_5 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-Llama3-V-2_5) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
339339
| MiniCPM-V-2_6 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2_6) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2_6) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
340+
| MiniCPM-o-2_6 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/) |
340341
| StableDiffusion | | [link](python/llm/example/GPU/HuggingFace/Multimodal/StableDiffusion) |
341342
| Bce-Embedding-Base-V1 | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Embedding) |
342343
| Speech_Paraformer-Large | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |

README.zh-CN.md

+1
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ See the demo of running [*Text-Generation-WebUI*](https://ipex-llm.readthedocs.i
337337
| MiniCPM-V-2 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2) |
338338
| MiniCPM-Llama3-V-2_5 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-Llama3-V-2_5) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
339339
| MiniCPM-V-2_6 | [link](python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm-v-2_6) | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-V-2_6) | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
340+
| MiniCPM-o-2_6 | | [link](python/llm/example/GPU/HuggingFace/Multimodal/MiniCPM-o-2_6/) |
340341
| StableDiffusion | | [link](python/llm/example/GPU/HuggingFace/Multimodal/StableDiffusion) |
341342
| Bce-Embedding-Base-V1 | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Embedding) |
342343
| Speech_Paraformer-Large | | | [Python link](python/llm/example/NPU/HF-Transformers-AutoModels/Multimodal) |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# MiniCPM-o-2_6
2+
In this directory, you will find examples on how you could apply IPEX-LLM INT4 optimizations on MiniCPM-o-2_6 model on [Intel GPUs](../../../README.md). For illustration purposes, we utilize [openbmb/MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) as reference MiniCPM-o-2_6 model.
3+
4+
In the following examples, we will guide you to apply IPEX-LLM optimizations on MiniCPM-o-2_6 model for text/audio/image/video inputs.
5+
6+
## 0. Requirements & Installation
7+
8+
To run these examples with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../../../README.md#requirements) for more information.
9+
10+
### 0.1 Install IPEX-LLM
11+
12+
- For **Intel Core™ Ultra Processors (Series 2) with processor number 2xxV (code name Lunar Lake)** on Windows:
13+
```cmd
14+
conda create -n llm python=3.11 libuv
15+
conda activate llm
16+
17+
:: or --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/lnl/cn/
18+
pip install --pre --upgrade ipex-llm[xpu_lnl] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/lnl/us/
19+
pip install torchaudio==2.3.1.post0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/lnl/us/
20+
```
21+
- For **Intel Arc B-Series GPU (code name Battlemage)** on Linux:
22+
```cmd
23+
conda create -n llm python=3.11
24+
conda activate llm
25+
26+
# or --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
27+
pip install --pre --upgrade ipex-llm[xpu-arc] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
28+
pip install torchaudio==2.3.1.post0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
29+
```
30+
31+
> [!NOTE]
32+
> We will update for installation on more Intel GPU platforms.
33+
34+
### 0.2 Install Required Pacakges for MiniCPM-o-2_6
35+
36+
```bash
37+
conda activate llm
38+
39+
# refer to: https://huggingface.co/openbmb/MiniCPM-o-2_6#usage
40+
pip install transformers==4.44.2 trl
41+
pip install librosa==0.9.0
42+
pip install soundfile==0.12.1
43+
pip install moviepy
44+
```
45+
46+
### 0.3 Runtime Configuration
47+
48+
- For **Intel Core™ Ultra Processors (Series 2) with processor number 2xxV (code name Lunar Lake)** on Windows:
49+
```cmd
50+
set SYCL_CACHE_PERSISTENT=1
51+
```
52+
- For **Intel Arc B-Series GPU (code name Battlemage)** on Linux:
53+
```cmd
54+
unset OCL_ICD_VENDOR
55+
export SYCL_CACHE_PERSISTENT=1
56+
```
57+
58+
> [!NOTE]
59+
> We will update for runtime configuration on more Intel GPU platforms.
60+
61+
### 1. Example: Chat in Omni Mode
62+
In [omni.py](./omni.py), we show a use case for a MiniCPM-V-2_6 model to chat in omni mode with IPEX-LLM INT4 optimizations on Intel GPUs. In this example, the model will take a video as input, and conduct inference based on the images and audio of this video.
63+
64+
For example, the video input shows a clip of an athlete swimming, with background audio asking "What the athlete is doing?". Then the model in omni mode should inference based on the images of the video and the question in audio.
65+
66+
#### 1.1 Running example
67+
68+
```bash
69+
python omni.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --video-path VIDEO_PATH
70+
```
71+
72+
Arguments info:
73+
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for MiniCPM-o-2_6 model (e.g. `openbmb/MiniCPM-o-2_6`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'openbmb/MiniCPM-o-2_6'`.
74+
- `--video-path VIDEO_PATH`: argument defining the video input.
75+
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
76+
77+
> [!TIP]
78+
> In Omni mode, please make sure that the video input contains sound.
79+
80+
> [!TIP]
81+
> You could just ignore the warning regarding `Some weights of the model checkpoint at xxx were not used when initializing MiniCPMO`.
82+
83+
### 2. Example: Chat with text/audio/image input
84+
In [chat.py](./chat.py), we show a use case for a MiniCPM-V-2_6 model to chat based on text/audio/image, or a combination of two of them, with IPEX-LLM INT4 optimizations on Intel GPUs.
85+
86+
#### 2.1 Running example
87+
88+
- Chat with text input
89+
```bash
90+
python chat.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT
91+
```
92+
93+
- Chat with audio input
94+
```bash
95+
python chat.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --audio-path AUDIO_PATH
96+
```
97+
98+
- Chat with image input
99+
```bash
100+
python chat.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --image-path IMAGE_PATH
101+
```
102+
103+
- Chat with text + audio inputs
104+
```bash
105+
python chat.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --audio-path AUDIO_PATH
106+
```
107+
108+
- Chat with text + image inputs
109+
```bash
110+
python chat.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --prompt PROMPT --image-path IMAGE_PATH
111+
```
112+
113+
- Chat with audio + image inputs
114+
```bash
115+
python chat.py --repo-id-or-model-path REPO_ID_OR_MODEL_PATH --audio-path AUDIO_PATH --image-path IMAGE_PATH
116+
```
117+
118+
119+
Arguments info:
120+
- `--repo-id-or-model-path REPO_ID_OR_MODEL_PATH`: argument defining the huggingface repo id for MiniCPM-o-2_6 model (e.g. `openbmb/MiniCPM-o-2_6`) to be downloaded, or the path to the huggingface checkpoint folder. It is default to be `'openbmb/MiniCPM-o-2_6'`.
121+
- `--prompt PROMPT`: argument defining the text input.
122+
- `--audio-path AUDIO_PATH`: argument defining the audio input.
123+
- `--image-path IMAGE_PATH`: argument defining the image input.
124+
- `--n-predict N_PREDICT`: argument defining the max number of tokens to predict. It is default to be `32`.
125+
126+
> [!TIP]
127+
> You could just ignore the warning regarding `Some weights of the model checkpoint at xxx were not used when initializing MiniCPMO`.
128+
129+
#### 2.2 Sample Outputs
130+
131+
##### [openbmb/MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6)
132+
133+
The sample input image is (which is fetched from [COCO dataset](https://cocodataset.org/#explore?id=264959)):
134+
135+
<a href="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg"><img width=400px src="http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg" ></a><br>
136+
http://farm6.staticflickr.com/5268/5602445367_3504763978_z.jpg
137+
138+
And the sample audio is a person saying "What is in this image".
139+
140+
- Chat with text + image inputs
141+
```log
142+
Inference time: xxxx s
143+
-------------------- Input Image Path --------------------
144+
5602445367_3504763978_z.jpg
145+
-------------------- Input Audio Path --------------------
146+
None
147+
-------------------- Input Prompt --------------------
148+
What is in this image?
149+
-------------------- Chat Output --------------------
150+
The image features a young child holding and displaying her white teddy bear. She is wearing a pink dress, which complements the color of the stuffed toy she
151+
```
152+
153+
- Chat with audio + image inputs:
154+
```log
155+
Inference time: xxxx s
156+
-------------------- Input Image Path --------------------
157+
5602445367_3504763978_z.jpg
158+
-------------------- Input Audio Path --------------------
159+
test_audio.wav
160+
-------------------- Input Prompt --------------------
161+
None
162+
-------------------- Chat Output --------------------
163+
In this image, there is a young girl holding and displaying her stuffed teddy bear. She appears to be the main subject of the photo, with her toy
164+
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
import os
18+
import time
19+
import torch
20+
import librosa
21+
import argparse
22+
from PIL import Image
23+
from transformers import AutoTokenizer
24+
from ipex_llm.transformers import AutoModel
25+
26+
27+
if __name__ == '__main__':
28+
parser = argparse.ArgumentParser(description='Chat with MiniCPM-o-2_6 with text/audio/image')
29+
parser.add_argument('--repo-id-or-model-path', type=str, default="openbmb/MiniCPM-o-2_6",
30+
help='The Hugging Face or ModelScope repo id for the MiniCPM-o-2_6 model to be downloaded'
31+
', or the path to the checkpoint folder')
32+
parser.add_argument('--image-path', type=str,
33+
help='The path to the image for inference.')
34+
parser.add_argument('--audio-path', type=str,
35+
help='The path to the audio for inference.')
36+
parser.add_argument('--prompt', type=str,
37+
help='Prompt for inference.')
38+
parser.add_argument('--n-predict', type=int, default=32,
39+
help='Max tokens to predict')
40+
41+
args = parser.parse_args()
42+
43+
model_path = args.repo_id_or_model_path
44+
image_path = args.image_path
45+
audio_path = args.audio_path
46+
47+
modules_to_not_convert = []
48+
init_vision = False
49+
init_audio = False
50+
if image_path is not None and os.path.exists(image_path):
51+
init_vision = True
52+
modules_to_not_convert += ["vpm", "resampler"]
53+
if audio_path is not None and os.path.exists(audio_path):
54+
init_audio = True
55+
modules_to_not_convert += ["apm"]
56+
57+
# Load model in 4 bit,
58+
# which convert the relevant layers in the model into INT4 format
59+
model = AutoModel.from_pretrained(model_path,
60+
load_in_low_bit="sym_int4",
61+
optimize_model=True,
62+
trust_remote_code=True,
63+
attn_implementation='sdpa',
64+
use_cache=True,
65+
init_vision=init_vision,
66+
init_audio=init_audio,
67+
init_tts=False,
68+
modules_to_not_convert=modules_to_not_convert)
69+
70+
model = model.half().to('xpu')
71+
72+
tokenizer = AutoTokenizer.from_pretrained(model_path,
73+
trust_remote_code=True)
74+
75+
76+
# The following code for generation is adapted from
77+
# https://huggingface.co/openbmb/MiniCPM-o-2_6#addressing-various-audio-understanding-tasks and
78+
# https://huggingface.co/openbmb/MiniCPM-o-2_6#chat-with-single-image
79+
content = []
80+
if init_vision:
81+
image_input = Image.open(image_path).convert('RGB')
82+
content.append(image_input)
83+
if args.prompt is not None:
84+
content.append(args.prompt)
85+
if init_audio:
86+
audio_input, _ = librosa.load(audio_path, sr=16000, mono=True)
87+
content.append(audio_input)
88+
messages = [{'role': 'user', 'content': content}]
89+
90+
91+
with torch.inference_mode():
92+
# ipex_llm model needs a warmup, then inference time can be accurate
93+
model.chat(
94+
msgs=messages,
95+
tokenizer=tokenizer,
96+
sampling=True,
97+
max_new_tokens=args.n_predict,
98+
)
99+
100+
st = time.time()
101+
response = model.chat(
102+
msgs=messages,
103+
tokenizer=tokenizer,
104+
sampling=True,
105+
max_new_tokens=args.n_predict,
106+
)
107+
torch.xpu.synchronize()
108+
end = time.time()
109+
110+
print(f'Inference time: {end-st} s')
111+
print('-'*20, 'Input Image Path', '-'*20)
112+
print(image_path)
113+
print('-'*20, 'Input Audio Path', '-'*20)
114+
print(audio_path)
115+
print('-'*20, 'Input Prompt', '-'*20)
116+
print(args.prompt)
117+
print('-'*20, 'Chat Output', '-'*20)
118+
print(response)
119+

0 commit comments

Comments
 (0)