Skip to content

Commit 4ba8219

Browse files
authored
Support PP inference for chatglm3 (#11375)
1 parent 9a3a21e commit 4ba8219

File tree

5 files changed

+118
-26
lines changed

5 files changed

+118
-26
lines changed

python/llm/example/GPU/Pipeline-Parallel-Inference/README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ To run this example with IPEX-LLM on Intel GPUs, we have some recommended requir
1212
- [Qwen/Qwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh)
1313
- [Qwen/Qwen1.5-14B-Chat](./run_qwen1.5_arc_2_card.sh)
1414
- [Qwen/Qwen1.5-32B-Chat](./run_qwen1.5_arc_2_card.sh)
15+
- [THUDM/chatglm3-6b](./run_chatglm_arc_2_card.sh)
1516
- [baichuan-inc/Baichuan2-7B-Chat](./run_baichuan2_arc_2_card.sh)
1617
- [baichuan-inc/Baichuan2-13B-Chat](./run_baichuan2_arc_2_card.sh)
1718
- [microsoft/Phi-3-mini-4k-instruct](./run_phi3_arc_2_card.sh)
@@ -71,6 +72,21 @@ bash run_qwen1.5_arc_2_card.sh
7172

7273
</details>
7374

75+
<details>
76+
<summary> Show chatglm example </summary>
77+
78+
#### Run chatglm3-6B on two Intel Arc A770
79+
80+
You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for chatglm to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.
81+
82+
```bash
83+
bash run_chatglm_arc_2_card.sh
84+
```
85+
86+
</details>
87+
88+
</details>
89+
7490
<details>
7591
<summary> Show Baichuan2 example </summary>
7692

python/llm/example/GPU/Pipeline-Parallel-Inference/generate.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import time
2020
import argparse
2121

22-
from ipex_llm.transformers import AutoModelForCausalLM, init_pipeline_parallel
22+
from ipex_llm.transformers import AutoModel, AutoModelForCausalLM, init_pipeline_parallel
2323
from transformers import AutoTokenizer
2424

2525
init_pipeline_parallel()
@@ -41,13 +41,21 @@
4141

4242
# Load model in 4 bit,
4343
# which convert the relevant layers in the model into INT4 format
44-
model = AutoModelForCausalLM.from_pretrained(model_path,
45-
load_in_4bit=True,
46-
optimize_model=True,
47-
trust_remote_code=True,
48-
use_cache=True,
49-
torch_dtype=torch.float16,
50-
pipeline_parallel_stages=args.gpu_num)
44+
try:
45+
model = AutoModelForCausalLM.from_pretrained(model_path,
46+
load_in_4bit=True,
47+
optimize_model=True,
48+
trust_remote_code=True,
49+
use_cache=True,
50+
torch_dtype=torch.float16,
51+
pipeline_parallel_stages=args.gpu_num)
52+
except:
53+
model = AutoModel.from_pretrained(model_path,
54+
load_in_4bit=True,
55+
optimize_model=True,
56+
trust_remote_code=True,
57+
use_cache=True,
58+
pipeline_parallel_stages=args.gpu_num)
5159

5260
# Load tokenizer
5361
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#
2+
# Copyright 2016 The BigDL Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
17+
source /opt/intel/oneapi/setvars.sh
18+
export MASTER_ADDR=127.0.0.1
19+
export MASTER_PORT=9090
20+
export FI_PROVIDER=tcp
21+
export USE_XETLA=OFF
22+
export OMP_NUM_THREADS=6
23+
if [[ $KERNEL_VERSION != *"6.5"* ]]; then
24+
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
25+
fi
26+
export TORCH_LLM_ALLREDUCE=0
27+
28+
NUM_GPUS=2 # number of used GPU
29+
# To run chatglm3-6b
30+
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
31+
generate.py --repo-id-or-model-path 'THUDM/chatglm3-6b' --gpu-num $NUM_GPUS

python/llm/src/ipex_llm/transformers/models/chatglm2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,12 @@ def chatglm2_model_forward(
7474
use_cache = use_cache if use_cache is not None else self.config.use_cache
7575
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
7676

77-
batch_size, seq_length = input_ids.shape
78-
7977
if inputs_embeds is None:
78+
batch_size, seq_length = input_ids.shape
8079
inputs_embeds = self.embedding(input_ids)
80+
else:
81+
inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
82+
seq_length, batch_size, _ = inputs_embeds.shape
8183

8284
if full_attention_mask is None:
8385
if (attention_mask is not None and not attention_mask.all()) or (

python/llm/src/ipex_llm/transformers/pipeline_parallel.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,19 @@ def forward(self, hidden_states, past_key_value=None, use_cache=False, **kwargs)
7171
return outputs
7272

7373

74+
class Dummy_GLMBlock(nn.Module):
75+
def __init__(self, *args):
76+
super().__init__()
77+
# to avoid AttributeError
78+
self.input_layernorm = DummyLayer()
79+
self.mlp = Dummy_MLPLayer()
80+
81+
def forward(
82+
self, hidden_states, attention_mask, rotary_pos_emb, kv_cache=None, use_cache=True,
83+
):
84+
return hidden_states, kv_cache
85+
86+
7487
def init_pipeline_parallel():
7588
import oneccl_bindings_for_pytorch
7689
os.environ["MASTER_ADDR"] = os.environ.get("MASTER_ADDR", "127.0.0.1")
@@ -79,28 +92,49 @@ def init_pipeline_parallel():
7992

8093

8194
def pipeline_parallel(model, pipeline_parallel_stages):
82-
slice_size = (model.config.num_hidden_layers + pipeline_parallel_stages - 1) // \
83-
pipeline_parallel_stages
95+
global num_layers
96+
if hasattr(model.config, 'num_hidden_layers'):
97+
num_layers = model.config.num_hidden_layers
98+
elif hasattr(model.config, 'num_layers'):
99+
# for chatglm3-6b
100+
num_layers = model.config.num_layers
101+
102+
slice_size = (num_layers + pipeline_parallel_stages - 1) // pipeline_parallel_stages
84103

85104
local_rank = dist.get_rank()
86105

87106
global layer_start
88107
global layer_end
89108
layer_start = slice_size * local_rank
90-
layer_end = layer_start + min(slice_size, model.config.num_hidden_layers - layer_start)
91-
92-
for i in range(model.config.num_hidden_layers):
93-
if i < layer_start or i >= layer_end:
94-
model._modules['model'].layers[i] = Dummy_DecoderLayer()
95-
else:
96-
# align layer_idx and len(past_key_values), otherwise abnormal output
97-
model._modules['model'].layers[i].self_attn.layer_idx = i - layer_start
98-
99-
if local_rank != 0:
100-
model._modules['model'].embed_tokens = DummyLayer()
101-
if local_rank != pipeline_parallel_stages - 1:
102-
model._modules['model'].norm = DummyLayer()
103-
model._modules['lm_head'] = DummyLayer()
109+
layer_end = layer_start + min(slice_size, num_layers - layer_start)
110+
111+
if model.config.architectures is not None \
112+
and model.config.architectures[0] in ["ChatGLMModel", "ChatGLMForConditionalGeneration"]:
113+
# for chatglm3-6b
114+
for i in range(num_layers):
115+
if i < layer_start or i >= layer_end:
116+
model._modules['transformer'].encoder.layers[i] = Dummy_GLMBlock()
117+
else:
118+
model._modules['transformer'].encoder.layers[i].self_attention.num_layers = \
119+
i - layer_start
120+
121+
if local_rank != 0:
122+
model._modules['transformer'].embedding = DummyLayer()
123+
if local_rank != pipeline_parallel_stages - 1:
124+
model._modules['transformer'].encoder.final_layernorm = DummyLayer()
125+
model._modules['transformer'].output_layer = DummyLayer()
126+
else:
127+
for i in range(num_layers):
128+
if i < layer_start or i >= layer_end:
129+
model._modules['model'].layers[i] = Dummy_DecoderLayer()
130+
else:
131+
model._modules['model'].layers[i].self_attn.layer_idx = i - layer_start
132+
133+
if local_rank != 0:
134+
model._modules['model'].embed_tokens = DummyLayer()
135+
if local_rank != pipeline_parallel_stages - 1:
136+
model._modules['model'].norm = DummyLayer()
137+
model._modules['lm_head'] = DummyLayer()
104138

105139
model.pipeline_parallel_stages = pipeline_parallel_stages
106140
model = model.to(f'xpu:{local_rank}')
@@ -176,6 +210,7 @@ def pipeline_parallel_generate(self,
176210

177211
global layer_start
178212
global layer_end
213+
global num_layers
179214

180215
self.first_token_time = 0
181216
self.next_token_time = []

0 commit comments

Comments
 (0)