-
Notifications
You must be signed in to change notification settings - Fork 311
Open
Labels
Description
When I tested the register_attention_for_kv_quant function of modelopt, I found the following issues:
a simple attention pytorch code is follows:
class CommonAttention(nn.Module):
def forward(self, qkv: tuple, extra_arg=None):
# NOTE: Add unused argument y with default value to test that replaced attention retain original defaults
q, k, v = qkv
attn = q @ k
attn = F.softmax(attn)
attn = (attn @ v)
return attn
@classmethod
def get_input(cls, device: str = "cpu"):
q = torch.randn(1, 4, 8, device=device)
k = torch.randn(1, 4, 8, device=device)
v = torch.randn(1, 4, 8, device=device)
return (q, k, v),
After I applied register_attention_for_kv_quant to CommonAttention,CommonAttention is follows:
class _QuantCommonAttention(nn.Module):
def forward(self, qkv: tuple, extra_arg=None):
q, k, v = qkv
k = k.transpose(-2, -1)
attn = q @ self.v_bmm_quantizer(k)
attn = F.softmax(attn)
attn = self.q_bmm_quantizer(attn) @ torch.transpose(self.k_bmm_quantizer(torch.transpose(v, -1, -2)), -1, -2)
return attn
@classmethod
def get_input(cls, device: str='cpu'):
q = torch.randn(1, 4, 8, device=device)
k = torch.randn(1, 4, 8, device=device)
v = torch.randn(1, 4, 8, device=device)
return (q, k, v)
Is this correct?
Reactions are currently unavailable