Skip to content

Commit 20b0d88

Browse files
authored
Add support for baichuan (#365)
1 parent 2bdea7a commit 20b0d88

File tree

6 files changed

+361
-0
lines changed

6 files changed

+361
-0
lines changed

vllm/model_executor/model_loader.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
# TODO(woosuk): Lazy-load the model classes.
1313
_MODEL_REGISTRY = {
14+
"BaiChuanForCausalLM": BaiChuanForCausalLM,
1415
"BloomForCausalLM": BloomForCausalLM,
1516
"GPT2LMHeadModel": GPT2LMHeadModel,
1617
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,

vllm/model_executor/models/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from vllm.model_executor.models.baichuan import BaiChuanForCausalLM
12
from vllm.model_executor.models.bloom import BloomForCausalLM
23
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
34
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
@@ -8,6 +9,7 @@
89
from vllm.model_executor.models.opt import OPTForCausalLM
910

1011
__all__ = [
12+
"BaiChuanForCausalLM",
1113
"BloomForCausalLM",
1214
"GPT2LMHeadModel",
1315
"GPTBigCodeForCausalLM",
+293
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
# coding=utf-8
2+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3+
#
4+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5+
# and OPT implementations in this library. It has been modified from its
6+
# original forms to accommodate minor architectural differences compared
7+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
"""Inference-only BaiChuan model compatible with HuggingFace weights.
21+
22+
The input of the model is flattened to a 1D tensor of tokens. The model uses
23+
InputMetadata to extract the original 2D shape of the input.
24+
"""
25+
from typing import Dict, List, Optional, Tuple
26+
27+
import torch
28+
from torch import nn
29+
30+
from vllm.sequence import SequenceOutputs
31+
from vllm.model_executor.input_metadata import InputMetadata
32+
from vllm.model_executor.layers.activation import SiluAndMul
33+
from vllm.model_executor.layers.layernorm import RMSNorm
34+
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
35+
from vllm.model_executor.layers.sampler import Sampler
36+
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
37+
load_tensor_parallel_weights)
38+
from vllm.model_executor.parallel_utils.parallel_state import (
39+
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
40+
from vllm.model_executor.parallel_utils.tensor_parallel import (
41+
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
42+
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
43+
44+
KVCache = Tuple[torch.Tensor, torch.Tensor]
45+
46+
47+
class BaiChuanMLP(nn.Module):
48+
49+
def __init__(
50+
self,
51+
hidden_size: int,
52+
intermediate_size: int,
53+
hidden_act: str,
54+
):
55+
super().__init__()
56+
self.gate_up_proj = ColumnParallelLinear(hidden_size,
57+
2 * intermediate_size,
58+
bias=False,
59+
gather_output=False,
60+
perform_initialization=False)
61+
self.down_proj = RowParallelLinear(intermediate_size,
62+
hidden_size,
63+
bias=False,
64+
input_is_parallel=True,
65+
perform_initialization=False)
66+
if hidden_act != "silu":
67+
raise ValueError(f"Unsupported activation: {hidden_act}. "
68+
"Only silu is supported for now.")
69+
self.act_fn = SiluAndMul()
70+
71+
def forward(self, x):
72+
gate_up, _ = self.gate_up_proj(x)
73+
x = self.act_fn(gate_up)
74+
x, _ = self.down_proj(x)
75+
return x
76+
77+
78+
class BaiChuanAttention(nn.Module):
79+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
80+
81+
def __init__(
82+
self,
83+
hidden_size: int,
84+
num_heads: int,
85+
):
86+
super().__init__()
87+
self.hidden_size = hidden_size
88+
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
89+
)
90+
self.total_num_heads = num_heads
91+
assert self.total_num_heads % tensor_model_parallel_world_size == 0
92+
self.num_heads = (self.total_num_heads //
93+
tensor_model_parallel_world_size)
94+
self.head_dim = hidden_size // self.total_num_heads
95+
self.scaling = self.head_dim**-0.5
96+
97+
# pylint: disable=invalid-name
98+
self.W_pack = ColumnParallelLinear(
99+
hidden_size,
100+
3 * hidden_size,
101+
bias=False,
102+
gather_output=False,
103+
perform_initialization=False,
104+
)
105+
self.o_proj = RowParallelLinear(
106+
self.total_num_heads * self.head_dim,
107+
hidden_size,
108+
bias=False,
109+
input_is_parallel=True,
110+
perform_initialization=False,
111+
)
112+
113+
self.attn = PagedAttentionWithRoPE(self.num_heads,
114+
self.head_dim,
115+
self.scaling,
116+
rotary_dim=self.head_dim)
117+
118+
def forward(
119+
self,
120+
positions: torch.Tensor,
121+
hidden_states: torch.Tensor,
122+
kv_cache: KVCache,
123+
input_metadata: InputMetadata,
124+
cache_event: Optional[torch.cuda.Event],
125+
) -> torch.Tensor:
126+
qkv, _ = self.W_pack(hidden_states)
127+
q, k, v = qkv.chunk(chunks=3, dim=-1)
128+
k_cache, v_cache = kv_cache
129+
attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
130+
input_metadata, cache_event)
131+
output, _ = self.o_proj(attn_output)
132+
return output
133+
134+
135+
class BaiChuanDecoderLayer(nn.Module):
136+
137+
def __init__(self, config: BaiChuanConfig):
138+
super().__init__()
139+
self.hidden_size = config.hidden_size
140+
self.self_attn = BaiChuanAttention(
141+
hidden_size=self.hidden_size,
142+
num_heads=config.num_attention_heads,
143+
)
144+
self.mlp = BaiChuanMLP(
145+
hidden_size=self.hidden_size,
146+
intermediate_size=config.intermediate_size,
147+
hidden_act=config.hidden_act,
148+
)
149+
self.input_layernorm = RMSNorm(config.hidden_size,
150+
eps=config.rms_norm_eps)
151+
self.post_attention_layernorm = RMSNorm(config.hidden_size,
152+
eps=config.rms_norm_eps)
153+
154+
def forward(
155+
self,
156+
positions: torch.Tensor,
157+
hidden_states: torch.Tensor,
158+
kv_cache: KVCache,
159+
input_metadata: InputMetadata,
160+
cache_event: Optional[torch.cuda.Event],
161+
) -> torch.Tensor:
162+
# Self Attention
163+
residual = hidden_states
164+
hidden_states = self.input_layernorm(hidden_states)
165+
hidden_states = self.self_attn(
166+
positions=positions,
167+
hidden_states=hidden_states,
168+
kv_cache=kv_cache,
169+
input_metadata=input_metadata,
170+
cache_event=cache_event,
171+
)
172+
hidden_states = residual + hidden_states
173+
174+
# Fully Connected
175+
residual = hidden_states
176+
hidden_states = self.post_attention_layernorm(hidden_states)
177+
hidden_states = self.mlp(hidden_states)
178+
hidden_states = residual + hidden_states
179+
return hidden_states
180+
181+
182+
class BaiChuanModel(nn.Module):
183+
184+
def __init__(self, config: BaiChuanConfig):
185+
super().__init__()
186+
self.config = config
187+
self.padding_idx = config.pad_token_id
188+
self.vocab_size = config.vocab_size
189+
190+
self.embed_tokens = VocabParallelEmbedding(
191+
config.vocab_size,
192+
config.hidden_size,
193+
perform_initialization=False)
194+
self.layers = nn.ModuleList([
195+
BaiChuanDecoderLayer(config)
196+
for _ in range(config.num_hidden_layers)
197+
])
198+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
199+
200+
def forward(
201+
self,
202+
input_ids: torch.Tensor,
203+
positions: torch.Tensor,
204+
kv_caches: List[KVCache],
205+
input_metadata: InputMetadata,
206+
cache_events: Optional[List[torch.cuda.Event]],
207+
) -> torch.Tensor:
208+
hidden_states = self.embed_tokens(input_ids)
209+
for i in range(len(self.layers)):
210+
if cache_events is None:
211+
cache_event = None
212+
else:
213+
cache_event = cache_events[i]
214+
layer = self.layers[i]
215+
hidden_states = layer(
216+
positions,
217+
hidden_states,
218+
kv_caches[i],
219+
input_metadata,
220+
cache_event,
221+
)
222+
hidden_states = self.norm(hidden_states)
223+
return hidden_states
224+
225+
226+
class BaiChuanForCausalLM(nn.Module):
227+
228+
def __init__(self, config):
229+
super().__init__()
230+
self.config = config
231+
self.model = BaiChuanModel(config)
232+
self.lm_head = ColumnParallelLinear(config.hidden_size,
233+
config.vocab_size,
234+
bias=False,
235+
gather_output=False,
236+
perform_initialization=False)
237+
self.sampler = Sampler(config.vocab_size)
238+
239+
def forward(
240+
self,
241+
input_ids: torch.Tensor,
242+
positions: torch.Tensor,
243+
kv_caches: List[KVCache],
244+
input_metadata: InputMetadata,
245+
cache_events: Optional[List[torch.cuda.Event]],
246+
) -> Dict[int, SequenceOutputs]:
247+
hidden_states = self.model(input_ids, positions, kv_caches,
248+
input_metadata, cache_events)
249+
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
250+
input_metadata)
251+
return next_tokens
252+
253+
_column_parallel_weights = [
254+
"embed_tokens.weight", "lm_head.weight", "W_pack.weight",
255+
"gate_proj.weight", "up_proj.weight"
256+
]
257+
_row_parallel_weights = ["o_proj.weight", "down_proj.weight"]
258+
259+
def load_weights(self,
260+
model_name_or_path: str,
261+
cache_dir: Optional[str] = None,
262+
use_np_cache: bool = False):
263+
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
264+
state_dict = self.state_dict()
265+
266+
for name, loaded_weight in hf_model_weights_iterator(
267+
model_name_or_path, cache_dir, use_np_cache):
268+
if "rotary_emb.inv_freq" in name:
269+
continue
270+
271+
is_gate_up_weight = False
272+
for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]):
273+
if weight_name not in name:
274+
continue
275+
param = state_dict[name.replace(weight_name, "gate_up_proj")]
276+
shard_size = param.shape[0] // 2
277+
loaded_weight = loaded_weight[
278+
shard_size * tensor_model_parallel_rank:shard_size *
279+
(tensor_model_parallel_rank + 1)]
280+
param_slice = param.data[shard_size * stride_id:shard_size *
281+
(stride_id + 1)]
282+
assert param_slice.shape == loaded_weight.shape
283+
param_slice.copy_(loaded_weight)
284+
is_gate_up_weight = True
285+
break
286+
if is_gate_up_weight:
287+
continue
288+
289+
param = state_dict[name]
290+
load_tensor_parallel_weights(param, loaded_weight, name,
291+
self._column_parallel_weights,
292+
self._row_parallel_weights,
293+
tensor_model_parallel_rank)

