Description
🚀 Feature
Making the new MHA implementation even more modular for easy implementation of different attention layers.
Motivation
The new MHA container implementation is already much more flexible than the one in core PyTorch. However, in the current version, when implementing a new attention layer (other than ScaledDotProduct
), one will have to repeat some code of ScaledDotProduct
, which is not optimal.
-
Computation of the attention weights.
-
The aggregation of the values based on the computed weights.
Different attention functions may differ only in the first step, or in the second step, or both.
Pitch
I can think of two solutions:
-
Let the attention layers (e.g.
ScaledDotProduct
) return only the attention weights, then the aggregation of the values is done in the main MHA container. -
Keep the MHA container unchanged by using a general template class for all the attention layers, and let each specific inherit this class.
I've tried both and found that the second solution is much cleaner. I give below an example in which I re-implemented ScaledDotProduct
using this approach, and furthermore, I added another attention layer called GeneralDotProduct
(denoted "general" in Section 3.1 of this paper). (Try adding yourself another attention layer such as GeneralDotProduct
in the current implementation you will see the issue.)
class GeneralAttention(torch.nn.Module):
def __init__(self, dropout=0.):
r"""General template for attention layers.
Args:
dropout (float): probability of dropping an attention weight.
"""
super().__init__()
self.dropout = dropout
def compute_weights(self, query: torch.Tensor, key: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
def compute_outputs(self, value: torch.Tensor,
weights: torch.Tensor,
) -> torch.Tensor:
r"""Computing the attention outputs from value and attention weights
Args:
query (Tensor): Projected query
weights (Tensor): Attention weights
Shape:
- value: :math:`(S, N * H, E / H)`
- weights: :math:`(N * H, L, S)`
- Output: :math:`(L, N * H, E / H)`
where L is the target length, S is the source length, H is the number
of attention heads, N is the batch size, and E is the embedding dimension.
"""
# Transpose: (S, N*H, E/H) --> (N*H, S, E/H)
value = value.transpose(-2, -3)
# (N*H, L, S) times (N*H, S, E/H) --> (N*H, L, E/H)
attn_output = torch.matmul(weights, value)
# Back to (L, N*H, E/H)
attn_output = attn_output.transpose(-2, -3)
return attn_output
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
bias_k: Optional[torch.Tensor] = None,
bias_v: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Attention forward pass.
Args:
query (Tensor): Projected query
key (Tensor): Projected key
value (Tensor): Projected value
attn_mask (BoolTensor, optional): 3D mask that prevents attention
to certain positions.
bias_k and bias_v: (Tensor, optional): one more key and value
sequence to be added at sequence dim (dim=-3). Those are used
for incremental decoding. Users should provide non-None to both
arguments in order to activate them.
Shape:
- query: :math:`(L, N * H, E / H)`
- key: :math:`(S, N * H, E / H)`
- value: :math:`(S, N * H, E / H)`
- attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not
allowed to attend while ``False`` values will be unchanged.
- bias_k and bias_v:bias: :math:`(1, N * H, E / H)`
- Output: :math:`(L, N * H, E / H)`, :math:`(N * H, L, S)`
where L is the target length, S is the source length, H is the number
of attention heads, N is the batch size, and E is the embedding dimension.
"""
if bias_k is not None and bias_v is not None:
assert (key.size(-1) == bias_k.size(-1) and
key.size(-2) == bias_k.size(-2) and
bias_k.size(-3) == 1), "Shape of bias_k is not supported"
assert (value.size(-1) == bias_v.size(-1) and
value.size(-2) == bias_v.size(-2) and
bias_v.size(-3) == 1), "Shape of bias_v is not supported"
key = torch.cat([key, bias_k])
value = torch.cat([value, bias_v])
if attn_mask is not None:
attn_mask = torch.nn.functional.pad(attn_mask, (0, 1))
# Compute attention weights
attn_weights = self.compute_weights(query, key, attn_mask=attn_mask)
# Add dropout
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
# Then compute the attention outputs
attn_output = self.compute_outputs(value, weights=attn_weights)
return attn_output, attn_weights
class ScaledDotProduct(GeneralAttention):
r"""Processes a projected query and key-value pair to apply
scaled dot product attention.
Examples::
>>> SDP = torchtext.modules.ScaledDotProduct(dropout=0.1)
>>> q = torch.randn(256, 21, 3)
>>> k = v = torch.randn(256, 21, 3)
>>> attn_output, attn_weights = SDP(q, k, v)
>>> print(attn_output.shape, attn_weights.shape)
torch.Size([256, 21, 3]) torch.Size([256, 21, 21])
"""
def compute_weights(self, query: torch.Tensor, key: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
r"""Uses a scaled dot product with the projected key-value pair to
compute the attention weights.
Args:
query (Tensor): Projected query
key (Tensor): Projected key
attn_mask (BoolTensor, optional): 3D mask that prevents attention
to certain positions.
Shape:
- query: :math:`(L, N * H, E / H)`
- key: :math:`(S, N * H, E / H)`
- attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not
allowed to attend while ``False`` values will be unchanged.
- Output: :math:`(N * H, L, S)`
where L is the target length, S is the source length, H is the number
of attention heads, N is the batch size, and E is the embedding dimension.
"""
tgt_len, head_dim = query.size(-3), query.size(-1)
assert query.size(-1) == key.size(-1), "Feature dims of query and key must equal."
src_len = key.size(-3)
batch_heads = max(query.size(-2), key.size(-2))
# Scale query
query, key = query.transpose(-2, -3), key.transpose(-2, -3)
query = query * (float(head_dim) ** -0.5)
# Attention weights: dot product of q, k
attn_weights = torch.matmul(query, key.transpose(-2, -1))
if attn_mask is not None:
if attn_mask.dim() != 3:
raise RuntimeError('attn_mask must be a 3D tensor.')
if (attn_mask.size(-1) != src_len) or (attn_mask.size(-2) != tgt_len) or \
(attn_mask.size(-3) != 1 and attn_mask.size(-3) != batch_heads):
raise RuntimeError('The size of the attn_mask is not correct.')
if attn_mask.dtype != torch.bool:
raise RuntimeError('Only bool tensor is supported for attn_mask')
attn_weights.masked_fill_(attn_mask, -1e8,)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
return attn_weights
class GeneralDotProduct(GeneralAttention):
def __init__(self, embed_dim, kdim=None, dropout=0.):
r"""Processes a projected query and key-value pair to apply the general
vector-matrix-vector product attention.
Examples::
>>> embed_dim, kdim = 5, 3
>>> GDP = torchtext.modules.GeneralDotProduct(embed_dim, kdim=kdim, dropout=0.1)
>>> q = torch.randn(256, 21, embed_dim)
>>> k = v = torch.randn(256, 12, kdim)
>>> attn_output, attn_weights = GDP(q, k, v)
>>> print(attn_output.shape, attn_weights.shape)
torch.Size([256, 21, 3]) torch.Size([256, 21, 12])
Args:
dropout (float): probability of dropping an attention weight.
"""
super().__init__(dropout=dropout)
kdim = embed_dim if kdim is None else kdim
self.W = torch.nn.Parameter(torch.empty(embed_dim, kdim))
def compute_weights(self, query: torch.Tensor, key: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
r"""Uses a scaled dot product with the projected key-value pair to update
the projected query.
Args:
query (Tensor): Projected query
key (Tensor): Projected key
attn_mask (BoolTensor, optional): 3D mask that prevents attention
to certain positions.
Shape:
- query: :math:`(L, N * H, E / H)`
- key: :math:`(S, N * H, K / H)`
- attn_mask: :math:`(N * H, L, S)`, positions with ``True`` are not
allowed to attend while ``False`` values will be unchanged.
- Output: :math:`(N * H, L, S)`
where L is the target length, S is the source length, H is the number
of attention heads, N is the batch size, and E is the embedding dimension,
K is the key dimension.
"""
tgt_len, head_dim = query.size(-3), query.size(-1)
assert (query.size(-1) == self.W.shape[0] and
key.size(-1) == self.W.shape[1]), "Feature dims not match."
src_len = key.size(-3)
batch_heads = max(query.size(-2), key.size(-2))
# (L, N * H, E/H) --> (N * H, L, E/H), (S, N * H, K/H) --> (N * H, S, K/H)
query, key = query.transpose(-2, -3), key.transpose(-2, -3)
# Attention weights: dot product of q, k
# W is (E/H, K/H)
attn_weights = torch.matmul(query, self.W)
print(f'attn_weights = {attn_weights.shape}, key = {key.shape}')
attn_weights = torch.matmul(attn_weights, key.transpose(-2, -1))
if attn_mask is not None:
if attn_mask.dim() != 3:
raise RuntimeError('attn_mask must be a 3D tensor.')
if (attn_mask.size(-1) != src_len) or (attn_mask.size(-2) != tgt_len) or \
(attn_mask.size(-3) != 1 and attn_mask.size(-3) != batch_heads):
raise RuntimeError('The size of the attn_mask is not correct.')
if attn_mask.dtype != torch.bool:
raise RuntimeError('Only bool tensor is supported for attn_mask')
attn_weights.masked_fill_(attn_mask, -1e8,)
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
return attn_weights
@zhangguanheng66 Are you interested in such a PR?