@@ -191,6 +191,135 @@ def cohere_model_forward(
191
191
)
192
192
193
193
194
+ def cohere_model_forward_4_41 (
195
+ self ,
196
+ input_ids : torch .LongTensor = None ,
197
+ attention_mask : Optional [torch .Tensor ] = None ,
198
+ position_ids : Optional [torch .LongTensor ] = None ,
199
+ past_key_values : Optional [List [torch .FloatTensor ]] = None ,
200
+ inputs_embeds : Optional [torch .FloatTensor ] = None ,
201
+ use_cache : Optional [bool ] = None ,
202
+ output_attentions : Optional [bool ] = None ,
203
+ output_hidden_states : Optional [bool ] = None ,
204
+ return_dict : Optional [bool ] = None ,
205
+ cache_position : Optional [torch .LongTensor ] = None ,
206
+ ):
207
+ use_cache = use_cache if use_cache is not None \
208
+ else self .config .use_cache
209
+ if use_cache and use_quantize_kv_cache (self .layers [0 ].mlp .up_proj , input_ids ):
210
+ if not isinstance (past_key_values , DynamicFp8Cache ):
211
+ past_key_values = DynamicFp8Cache .from_legacy_cache (past_key_values )
212
+ output_attentions = output_attentions if output_attentions is not None \
213
+ else self .config .output_attentions
214
+ output_hidden_states = (
215
+ output_hidden_states if output_hidden_states is not None
216
+ else self .config .output_hidden_states
217
+ )
218
+ use_cache = use_cache if use_cache is not None else self .config .use_cache
219
+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
220
+
221
+ if input_ids is not None and inputs_embeds is not None :
222
+ invalidInputError (False ,
223
+ "You cannot specify both input_ids and inputs_embeds at the same time" )
224
+
225
+ if self .gradient_checkpointing and self .training and use_cache :
226
+ invalidInputError (False ,
227
+ "`use_cache=True` is incompatible "
228
+ "with gradient checkpointing. Setting `use_cache=False`." )
229
+ use_cache = False
230
+
231
+ if inputs_embeds is None :
232
+ inputs_embeds = self .embed_tokens (input_ids )
233
+
234
+ past_seen_tokens = 0
235
+ return_legacy_cache = False
236
+ # kept for BC (non `Cache` `past_key_values` inputs)
237
+ if use_cache and not isinstance (past_key_values , Cache ):
238
+ return_legacy_cache = True
239
+ past_key_values = DynamicCache .from_legacy_cache (past_key_values )
240
+
241
+ if cache_position is None :
242
+ past_seen_tokens = past_key_values .get_seq_length () if past_key_values is not None else 0
243
+ cache_position = torch .arange (
244
+ past_seen_tokens , past_seen_tokens + inputs_embeds .shape [1 ], device = inputs_embeds .device
245
+ )
246
+
247
+ if position_ids is None :
248
+ position_ids = cache_position .unsqueeze (0 )
249
+
250
+ causal_mask = self ._update_causal_mask (
251
+ attention_mask , inputs_embeds , cache_position , past_key_values , output_attentions
252
+ )
253
+
254
+ # embed positions
255
+ hidden_states = inputs_embeds
256
+
257
+ # decoder layers
258
+ all_hidden_states = () if output_hidden_states else None
259
+ all_self_attns = () if output_attentions else None
260
+ next_decoder_cache = None
261
+
262
+ for decoder_layer in self .layers :
263
+ if output_hidden_states :
264
+ all_hidden_states += (hidden_states ,)
265
+
266
+ if self .gradient_checkpointing and self .training :
267
+ layer_outputs = self ._gradient_checkpointing_func (
268
+ decoder_layer .__call__ ,
269
+ hidden_states ,
270
+ causal_mask ,
271
+ position_ids ,
272
+ past_key_values ,
273
+ output_attentions ,
274
+ use_cache ,
275
+ cache_position ,
276
+ )
277
+ else :
278
+ # ipex-llm changes
279
+ curr_device = decoder_layer .input_layernorm .weight .device
280
+ if causal_mask is not None :
281
+ causal_mask = causal_mask .to (curr_device )
282
+ if position_ids is not None :
283
+ position_ids = position_ids .to (curr_device )
284
+ # ipex-llm changes end
285
+ layer_outputs = decoder_layer (
286
+ hidden_states ,
287
+ attention_mask = causal_mask ,
288
+ position_ids = position_ids ,
289
+ past_key_value = past_key_values ,
290
+ output_attentions = output_attentions ,
291
+ use_cache = use_cache ,
292
+ cache_position = cache_position ,
293
+ )
294
+
295
+ hidden_states = layer_outputs [0 ]
296
+
297
+ if use_cache :
298
+ next_decoder_cache = layer_outputs [2 if output_attentions else 1 ]
299
+
300
+ if output_attentions :
301
+ all_self_attns += (layer_outputs [1 ],)
302
+
303
+ hidden_states = self .norm (hidden_states )
304
+
305
+ # add hidden states from the last decoder layer
306
+ if output_hidden_states :
307
+ all_hidden_states += (hidden_states ,)
308
+
309
+ next_cache = next_decoder_cache if use_cache else None
310
+ if return_legacy_cache :
311
+ next_cache = next_cache .to_legacy_cache ()
312
+ if not return_dict :
313
+ return tuple (v for v in [hidden_states , next_cache ,
314
+ all_hidden_states , all_self_attns ] if v is not None )
315
+ return BaseModelOutputWithPast (
316
+ last_hidden_state = hidden_states ,
317
+ past_key_values = next_cache ,
318
+ hidden_states = all_hidden_states ,
319
+ attentions = all_self_attns ,
320
+ )
321
+
322
+
194
323
def cohere_attention_forward (
195
324
self ,
196
325
hidden_states : torch .Tensor ,
0 commit comments