23
23
import torch .nn .functional as F
24
24
from typing import Optional , Tuple
25
25
from ipex_llm .transformers .models .utils import init_kv_cache , extend_kv_cache , append_kv_cache
26
- from ipex_llm .transformers .models .chatglm2 import glm_sdpa
26
+ from ipex_llm .transformers .models .utils import use_flash_attention , use_sdp
27
27
28
28
29
29
def rotate_half (x ):
@@ -39,6 +39,49 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
39
39
q , k = (q * cos ) + (rotate_half (q ) * sin ), (k * cos ) + (rotate_half (k ) * sin )
40
40
return q , k
41
41
42
+
43
+ def glm_sdpa (query , key , value , attention_mask = None , is_causal = False ):
44
+ if use_flash_attention (query , key , attention_mask ) or query .device .type == 'cpu' :
45
+ context_layer = F .scaled_dot_product_attention (query .to (key .dtype ),
46
+ key ,
47
+ value ,
48
+ attention_mask ,
49
+ is_causal = is_causal ).to (key .dtype )
50
+ else :
51
+ # attention_mask is not None only when past_key_value is not None and q_len > 1
52
+ if attention_mask is not None :
53
+ attn_bias = torch .zeros (attention_mask .shape , dtype = query .dtype ,
54
+ device = query .device )
55
+ attention_mask = ~ attention_mask
56
+ if attention_mask .dtype == torch .bool :
57
+ attn_bias .masked_fill_ (attention_mask .logical_not (), float ("-inf" ))
58
+ else :
59
+ attn_bias += attention_mask
60
+ elif is_causal :
61
+ L , S = query .size (- 2 ), key .size (- 2 )
62
+ attn_bias = torch .zeros (L , S , dtype = query .dtype , device = query .device )
63
+ temp_mask = torch .ones (L , S , dtype = torch .bool , device = query .device ).tril (diagonal = 0 )
64
+ attn_bias .masked_fill_ (temp_mask .logical_not (), float ("-inf" ))
65
+ attn_bias .to (key .dtype )
66
+ else :
67
+ attn_bias = None
68
+ if use_sdp (query .shape [2 ], key .shape [2 ],
69
+ query .shape [- 1 ], query ):
70
+ import xe_addons
71
+ attn_output = xe_addons .sdp (query , key , value , attn_bias )
72
+ context_layer = attn_output .view (query .shape )
73
+ else :
74
+ head_dim = query .size (- 1 )
75
+ attn = torch .matmul (query .to (key .dtype ) / math .sqrt (head_dim ),
76
+ key .transpose (2 , 3 ))
77
+ if attn_bias is not None :
78
+ attn += attn_bias
79
+ attn = F .softmax (attn , dim = - 1 ,
80
+ dtype = torch .float32 ).to (value .dtype )
81
+ context_layer = torch .matmul (attn , value )
82
+ return context_layer
83
+
84
+
42
85
import os
43
86
44
87
KV_CACHE_ALLOC_BLOCK_LENGTH = int (os .environ .get ("KV_CACHE_ALLOC_BLOCK_LENGTH" , 256 ))
0 commit comments