Skip to content

Bug for register_attention_for_kv_quant #1064

@AIgiraffe

Description

@AIgiraffe

When I tested the register_attention_for_kv_quant function of modelopt, I found the following issues:
a simple attention pytorch code is follows:

class CommonAttention(nn.Module):
def forward(self, qkv: tuple, extra_arg=None):
# NOTE: Add unused argument y with default value to test that replaced attention retain original defaults
q, k, v = qkv

    attn = q @ k
    attn = F.softmax(attn)

    attn = (attn @ v)
    return attn 

@classmethod
def get_input(cls, device: str = "cpu"):
    q = torch.randn(1, 4, 8, device=device)
    k = torch.randn(1, 4, 8, device=device)
    v = torch.randn(1, 4, 8, device=device)
    return (q, k, v),

After I applied register_attention_for_kv_quant to CommonAttention,CommonAttention is follows:

class _QuantCommonAttention(nn.Module):

def forward(self, qkv: tuple, extra_arg=None):
    q, k, v = qkv
    k = k.transpose(-2, -1)
    attn = q @ self.v_bmm_quantizer(k)
    attn = F.softmax(attn)
    attn = self.q_bmm_quantizer(attn) @ torch.transpose(self.k_bmm_quantizer(torch.transpose(v, -1, -2)), -1, -2)
    return attn

@classmethod
def get_input(cls, device: str='cpu'):
    q = torch.randn(1, 4, 8, device=device)
    k = torch.randn(1, 4, 8, device=device)
    v = torch.randn(1, 4, 8, device=device)
    return (q, k, v)

Is this correct?

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions