Skip to content

Commit 5a6211f

Browse files
authored
fix minicpm for transformers>=4.39 (#11533)
* fix minicpm for transformers>=4.39
1 parent 0209427 commit 5a6211f

File tree

6 files changed

+320
-5
lines changed

6 files changed

+320
-5
lines changed

python/llm/example/CPU/HF-Transformers-AutoModels/Model/minicpm/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ conda activate llm
1717

1818
# install ipex-llm with 'all' option
1919
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
20+
pip install "transformers>=4.36"
2021
```
2122
On Windows:
2223

@@ -25,6 +26,7 @@ conda create -n llm python=3.11
2526
conda activate llm
2627
2728
pip install --pre --upgrade ipex-llm[all]
29+
pip install "transformers>=4.36"
2830
```
2931

3032
### 2. Run

python/llm/example/CPU/PyTorch-Models/Model/minicpm/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ conda activate llm
1919

2020
# install the latest ipex-llm nightly build with 'all' option
2121
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
22+
pip install "transformers>=4.36"
2223
```
2324

2425
On Windows:
@@ -28,6 +29,7 @@ conda create -n llm python=3.11
2829
conda activate llm
2930
3031
pip install --pre --upgrade ipex-llm[all]
32+
pip install "transformers>=4.36"
3133
```
3234

3335
### 2. Run

python/llm/example/GPU/HuggingFace/LLM/minicpm/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ conda create -n llm python=3.11
1414
conda activate llm
1515
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
1616
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
17+
pip install "transformers>=4.36"
1718
```
1819

1920
#### 1.2 Installation on Windows
@@ -24,6 +25,7 @@ conda activate llm
2425

2526
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
2627
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
28+
pip install "transformers>=4.36"
2729
```
2830

2931
### 2. Configures OneAPI environment variables for Linux

python/llm/example/GPU/PyTorch-Models/Model/minicpm/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ conda create -n llm python=3.11
1414
conda activate llm
1515
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
1616
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
17+
pip install "transformers>=4.36"
1718
```
1819

1920
#### 1.2 Installation on Windows
@@ -24,6 +25,7 @@ conda activate llm
2425

2526
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
2627
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
28+
pip install "transformers>=4.36"
2729
```
2830

2931
### 2. Configures OneAPI environment variables for Linux

python/llm/src/ipex_llm/transformers/convert.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -1673,19 +1673,27 @@ def safe_bmm_fwd(*args, **kwargs):
16731673
stablelm_model_forward
16741674
)
16751675
elif model.config.model_type == 'minicpm':
1676-
from ipex_llm.transformers.models.minicpm import minicpm_attention_forward
1677-
from ipex_llm.transformers.models.minicpm import minicpm_model_forward
16781676
modeling_module_name = model.__class__.__module__
16791677
module = importlib.import_module(modeling_module_name)
1678+
if version.parse(trans_version) >= version.parse("4.39.0"):
1679+
from ipex_llm.transformers.models.minicpm import minicpm_attention_forward_4_39
1680+
convert_forward(model,
1681+
module.MiniCPMAttention,
1682+
minicpm_attention_forward_4_39)
1683+
else:
1684+
from ipex_llm.transformers.models.minicpm import minicpm_attention_forward
1685+
convert_forward(model,
1686+
module.MiniCPMAttention,
1687+
minicpm_attention_forward)
1688+
from ipex_llm.transformers.models.minicpm import minicpm_model_forward
1689+
16801690
convert_forward(model,
16811691
module.MiniCPMMLP,
16821692
llama_mlp_forward)
16831693
convert_forward(model,
16841694
module.MiniCPMRMSNorm,
16851695
llama_rms_norm_forward)
1686-
convert_forward(model,
1687-
module.MiniCPMAttention,
1688-
minicpm_attention_forward)
1696+
16891697
convert_forward(model,
16901698
module.MiniCPMModel,
16911699
minicpm_model_forward)

python/llm/src/ipex_llm/transformers/models/minicpm.py

