Skip to content

Commit 96dee4e

Browse files
committed
wip: qk attention locations
Signed-off-by: Kyle Sayers <[email protected]>
1 parent d77bcef commit 96dee4e

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import torch
18+
from transformers import AttentionInterface
19+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
20+
21+
22+
# TODO: HF acknowledgement
23+
def transformable_attention(
24+
module: torch.nn.Module,
25+
query: torch.Tensor,
26+
key: torch.Tensor,
27+
value: torch.Tensor,
28+
attention_mask: Optional[torch.Tensor],
29+
scaling: float,
30+
dropout: float = 0.0,
31+
**kwargs,
32+
):
33+
"""
34+
Hook to potentially call transforms
35+
"""
36+
from compressed_tensors.transform import TransformBase, TransformLocation
37+
38+
for submodule in module.children():
39+
if isinstance(submodule, TransformBase):
40+
if TransformBase.args.location == TransformLocation.Q_ATTN:
41+
query = submodule(query)
42+
43+
if TransformBase.args.location == TransformLocation.K_CACHE:
44+
key = submodule(key)
45+
46+
return ALL_ATTENTION_FUNCTIONS["sdpa"](
47+
module, query, key, value, attention_mask, scaling, dropout, **kwargs
48+
) # TODO: use original setting from config
49+
50+
51+
# another way to do this, to get around random AttentionInterface register and messing with the config
52+
# would be to just patch the ALL_ATTENTION_FUNCTIONS
53+
# we already have to patch the config's attn_implementation anyways
54+
55+
AttentionInterface.register("transformable_attention", transformable_attention)
56+
# model.config.attn_implementation = "transformable_attention"

0 commit comments

Comments
 (0)