Skip to content

Commit d647119

Browse files
authored
Fix cohere model on transformers>=4.41 (#11575)
* fix cohere model for 4-41
1 parent 5b6eb85 commit d647119

File tree

6 files changed

+151
-12
lines changed

6 files changed

+151
-12
lines changed

python/llm/example/CPU/HF-Transformers-AutoModels/Model/cohere/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ conda activate llm
1717

1818
# install ipex-llm with 'all' option
1919
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
20-
pip install transformers==4.40.0
20+
pip install "transformers>=4.40.0"
2121
```
2222

2323
On Windows:
@@ -27,7 +27,7 @@ conda create -n llm python=3.11
2727
conda activate llm
2828
2929
pip install --pre --upgrade ipex-llm[all]
30-
pip install transformers==4.40.0
30+
pip install "transformers>=4.40.0"
3131
```
3232

3333
### 2. Run

python/llm/example/CPU/PyTorch-Models/Model/cohere/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ conda activate llm
1818

1919
# install the latest ipex-llm nightly build with 'all' option
2020
pip install --pre --upgrade ipex-llm[all] --extra-index-url https://download.pytorch.org/whl/cpu
21-
pip install transformers==4.40.0
21+
pip install "transformers>=4.40.0"
2222
```
2323

2424
On Windows:
@@ -28,7 +28,7 @@ conda create -n llm python=3.11
2828
conda activate llm
2929
3030
pip install --pre --upgrade ipex-llm[all]
31-
pip install transformers==4.40.0
31+
pip install "transformers>=4.40.0"
3232
```
3333

3434
### 2. Run

python/llm/example/GPU/HuggingFace/LLM/cohere/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ conda create -n llm python=3.11
1717
conda activate llm
1818
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
1919
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
20-
pip install transformers==4.40.0
20+
pip install "transformers>=4.40.0"
2121
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
2222
```
2323

@@ -29,7 +29,7 @@ conda activate llm
2929

3030
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
3131
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
32-
pip install transformers==4.40.0
32+
pip install "transformers>=4.40.0"
3333
```
3434

3535
### 2. Configures OneAPI environment variables for Linux

python/llm/example/GPU/PyTorch-Models/Model/cohere/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ conda create -n llm python=3.11
1717
conda activate llm
1818
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
1919
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
20-
pip install transformers==4.40.0
20+
pip install "transformers>=4.40.0"
2121
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc
2222
```
2323

@@ -29,7 +29,7 @@ conda activate llm
2929

3030
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default
3131
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
32-
pip install transformers==4.40.0
32+
pip install "transformers>=4.40.0"
3333
```
3434

3535
### 2. Configures OneAPI environment variables for Linux

python/llm/src/ipex_llm/transformers/convert.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -1382,13 +1382,23 @@ def _optimize_post(model, lightweight_bmm=False):
13821382
qwen2_attention_forward)
13831383
elif model.config.model_type == "cohere":
13841384
# for CohereForAI/c4ai-command-r-v01
1385+
invalidInputError(version.parse(trans_version) >= version.parse("4.40.0"),
1386+
"Please upgrade transformers to 4.40.0 or higher version "
1387+
"to run Mixtral models.")
13851388
modeling_module_name = model.__class__.__module__
13861389
module = importlib.import_module(modeling_module_name)
1390+
if version.parse(trans_version) >= version.parse("4.41.0"):
1391+
from ipex_llm.transformers.models.cohere import cohere_model_forward_4_41
1392+
convert_forward(model,
1393+
module.CohereModel,
1394+
cohere_model_forward_4_41)
1395+
else:
1396+
from ipex_llm.transformers.models.cohere import cohere_model_forward
1397+
convert_forward(model,
1398+
module.CohereModel,
1399+
cohere_model_forward)
1400+
13871401
from ipex_llm.transformers.models.cohere import cohere_attention_forward
1388-
from ipex_llm.transformers.models.cohere import cohere_model_forward
1389-
convert_forward(model,
1390-
module.CohereModel,
1391-
cohere_model_forward)
13921402
convert_forward(model,
13931403
module.CohereAttention,
13941404
cohere_attention_forward)

python/llm/src/ipex_llm/transformers/models/cohere.py

+129
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,135 @@ def cohere_model_forward(
191191
)
192192

193193

