@@ -767,3 +767,302 @@ def minicpm_model_forward_internal(
767
767
hidden_states = all_hidden_states ,
768
768
attentions = all_self_attns ,
769
769
)
770
+
771
+
772
+ def minicpm_attention_forward_original_4_39 (
773
+ self ,
774
+ hidden_states : torch .Tensor ,
775
+ attention_mask : Optional [torch .Tensor ] = None ,
776
+ position_ids : Optional [torch .LongTensor ] = None ,
777
+ past_key_value : Optional [List [torch .FloatTensor ]] = None ,
778
+ output_attentions : bool = False ,
779
+ use_cache : bool = False ,
780
+ cache_position : Optional [torch .LongTensor ] = None ,
781
+ ** kwargs
782
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [List [torch .FloatTensor ]]]:
783
+ if "padding_mask" in kwargs :
784
+ warnings .warn (
785
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. "
786
+ "Please make sure use `attention_mask` instead.`"
787
+ )
788
+
789
+ bsz , q_len , hidden_size = hidden_states .size ()
790
+ device = hidden_states .device
791
+ # for flash attention
792
+ original_dtype = hidden_states .dtype
793
+
794
+ use_fuse_rope = should_use_fuse_rope (self , hidden_states , position_ids )
795
+ enough_kv_room = is_enough_kv_cache_room_4_36 (past_key_value , self .layer_idx , seq_len = q_len )
796
+ no_tp = not self .config .pretraining_tp > 1
797
+ decoding_fast_path = use_decoding_fast_path (self .q_proj ,
798
+ use_fuse_rope ,
799
+ enough_kv_room ,
800
+ bsz * q_len ,
801
+ llama_decoding_fast_path_qtype_check ) and no_tp
802
+
803
+ # single batch decoding fast path
804
+ # forward_qkv takes will perform QKV projection, rotary position embedding
805
+ # and save the key/value states to cache, then return query states and the
806
+ # extended key/value cache
807
+ if decoding_fast_path :
808
+ hidden_states = hidden_states .view (1 , - 1 )
809
+ cache_k = past_key_value .key_cache [self .layer_idx ]
810
+ cache_v = past_key_value .value_cache [self .layer_idx ]
811
+ kv_seq_len = cache_k .shape [- 2 ]
812
+ import xe_linear
813
+ query_states , key_states , value_states = xe_linear .forward_qkv (hidden_states ,
814
+ self .q_proj .weight ,
815
+ self .k_proj .weight ,
816
+ self .v_proj .weight ,
817
+ position_ids ,
818
+ cache_k , cache_v ,
819
+ self .q_proj .weight .qtype ,
820
+ self .v_proj .weight .qtype ,
821
+ kv_seq_len ,
822
+ self .head_dim ,
823
+ self .rotary_emb .base ,)
824
+ kv_seq_len += 1
825
+ # update past_key_value's seem_tokens and kv caches.
826
+ if self .layer_idx == 0 :
827
+ past_key_value ._seen_tokens = kv_seq_len
828
+ past_key_value .key_cache [self .layer_idx ] = key_states
829
+ past_key_value .value_cache [self .layer_idx ] = value_states
830
+
831
+ else :
832
+ if self .config .pretraining_tp > 1 :
833
+ key_value_slicing = ((self .num_key_value_heads * self .head_dim ) //
834
+ self .config .pretraining_tp )
835
+ query_slices = self .q_proj .weight .split ((self .num_heads * self .head_dim )
836
+ // self .config .pretraining_tp , dim = 0 )
837
+ key_slices = self .k_proj .weight .split (key_value_slicing , dim = 0 )
838
+ value_slices = self .v_proj .weight .split (key_value_slicing , dim = 0 )
839
+
840
+ query_states = [F .linear (hidden_states , query_slices [i ])
841
+ for i in range (self .config .pretraining_tp )]
842
+ query_states = torch .cat (query_states , dim = - 1 )
843
+
844
+ key_states = [F .linear (hidden_states , key_slices [i ])
845
+ for i in range (self .config .pretraining_tp )]
846
+ key_states = torch .cat (key_states , dim = - 1 )
847
+
848
+ value_states = [F .linear (hidden_states , value_slices [i ])
849
+ for i in range (self .config .pretraining_tp )]
850
+ value_states = torch .cat (value_states , dim = - 1 )
851
+ else :
852
+ if fp16_fusion_check (self .q_proj , hidden_states , self .training ) and \
853
+ hidden_size == 4096 and self .q_proj .out_features == self .k_proj .out_features :
854
+ # only use mm_qkv_out on pvc for llama-7b
855
+ if not hasattr (self , "qkv_proj_weight" ):
856
+ self .qkv_proj_weight = torch .stack ([self .q_proj .weight ,
857
+ self .k_proj .weight ,
858
+ self .v_proj .weight ]).contiguous ()
859
+ self .q_proj .weight .data = self .qkv_proj_weight [0 , :, :]
860
+ self .k_proj .weight .data = self .qkv_proj_weight [1 , :, :]
861
+ self .v_proj .weight .data = self .qkv_proj_weight [2 , :, :]
862
+ torch .xpu .empty_cache ()
863
+ query_states = torch .empty (bsz , q_len , self .qkv_proj_weight .shape [- 1 ],
864
+ dtype = hidden_states .dtype , device = hidden_states .device )
865
+ key_states = torch .empty (bsz , q_len , self .qkv_proj_weight .shape [- 1 ],
866
+ dtype = hidden_states .dtype , device = hidden_states .device )
867
+ value_states = torch .empty (bsz , q_len , self .qkv_proj_weight .shape [- 1 ],
868
+ dtype = hidden_states .dtype , device = hidden_states .device )
869
+ torch .ops .torch_ipex .mm_qkv_out (
870
+ hidden_states , self .qkv_proj_weight , None ,
871
+ query_states , key_states , value_states
872
+ )
873
+ else :
874
+ if should_use_xetla_mm_qkv (self , device ):
875
+ if not hasattr (self , "qkv_proj_qweight" ):
876
+ self .qkv_proj_qweight = fuse_qkv_weight_xetla (self .q_proj ,
877
+ self .k_proj ,
878
+ self .v_proj ,
879
+ self .q_proj .weight .qtype ,)
880
+ import xe_linear
881
+ q_out_len = self .q_proj .out_len
882
+ k_out_len = self .k_proj .out_len
883
+ v_out_len = self .v_proj .out_len
884
+ qkv_states = xe_linear .mm_xetla (hidden_states ,
885
+ self .qkv_proj_qweight ,
886
+ self .q_proj .weight .qtype )
887
+ query_states = qkv_states [:, :, :q_out_len ]
888
+ key_states = qkv_states [:, :, q_out_len :q_out_len + k_out_len ]
889
+ value_states = qkv_states [:, :, q_out_len + k_out_len :]
890
+ else :
891
+ query_states = self .q_proj (hidden_states )
892
+ key_states = self .k_proj (hidden_states )
893
+ value_states = self .v_proj (hidden_states )
894
+
895
+ query_states = query_states .view (bsz , q_len ,
896
+ self .num_heads , self .head_dim ).transpose (1 , 2 )
897
+ key_states = key_states .view (bsz , q_len ,
898
+ self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
899
+ value_states = value_states .view (bsz , q_len ,
900
+ self .num_key_value_heads , self .head_dim ).transpose (1 , 2 )
901
+
902
+ kv_seq_len = key_states .shape [- 2 ]
903
+ if past_key_value is not None :
904
+ if self .layer_idx is None :
905
+ invalidInputError (False ,
906
+ "The cache structure has changed since version v4.36. "
907
+ f"If you are using { self .__class__ .__name__ } for "
908
+ "auto-regressive decodingwith k/v caching, please make sure "
909
+ "to initialize the attention class with a layer index." )
910
+ kv_seq_len += past_key_value .get_usable_length (kv_seq_len , self .layer_idx )
911
+
912
+ if use_fuse_rope :
913
+ import xe_addons
914
+ xe_addons .rotary_half_inplaced (self .rotary_emb .inv_freq , position_ids ,
915
+ query_states , key_states )
916
+ else :
917
+ if cache_position is not None :
918
+ # for transformers 4.38.0
919
+ cos , sin = self .rotary_emb (value_states , position_ids )
920
+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states ,
921
+ cos , sin , position_ids , "llama2" )
922
+ else :
923
+ cos , sin = self .rotary_emb (value_states , seq_len = kv_seq_len )
924
+ query_states , key_states = apply_rotary_pos_emb (query_states , key_states ,
925
+ cos , sin , position_ids , "llama" )
926
+
927
+ if past_key_value is not None :
928
+ # update the number of seen tokens
929
+ if self .layer_idx == 0 :
930
+ past_key_value ._seen_tokens += key_states .shape [- 2 ]
931
+
932
+ # reuse k, v, self_attention
933
+ # update `past_key_value` with `key_states` and `value_states` for layer `layer_idx`
934
+ if len (past_key_value .key_cache ) <= self .layer_idx :
935
+ past_key_value .key_cache .append (key_states )
936
+ past_key_value .value_cache .append (value_states )
937
+ else :
938
+ cache_k = past_key_value .key_cache [self .layer_idx ]
939
+ cache_v = past_key_value .value_cache [self .layer_idx ]
940
+
941
+ if not enough_kv_room :
942
+ # allocate new
943
+ new_c_k , new_c_v = extend_kv_cache (bsz ,
944
+ self .num_key_value_heads , # Support GQA
945
+ self .head_dim ,
946
+ cache_k .size (2 ),
947
+ kv_seq_len + KV_CACHE_ALLOC_BLOCK_LENGTH ,
948
+ dtype = cache_k .dtype ,
949
+ device = device )
950
+
951
+ new_c_k [:] = cache_k
952
+ new_c_v [:] = cache_v
953
+ cache_k = new_c_k
954
+ cache_v = new_c_v
955
+
956
+ key_states , value_states = append_kv_cache (cache_k ,
957
+ cache_v ,
958
+ key_states ,
959
+ value_states )
960
+
961
+ # update past_key_value
962
+ past_key_value .key_cache [self .layer_idx ] = key_states
963
+ past_key_value .value_cache [self .layer_idx ] = value_states
964
+
965
+ if cache_position is not None :
966
+ new_attention_mask = attention_mask [:, :, kv_seq_len - q_len :kv_seq_len , 0 :kv_seq_len ]
967
+ else :
968
+ new_attention_mask = attention_mask
969
+
970
+ if not self .training and not hidden_states .requires_grad and \
971
+ use_flash_attention (query_states , key_states , new_attention_mask ):
972
+ # repeat k/v heads if n_kv_heads < n_heads
973
+ key_states = repeat_kv (key_states , self .num_key_value_groups )
974
+ value_states = repeat_kv (value_states , self .num_key_value_groups )
975
+ # now only use flash attention for first token
976
+ attn_output = F .scaled_dot_product_attention (query_states .to (device , dtype = torch .float16 ),
977
+ key_states .to (device , dtype = torch .float16 ),
978
+ value_states .to (device , dtype = torch .float16 ),
979
+ is_causal = True )
980
+ attn_weights = None
981
+ elif not self .training and not hidden_states .requires_grad and \
982
+ use_sdp (q_len , key_states .shape [2 ], self .head_dim , query_states ):
983
+ import xe_addons
984
+ attn_output = xe_addons .sdp (query_states , key_states , value_states ,
985
+ new_attention_mask )
986
+ attn_output = attn_output .view (query_states .shape )
987
+ attn_weights = None
988
+ else :
989
+ # repeat k/v heads if n_kv_heads < n_heads
990
+ key_states = repeat_kv (key_states , self .num_key_value_groups )
991
+ value_states = repeat_kv (value_states , self .num_key_value_groups )
992
+ # otherwise, use native attention
993
+ if query_states .device .type == "xpu" :
994
+ attn_output , attn_weights = native_sdp (query_states , key_states , value_states ,
995
+ new_attention_mask , cache_position ,
996
+ bsz , q_len , kv_seq_len ,
997
+ self .head_dim , self .num_heads , output_attentions )
998
+ else :
999
+ # CPU path
1000
+ if not output_attentions :
1001
+ attn_output = torch .nn .functional .scaled_dot_product_attention (
1002
+ query_states ,
1003
+ key_states ,
1004
+ value_states ,
1005
+ attn_mask = new_attention_mask ,
1006
+ dropout_p = self .attention_dropout if self .training else 0.0 ,
1007
+ # The q_len > 1 is necessary to match with
1008
+ # AttentionMaskConverter.to_causal_4d that
1009
+ # does not create a causal mask in case q_len == 1.
1010
+ is_causal = self .is_causal and new_attention_mask is None and q_len > 1 ,
1011
+ )
1012
+ else :
1013
+ attn_output , attn_weights = native_sdp (query_states , key_states , value_states ,
1014
+ new_attention_mask , cache_position ,
1015
+ bsz , q_len , kv_seq_len ,
1016
+ self .head_dim ,
1017
+ self .num_heads , output_attentions )
1018
+
1019
+ attn_output_size = (bsz , self .num_heads , q_len , self .head_dim )
1020
+ if attn_output .size () != attn_output_size :
1021
+ invalidInputError (False ,
1022
+ f"`attn_output` should be of size { attn_output_size } ,"
1023
+ f" but is { attn_output .size ()} " )
1024
+
1025
+ attn_output = attn_output .transpose (1 , 2 ).contiguous ()
1026
+ attn_output = attn_output .reshape (bsz , q_len , self .hidden_size )
1027
+
1028
+ if self .config .pretraining_tp > 1 :
1029
+ attn_output = attn_output .split (self .hidden_size // self .config .pretraining_tp , dim = 2 )
1030
+ o_proj_slices = self .o_proj .weight .split (self .hidden_size // self .config .pretraining_tp ,
1031
+ dim = 1 )
1032
+ attn_output = sum ([F .linear (attn_output [i ], o_proj_slices [i ])
1033
+ for i in range (self .config .pretraining_tp )])
1034
+ else :
1035
+ attn_output = self .o_proj (attn_output )
1036
+
1037
+ if not output_attentions :
1038
+ attn_weights = None
1039
+
1040
+ return attn_output .to (original_dtype ), attn_weights , past_key_value
1041
+
1042
+
1043
+ def minicpm_attention_forward_4_39 (
1044
+ self ,
1045
+ hidden_states : torch .Tensor ,
1046
+ attention_mask : Optional [torch .Tensor ] = None ,
1047
+ position_ids : Optional [torch .LongTensor ] = None ,
1048
+ past_key_value : Optional [List [torch .FloatTensor ]] = None ,
1049
+ output_attentions : bool = False ,
1050
+ use_cache : bool = False ,
1051
+ cache_position : Optional [torch .LongTensor ] = None ,
1052
+ ** kwargs
1053
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [List [torch .FloatTensor ]]]:
1054
+ if use_quantize_kv_cache (self .q_proj , hidden_states ):
1055
+ forward_function = minicpm_attention_forward_quantized
1056
+ else :
1057
+ forward_function = minicpm_attention_forward_original_4_39
1058
+ return forward_function (
1059
+ self = self ,
1060
+ hidden_states = hidden_states ,
1061
+ attention_mask = attention_mask ,
1062
+ position_ids = position_ids ,
1063
+ past_key_value = past_key_value ,
1064
+ output_attentions = output_attentions ,
1065
+ use_cache = use_cache ,
1066
+ cache_position = cache_position ,
1067
+ kwargs = kwargs
1068
+ )
0 commit comments