+299
Original file line numberDiff line numberDiff line change
@@ -767,3 +767,302 @@ def minicpm_model_forward_internal(
767767
hidden_states=all_hidden_states,
768768
attentions=all_self_attns,
769769
)
770+
771+
772+
def minicpm_attention_forward_original_4_39(
773+
self,
774+
hidden_states: torch.Tensor,
775+
attention_mask: Optional[torch.Tensor] = None,
776+
position_ids: Optional[torch.LongTensor] = None,
777+
past_key_value: Optional[List[torch.FloatTensor]] = None,
778+
output_attentions: bool = False,
779+
use_cache: bool = False,
780+
cache_position: Optional[torch.LongTensor] = None,
781+
**kwargs
782+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
783+
if "padding_mask" in kwargs:
784+
warnings.warn(
785+
"Passing `padding_mask` is deprecated and will be removed in v4.37. "
786+
"Please make sure use `attention_mask` instead.`"
787+
)
788+
789+
bsz, q_len, hidden_size = hidden_states.size()
790+
device = hidden_states.device
791+
# for flash attention
792+
original_dtype = hidden_states.dtype
793+
794+
use_fuse_rope = should_use_fuse_rope(self, hidden_states, position_ids)
795+
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx, seq_len=q_len)
796+
no_tp = not self.config.pretraining_tp > 1
797+
decoding_fast_path = use_decoding_fast_path(self.q_proj,
798+
use_fuse_rope,
799+
enough_kv_room,
800+
bsz * q_len,
801+
llama_decoding_fast_path_qtype_check) and no_tp
802+
803+
# single batch decoding fast path
804+
# forward_qkv takes will perform QKV projection, rotary position embedding
805+
# and save the key/value states to cache, then return query states and the
806+
# extended key/value cache
807+
if decoding_fast_path:
808+
hidden_states = hidden_states.view(1, -1)
809+
cache_k = past_key_value.key_cache[self.layer_idx]
810+
cache_v = past_key_value.value_cache[self.layer_idx]
811+
kv_seq_len = cache_k.shape[-2]
812+
import xe_linear
813+
query_states, key_states, value_states = xe_linear.forward_qkv(hidden_states,
814+
self.q_proj.weight,
815+
self.k_proj.weight,
816+
self.v_proj.weight,
817+
position_ids,
818+
cache_k, cache_v,
819+
self.q_proj.weight.qtype,
820+
self.v_proj.weight.qtype,
821+
kv_seq_len,
822+
self.head_dim,
823+
self.rotary_emb.base,)
824+
kv_seq_len += 1
825+
# update past_key_value's seem_tokens and kv caches.
826+
if self.layer_idx == 0:
827+
past_key_value._seen_tokens = kv_seq_len
828+
past_key_value.key_cache[self.layer_idx] = key_states
829+
past_key_value.value_cache[self.layer_idx] = value_states
830+
831+
else:
832+
if self.config.pretraining_tp > 1:
833+
key_value_slicing = ((self.num_key_value_heads * self.head_dim) //
834+
self.config.pretraining_tp)
835+
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim)
836+
// self.config.pretraining_tp, dim=0)
837+
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
838+
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
839+
840+
query_states = [F.linear(hidden_states, query_slices[i])
841+
for i in range(self.config.pretraining_tp)]
842+
query_states = torch.cat(query_states, dim=-1)
843+
844+
key_states = [F.linear(hidden_states, key_slices[i])
845+
for i in range(self.config.pretraining_tp)]
846+
key_states = torch.cat(key_states, dim=-1)
847+
848+
value_states = [F.linear(hidden_states, value_slices[i])
849+
for i in range(self.config.pretraining_tp)]
850+
value_states = torch.cat(value_states, dim=-1)
851+
else:
852+
if fp16_fusion_check(self.q_proj, hidden_states, self.training) and \
853+
hidden_size == 4096 and self.q_proj.out_features == self.k_proj.out_features:
854+
# only use mm_qkv_out on pvc for llama-7b
855+
if not hasattr(self, "qkv_proj_weight"):
856+
self.qkv_proj_weight = torch.stack([self.q_proj.weight,
857+
self.k_proj.weight,
858+
self.v_proj.weight]).contiguous()
859+
self.q_proj.weight.data = self.qkv_proj_weight[0, :, :]
860+
self.k_proj.weight.data = self.qkv_proj_weight[1, :, :]
861+
self.v_proj.weight.data = self.qkv_proj_weight[2, :, :]
862+
torch.xpu.empty_cache()
863+
query_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
864+
dtype=hidden_states.dtype, device=hidden_states.device)
865+
key_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
866+
dtype=hidden_states.dtype, device=hidden_states.device)
867+
value_states = torch.empty(bsz, q_len, self.qkv_proj_weight.shape[-1],
868+
dtype=hidden_states.dtype, device=hidden_states.device)
869+
torch.ops.torch_ipex.mm_qkv_out(
870+
hidden_states, self.qkv_proj_weight, None,
871+
query_states, key_states, value_states
872+
)
873+
else:
874+
if should_use_xetla_mm_qkv(self, device):
875+
if not hasattr(self, "qkv_proj_qweight"):
876+
self.qkv_proj_qweight = fuse_qkv_weight_xetla(self.q_proj,
877+
self.k_proj,
878+
self.v_proj,
879+
self.q_proj.weight.qtype,)
880+
import xe_linear
881+
q_out_len = self.q_proj.out_len
882+
k_out_len = self.k_proj.out_len
883+
v_out_len = self.v_proj.out_len
884+
qkv_states = xe_linear.mm_xetla(hidden_states,
885+
self.qkv_proj_qweight,
886+
self.q_proj.weight.qtype)
887+
query_states = qkv_states[:, :, :q_out_len]
888+
key_states = qkv_states[:, :, q_out_len:q_out_len + k_out_len]
889+
value_states = qkv_states[:, :, q_out_len + k_out_len:]
890+
else:
891+
query_states = self.q_proj(hidden_states)
892+
key_states = self.k_proj(hidden_states)
893+
value_states = self.v_proj(hidden_states)
894+
895+
query_states = query_states.view(bsz, q_len,
896+
self.num_heads, self.head_dim).transpose(1, 2)
897+
key_states = key_states.view(bsz, q_len,
898+
self.num_key_value_heads, self.head_dim).transpose(1, 2)
899+
value_states = value_states.view(bsz, q_len,
900+
self.num_key_value_heads, self.head_dim).transpose(1, 2)
901+
902+
kv_seq_len = key_states.shape[-2]
903+
if past_key_value is not None:
904+
if self.layer_idx is None:
905+
invalidInputError(False,
906+
"The cache structure has changed since version v4.36. "
907+
f"If you are using {self.__class__.__name__} for "
908+
"auto-regressive decodingwith k/v caching, please make sure "
909+
"to initialize the attention class with a layer index.")
910+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
911+
912+
if use_fuse_rope:
913+
import xe_addons
914+
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
915+
query_states, key_states)
916+
else:
917+
if cache_position is not None:
918+
# for transformers 4.38.0
919+
cos, sin = self.rotary_emb(value_states, position_ids)
920+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
921+
cos, sin, position_ids, "llama2")
922+
else:
923+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
924+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
925+
cos, sin, position_ids, "llama")
926+
927+
if past_key_value is not None:
928+
# update the number of seen tokens
929+
if self.layer_idx == 0:
930+
past_key_value._seen_tokens += key_states.shape[-2]
931+
932+
# reuse k, v, self_attention
933+
# update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
934+
if len(past_key_value.key_cache) <= self.layer_idx:
935+
past_key_value.key_cache.append(key_states)
936+
past_key_value.value_cache.append(value_states)
937+
else:
938+
cache_k = past_key_value.key_cache[self.layer_idx]
939+
cache_v = past_key_value.value_cache[self.layer_idx]
940+
941+
if not enough_kv_room:
942+
# allocate new
943+
new_c_k, new_c_v = extend_kv_cache(bsz,
944+
self.num_key_value_heads, # Support GQA
945+
self.head_dim,
946+
cache_k.size(2),
947+
kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH,
948+
dtype=cache_k.dtype,
949+
device=device)
950+
951+
new_c_k[:] = cache_k
952+
new_c_v[:] = cache_v
953+
cache_k = new_c_k
954+
cache_v = new_c_v
955+
956+
key_states, value_states = append_kv_cache(cache_k,
957+
cache_v,
958+
key_states,
959+
value_states)
960+
961+
# update past_key_value
962+
past_key_value.key_cache[self.layer_idx] = key_states
963+
past_key_value.value_cache[self.layer_idx] = value_states
964+
965+
if cache_position is not None:
966+
new_attention_mask = attention_mask[:, :, kv_seq_len - q_len:kv_seq_len, 0:kv_seq_len]
967+
else:
968+
new_attention_mask = attention_mask
969+
970+
if not self.training and not hidden_states.requires_grad and \
971+
use_flash_attention(query_states, key_states, new_attention_mask):
972+
# repeat k/v heads if n_kv_heads < n_heads
973+
key_states = repeat_kv(key_states, self.num_key_value_groups)
974+
value_states = repeat_kv(value_states, self.num_key_value_groups)
975+
# now only use flash attention for first token
976+
attn_output = F.scaled_dot_product_attention(query_states.to(device, dtype=torch.float16),
977+
key_states.to(device, dtype=torch.float16),
978+
value_states.to(device, dtype=torch.float16),
979+
is_causal=True)
980+
attn_weights = None
981+
elif not self.training and not hidden_states.requires_grad and \
982+
use_sdp(q_len, key_states.shape[2], self.head_dim, query_states):
983+
import xe_addons
984+
attn_output = xe_addons.sdp(query_states, key_states, value_states,
985+
new_attention_mask)
986+
attn_output = attn_output.view(query_states.shape)
987+
attn_weights = None
988+
else:
989+
# repeat k/v heads if n_kv_heads < n_heads
990+
key_states = repeat_kv(key_states, self.num_key_value_groups)
991+
value_states = repeat_kv(value_states, self.num_key_value_groups)
992+
# otherwise, use native attention
993+
if query_states.device.type == "xpu":
994+
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
995+
new_attention_mask, cache_position,
996+
bsz, q_len, kv_seq_len,
997+
self.head_dim, self.num_heads, output_attentions)
998+
else:
999+
# CPU path
1000+
if not output_attentions:
1001+
attn_output = torch.nn.functional.scaled_dot_product_attention(
1002+
query_states,
1003+
key_states,
1004+
value_states,
1005+
attn_mask=new_attention_mask,
1006+
dropout_p=self.attention_dropout if self.training else 0.0,
1007+
# The q_len > 1 is necessary to match with
1008+
# AttentionMaskConverter.to_causal_4d that
1009+
# does not create a causal mask in case q_len == 1.
1010+
is_causal=self.is_causal and new_attention_mask is None and q_len > 1,
1011+
)
1012+
else:
1013+
attn_output, attn_weights = native_sdp(query_states, key_states, value_states,
1014+
new_attention_mask, cache_position,
1015+
bsz, q_len, kv_seq_len,
1016+
self.head_dim,
1017+
self.num_heads, output_attentions)
1018+
1019+
attn_output_size = (bsz, self.num_heads, q_len, self.head_dim)
1020+
if attn_output.size() != attn_output_size:
1021+
invalidInputError(False,
1022+
f"`attn_output` should be of size {attn_output_size},"
1023+
f" but is {attn_output.size()}")
1024+
1025+
attn_output = attn_output.transpose(1, 2).contiguous()
1026+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
1027+
1028+
if self.config.pretraining_tp > 1:
1029+
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
1030+
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp,
1031+
dim=1)
1032+
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i])
1033+
for i in range(self.config.pretraining_tp)])
1034+
else:
1035+
attn_output = self.o_proj(attn_output)
1036+
1037+
if not output_attentions:
1038+
attn_weights = None
1039+
1040+
return attn_output.to(original_dtype), attn_weights, past_key_value
1041+
1042+
1043+
def minicpm_attention_forward_4_39(
1044+
self,
1045+
hidden_states: torch.Tensor,
1046+
attention_mask: Optional[torch.Tensor] = None,
1047+
position_ids: Optional[torch.LongTensor] = None,
1048+
past_key_value: Optional[List[torch.FloatTensor]] = None,
1049+
output_attentions: bool = False,
1050+
use_cache: bool = False,
1051+
cache_position: Optional[torch.LongTensor] = None,
1052+
**kwargs
1053+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.FloatTensor]]]:
1054+
if use_quantize_kv_cache(self.q_proj, hidden_states):
1055+
forward_function = minicpm_attention_forward_quantized
1056+
else:
1057+
forward_function = minicpm_attention_forward_original_4_39
1058+
return forward_function(
1059+
self=self,
1060+
hidden_states=hidden_states,
1061+
attention_mask=attention_mask,
1062+
position_ids=position_ids,
1063+
past_key_value=past_key_value,
1064+
output_attentions=output_attentions,
1065+
use_cache=use_cache,
1066+
cache_position=cache_position,
1067+
kwargs=kwargs
1068+
)

0 commit comments

Comments
 (0)