194+
def cohere_model_forward_4_41(
195+
self,
196+
input_ids: torch.LongTensor = None,
197+
attention_mask: Optional[torch.Tensor] = None,
198+
position_ids: Optional[torch.LongTensor] = None,
199+
past_key_values: Optional[List[torch.FloatTensor]] = None,
200+
inputs_embeds: Optional[torch.FloatTensor] = None,
201+
use_cache: Optional[bool] = None,
202+
output_attentions: Optional[bool] = None,
203+
output_hidden_states: Optional[bool] = None,
204+
return_dict: Optional[bool] = None,
205+
cache_position: Optional[torch.LongTensor] = None,
206+
):
207+
use_cache = use_cache if use_cache is not None \
208+
else self.config.use_cache
209+
if use_cache and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids):
210+
if not isinstance(past_key_values, DynamicFp8Cache):
211+
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
212+
output_attentions = output_attentions if output_attentions is not None \
213+
else self.config.output_attentions
214+
output_hidden_states = (
215+
output_hidden_states if output_hidden_states is not None
216+
else self.config.output_hidden_states
217+
)
218+
use_cache = use_cache if use_cache is not None else self.config.use_cache
219+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
220+
221+
if input_ids is not None and inputs_embeds is not None:
222+
invalidInputError(False,
223+
"You cannot specify both input_ids and inputs_embeds at the same time")
224+
225+
if self.gradient_checkpointing and self.training and use_cache:
226+
invalidInputError(False,
227+
"`use_cache=True` is incompatible "
228+
"with gradient checkpointing. Setting `use_cache=False`.")
229+
use_cache = False
230+
231+
if inputs_embeds is None:
232+
inputs_embeds = self.embed_tokens(input_ids)
233+
234+
past_seen_tokens = 0
235+
return_legacy_cache = False
236+
# kept for BC (non `Cache` `past_key_values` inputs)
237+
if use_cache and not isinstance(past_key_values, Cache):
238+
return_legacy_cache = True
239+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
240+
241+
if cache_position is None:
242+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
243+
cache_position = torch.arange(
244+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
245+
)
246+
247+
if position_ids is None:
248+
position_ids = cache_position.unsqueeze(0)
249+
250+
causal_mask = self._update_causal_mask(
251+
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
252+
)
253+
254+
# embed positions
255+
hidden_states = inputs_embeds
256+
257+
# decoder layers
258+
all_hidden_states = () if output_hidden_states else None
259+
all_self_attns = () if output_attentions else None
260+
next_decoder_cache = None
261+
262+
for decoder_layer in self.layers:
263+
if output_hidden_states:
264+
all_hidden_states += (hidden_states,)
265+
266+
if self.gradient_checkpointing and self.training:
267+
layer_outputs = self._gradient_checkpointing_func(
268+
decoder_layer.__call__,
269+
hidden_states,
270+
causal_mask,
271+
position_ids,
272+
past_key_values,
273+
output_attentions,
274+
use_cache,
275+
cache_position,
276+
)
277+
else:
278+
# ipex-llm changes
279+
curr_device = decoder_layer.input_layernorm.weight.device
280+
if causal_mask is not None:
281+
causal_mask = causal_mask.to(curr_device)
282+
if position_ids is not None:
283+
position_ids = position_ids.to(curr_device)
284+
# ipex-llm changes end
285+
layer_outputs = decoder_layer(
286+
hidden_states,
287+
attention_mask=causal_mask,
288+
position_ids=position_ids,
289+
past_key_value=past_key_values,
290+
output_attentions=output_attentions,
291+
use_cache=use_cache,
292+
cache_position=cache_position,
293+
)
294+
295+
hidden_states = layer_outputs[0]
296+
297+
if use_cache:
298+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
299+
300+
if output_attentions:
301+
all_self_attns += (layer_outputs[1],)
302+
303+
hidden_states = self.norm(hidden_states)
304+
305+
# add hidden states from the last decoder layer
306+
if output_hidden_states:
307+
all_hidden_states += (hidden_states,)
308+
309+
next_cache = next_decoder_cache if use_cache else None
310+
if return_legacy_cache:
311+
next_cache = next_cache.to_legacy_cache()
312+
if not return_dict:
313+
return tuple(v for v in [hidden_states, next_cache,
314+
all_hidden_states, all_self_attns] if v is not None)
315+
return BaseModelOutputWithPast(
316+
last_hidden_state=hidden_states,
317+
past_key_values=next_cache,
318+
hidden_states=all_hidden_states,
319+
attentions=all_self_attns,
320+
)
321+
322+
194323
def cohere_attention_forward(
195324
self,
196325
hidden_states: torch.Tensor,

0 commit comments

Comments
 (0)