Skip to content

Commit 91965b5

Browse files
authored
add glm_sdpa back to fix chatglm-6b (#11313)
1 parent 7f65836 commit 91965b5

File tree

1 file changed

+44
-1
lines changed
  • python/llm/src/ipex_llm/transformers/models

1 file changed

+44
-1
lines changed

python/llm/src/ipex_llm/transformers/models/chatglm.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch.nn.functional as F
2424
from typing import Optional, Tuple
2525
from ipex_llm.transformers.models.utils import init_kv_cache, extend_kv_cache, append_kv_cache
26-
from ipex_llm.transformers.models.chatglm2 import glm_sdpa
26+
from ipex_llm.transformers.models.utils import use_flash_attention, use_sdp
2727

2828

2929
def rotate_half(x):
@@ -39,6 +39,49 @@ def apply_rotary_pos_emb_index(q, k, cos, sin, position_id):
3939
q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
4040
return q, k
4141

42+
43+
def glm_sdpa(query, key, value, attention_mask=None, is_causal=False):
44+
if use_flash_attention(query, key, attention_mask) or query.device.type == 'cpu':
45+
context_layer = F.scaled_dot_product_attention(query.to(key.dtype),
46+
key,
47+
value,
48+
attention_mask,
49+
is_causal=is_causal).to(key.dtype)
50+
else:
51+
# attention_mask is not None only when past_key_value is not None and q_len > 1
52+
if attention_mask is not None:
53+
attn_bias = torch.zeros(attention_mask.shape, dtype=query.dtype,
54+
device=query.device)
55+
attention_mask = ~attention_mask
56+
if attention_mask.dtype == torch.bool:
57+
attn_bias.masked_fill_(attention_mask.logical_not(), float("-inf"))
58+
else:
59+
attn_bias += attention_mask
60+
elif is_causal:
61+
L, S = query.size(-2), key.size(-2)
62+
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device)
63+
temp_mask = torch.ones(L, S, dtype=torch.bool, device=query.device).tril(diagonal=0)
64+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
65+
attn_bias.to(key.dtype)
66+
else:
67+
attn_bias = None
68+
if use_sdp(query.shape[2], key.shape[2],
69+
query.shape[-1], query):
70+
import xe_addons
71+
attn_output = xe_addons.sdp(query, key, value, attn_bias)
72+
context_layer = attn_output.view(query.shape)
73+
else:
74+
head_dim = query.size(-1)
75+
attn = torch.matmul(query.to(key.dtype) / math.sqrt(head_dim),
76+
key.transpose(2, 3))
77+
if attn_bias is not None:
78+
attn += attn_bias
79+
attn = F.softmax(attn, dim=-1,
80+
dtype=torch.float32).to(value.dtype)
81+
context_layer = torch.matmul(attn, value)
82+
return context_layer
83+
84+
4285
import os
4386

4487
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))

0 commit comments

Comments
 (0)