vllm/transformers_utils/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
_CONFIG_REGISTRY = {
66
"mpt": MPTConfig,
7+
"baichuan": BaiChuanConfig,
78
}
89

910

Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from vllm.transformers_utils.configs.mpt import MPTConfig
2+
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
23

34
__all__ = [
45
"MPTConfig",
6+
"BaiChuanConfig",
57
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# coding=utf-8
2+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3+
#
4+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5+
# and OPT implementations in this library. It has been modified from its
6+
# original forms to accommodate minor architectural differences compared
7+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
from transformers.configuration_utils import PretrainedConfig
22+
23+
24+
class BaiChuanConfig(PretrainedConfig):
25+
model_type = "baichuan"
26+
keys_to_ignore_at_inference = ["past_key_values"]
27+
28+
def __init__(
29+
self,
30+
vocab_size=64000,
31+
hidden_size=4096,
32+
intermediate_size=11008,
33+
num_hidden_layers=32,
34+
num_attention_heads=32,
35+
hidden_act="silu",
36+
max_position_embeddings=4096,
37+
initializer_range=0.02,
38+
rms_norm_eps=1e-6,
39+
use_cache=True,
40+
pad_token_id=0,
41+
bos_token_id=1,
42+
eos_token_id=2,
43+
tie_word_embeddings=False,
44+
**kwargs,
45+
):
46+
self.vocab_size = vocab_size
47+
self.max_position_embeddings = max_position_embeddings
48+
self.hidden_size = hidden_size
49+
self.intermediate_size = intermediate_size
50+
self.num_hidden_layers = num_hidden_layers
51+
self.num_attention_heads = num_attention_heads
52+
self.hidden_act = hidden_act
53+
self.initializer_range = initializer_range
54+
self.rms_norm_eps = rms_norm_eps
55+
self.use_cache = use_cache
56+
super().__init__(
57+
pad_token_id=pad_token_id,
58+
bos_token_id=bos_token_id,
59+
eos_token_id=eos_token_id,
60+
tie_word_embeddings=tie_word_embeddings,
61+
**kwargs,
62+
)

0 commit comments

Comments
 (0)