Skip to content

Make MHA even more flexible #884

Open
@netw0rkf10w

Description

@netw0rkf10w

🚀 Feature

Making the new MHA implementation even more modular for easy implementation of different attention layers.

Motivation

The new MHA container implementation is already much more flexible than the one in core PyTorch. However, in the current version, when implementing a new attention layer (other than ScaledDotProduct), one will have to repeat some code of ScaledDotProduct, which is not optimal.

  • Computation of the attention weights.

  • The aggregation of the values based on the computed weights.

Different attention functions may differ only in the first step, or in the second step, or both.

Pitch

I can think of two solutions:

  1. Let the attention layers (e.g. ScaledDotProduct) return only the attention weights, then the aggregation of the values is done in the main MHA container.

  2. Keep the MHA container unchanged by using a general template class for all the attention layers, and let each specific inherit this class.

I've tried both and found that the second solution is much cleaner. I give below an example in which I re-implemented ScaledDotProduct using this approach, and furthermore, I added another attention layer called GeneralDotProduct (denoted "general" in Section 3.1 of this paper). (Try adding yourself another attention layer such as GeneralDotProduct in the current implementation you will see the issue.)

class GeneralAttention(torch.nn.Module):

    def __init__(self, dropout=0.):
        r"""General template for attention layers.

        Args:
            dropout (float): probability of dropping an attention weight.
        """
        super().__init__()
        self.dropout = dropout

    def compute_weights(self, query: torch.Tensor, key: torch.Tensor,
                        attn_mask: Optional[torch.Tensor] = None,
                        ) -> torch.Tensor:
        raise NotImplementedError

    def compute_outputs(self, value: torch.Tensor,
                        weights: torch.Tensor,
                        ) -> torch.Tensor:
        r"""Computing the attention outputs from value and attention weights

        Args:
            query (Tensor): Projected query
            weights (Tensor): Attention weights 
        Shape:
            - value: :math:`(S, N * H, E / H)`
            - weights: :math:`(N * H, L, S)`

            - Output: :math:`(L, N * H, E / H)`

            where L is the target length, S is the source length, H is the number
            of attention heads, N is the batch size, and E is the embedding dimension.
        """
        # Transpose: (S, N*H, E/H) --> (N*H, S, E/H)
        value = value.transpose(-2, -3)
        # (N*H, L, S) times (N*H, S, E/H) --> (N*H, L, E/H)
        attn_output = torch.matmul(weights, value)
        # Back to (L, N*H, E/H)
        attn_output = attn_output.transpose(-2, -3)

        return attn_output


    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
                attn_mask: Optional[torch.Tensor] = None,
                bias_k: Optional[torch.Tensor] = None,
                bias_v: Optional[torch.Tensor] = None,
                ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Attention forward pass.

        Args:
            query (Tensor): Projected query
            key (Tensor): Projected key
            value (Tensor): Projected value
            attn_mask (BoolTensor, optional): 3D mask that prevents attention
                to certain positions.
            bias_k and bias_v: (Tensor, optional): one more key and value
                sequence to be added at sequence dim (dim=-3). Those are used
                for incremental decoding. Users should provide non-None to both
                arguments in order to activate them.
        Shape:
            - query: :math:`(L, N * H, E / H)`
            - key: :math:`(S, N * H, E / H)`
            - value: :math:`(S, N * H, E / H)`
            - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not
                allowed to attend while ``False`` values will be unchanged.
            - bias_k and bias_v:bias: :math:`(1, N * H, E / H)`

            - Output: :math:`(L, N * H, E / H)`, :math:`(N * H, L, S)`

            where L is the target length, S is the source length, H is the number
            of attention heads, N is the batch size, and E is the embedding dimension.
        """
        if bias_k is not None and bias_v is not None:
            assert (key.size(-1) == bias_k.size(-1) and
                    key.size(-2) == bias_k.size(-2) and
                    bias_k.size(-3) == 1), "Shape of bias_k is not supported"
            assert (value.size(-1) == bias_v.size(-1) and
                    value.size(-2) == bias_v.size(-2) and
                    bias_v.size(-3) == 1), "Shape of bias_v is not supported"
            key = torch.cat([key, bias_k])
            value = torch.cat([value, bias_v])
            if attn_mask is not None:
                attn_mask = torch.nn.functional.pad(attn_mask, (0, 1))

        # Compute attention weights
        attn_weights = self.compute_weights(query, key, attn_mask=attn_mask)
        # Add dropout
        attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
        # Then compute the attention outputs
        attn_output = self.compute_outputs(value, weights=attn_weights)
        
        return attn_output, attn_weights
        

class ScaledDotProduct(GeneralAttention):
    r"""Processes a projected query and key-value pair to apply
        scaled dot product attention.

        Examples::
            >>> SDP = torchtext.modules.ScaledDotProduct(dropout=0.1)
            >>> q = torch.randn(256, 21, 3)
            >>> k = v = torch.randn(256, 21, 3)
            >>> attn_output, attn_weights = SDP(q, k, v)
            >>> print(attn_output.shape, attn_weights.shape)
            torch.Size([256, 21, 3]) torch.Size([256, 21, 21])
        """

    def compute_weights(self, query: torch.Tensor, key: torch.Tensor,
                        attn_mask: Optional[torch.Tensor] = None
                        ) -> torch.Tensor:
        r"""Uses a scaled dot product with the projected key-value pair to 
        compute the attention weights.

        Args:
            query (Tensor): Projected query
            key (Tensor): Projected key
            attn_mask (BoolTensor, optional): 3D mask that prevents attention
                to certain positions.
        Shape:
            - query: :math:`(L, N * H, E / H)`
            - key: :math:`(S, N * H, E / H)`
            - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not
                allowed to attend while ``False`` values will be unchanged.

            - Output: :math:`(N * H, L, S)`

            where L is the target length, S is the source length, H is the number
            of attention heads, N is the batch size, and E is the embedding dimension.
        """
        tgt_len, head_dim = query.size(-3), query.size(-1)
        assert query.size(-1) == key.size(-1), "Feature dims of query and key must equal."
        src_len = key.size(-3)
        batch_heads = max(query.size(-2), key.size(-2))

        # Scale query
        query, key = query.transpose(-2, -3), key.transpose(-2, -3)
        query = query * (float(head_dim) ** -0.5)

        # Attention weights: dot product of q, k
        attn_weights = torch.matmul(query, key.transpose(-2, -1))
        if attn_mask is not None:
            if attn_mask.dim() != 3:
                raise RuntimeError('attn_mask must be a 3D tensor.')
            if (attn_mask.size(-1) != src_len) or (attn_mask.size(-2) != tgt_len) or \
               (attn_mask.size(-3) != 1 and attn_mask.size(-3) != batch_heads):
                raise RuntimeError('The size of the attn_mask is not correct.')
            if attn_mask.dtype != torch.bool:
                raise RuntimeError('Only bool tensor is supported for attn_mask')

            attn_weights.masked_fill_(attn_mask, -1e8,)
        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
        
        return attn_weights


class GeneralDotProduct(GeneralAttention):

    def __init__(self, embed_dim, kdim=None, dropout=0.):
        r"""Processes a projected query and key-value pair to apply the general
        vector-matrix-vector product attention.

        Examples::
            >>> embed_dim, kdim = 5, 3
            >>> GDP = torchtext.modules.GeneralDotProduct(embed_dim, kdim=kdim, dropout=0.1)
            >>> q = torch.randn(256, 21, embed_dim)
            >>> k = v = torch.randn(256, 12, kdim)
            >>> attn_output, attn_weights = GDP(q, k, v)
            >>> print(attn_output.shape, attn_weights.shape)
            torch.Size([256, 21, 3]) torch.Size([256, 21, 12])

        Args:
            dropout (float): probability of dropping an attention weight.
        """
        super().__init__(dropout=dropout)
        kdim = embed_dim if kdim is None else kdim
        self.W = torch.nn.Parameter(torch.empty(embed_dim, kdim))

    def compute_weights(self, query: torch.Tensor, key: torch.Tensor,
                        attn_mask: Optional[torch.Tensor] = None
                        ) -> torch.Tensor:
        r"""Uses a scaled dot product with the projected key-value pair to update
        the projected query.

        Args:
            query (Tensor): Projected query
            key (Tensor): Projected key
            attn_mask (BoolTensor, optional): 3D mask that prevents attention
                to certain positions.
        Shape:
            - query: :math:`(L, N * H, E / H)`
            - key: :math:`(S, N * H, K / H)`
            - attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not
                allowed to attend while ``False`` values will be unchanged.

            - Output: :math:`(N * H, L, S)`

            where L is the target length, S is the source length, H is the number
            of attention heads, N is the batch size, and E is the embedding dimension,
            K is the key dimension.
        """
        tgt_len, head_dim = query.size(-3), query.size(-1)
        assert (query.size(-1) == self.W.shape[0] and
                key.size(-1) == self.W.shape[1]), "Feature dims not match."
        src_len = key.size(-3)
        batch_heads = max(query.size(-2), key.size(-2))

        # (L, N * H, E/H) --> (N * H, L, E/H), (S, N * H, K/H) --> (N * H, S, K/H)
        query, key = query.transpose(-2, -3), key.transpose(-2, -3)

        # Attention weights: dot product of q, k
        # W is (E/H, K/H)
        attn_weights = torch.matmul(query, self.W)
        print(f'attn_weights = {attn_weights.shape}, key = {key.shape}')
        attn_weights = torch.matmul(attn_weights, key.transpose(-2, -1))
        
        if attn_mask is not None:
            if attn_mask.dim() != 3:
                raise RuntimeError('attn_mask must be a 3D tensor.')
            if (attn_mask.size(-1) != src_len) or (attn_mask.size(-2) != tgt_len) or \
               (attn_mask.size(-3) != 1 and attn_mask.size(-3) != batch_heads):
                raise RuntimeError('The size of the attn_mask is not correct.')
            if attn_mask.dtype != torch.bool:
                raise RuntimeError('Only bool tensor is supported for attn_mask')

            attn_weights.masked_fill_(attn_mask, -1e8,)

        attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
        
        return attn_weights

@zhangguanheng66 Are you interested in such a PR?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions