|
| 1 | +from typing import List, Optional, Union |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch import nn as nn |
| 5 | +from torch.nn import functional as F |
| 6 | + |
| 7 | +from .config import use_fused_attn |
| 8 | +from .create_conv2d import create_conv2d |
| 9 | +from .helpers import to_2tuple |
| 10 | +from .pool2d_same import create_pool2d |
| 11 | + |
| 12 | + |
| 13 | +class MultiQueryAttentionV2(nn.Module): |
| 14 | + """Multi Query Attention. |
| 15 | +
|
| 16 | + Fast Transformer Decoding: One Write-Head is All You Need |
| 17 | + https://arxiv.org/pdf/1911.02150.pdf |
| 18 | +
|
| 19 | + This is an acceletor optimized version - removing multiple unneccessary |
| 20 | + tensor transpose by re-arranging indices according to the following rules: 1) |
| 21 | + contracted indices are at the end, 2) other indices have the same order in the |
| 22 | + input and output tensores. |
| 23 | +
|
| 24 | + Compared to V1, this gives 3x speed up. |
| 25 | + """ |
| 26 | + |
| 27 | + def __init__( |
| 28 | + self, |
| 29 | + dim: int, |
| 30 | + dim_out: Optional[int] = None, |
| 31 | + num_heads: int = 8, |
| 32 | + key_dim: int = 64, |
| 33 | + value_dim: int = 64, |
| 34 | + attn_drop: float = 0., |
| 35 | + proj_drop: float = 0., |
| 36 | + ): |
| 37 | + """Initializer.""" |
| 38 | + super().__init__() |
| 39 | + dim_out = dim_out or dim |
| 40 | + self.num_heads = num_heads |
| 41 | + self.key_dim = key_dim |
| 42 | + self.value_dim = value_dim |
| 43 | + self.scale = key_dim ** -0.5 |
| 44 | + |
| 45 | + self.query_proj = nn.Parameter(torch.randn([self.num_heads, self.key_dim, dim])) |
| 46 | + self.key_proj = nn.Parameter(torch.randn([dim, self.key_dim])) |
| 47 | + self.value_proj = nn.Parameter(torch.randn([dim, self.value_dim])) |
| 48 | + self.attn_drop = nn.Dropout(attn_drop) |
| 49 | + self.out_proj = nn.Parameter(torch.randn([dim_out, self.num_heads, self.value_dim])) |
| 50 | + self.proj_drop = nn.Dropout(proj_drop) |
| 51 | + |
| 52 | + def _reshape_input(self, t): |
| 53 | + """Reshapes a tensor to three dimensions, keeping the first and last.""" |
| 54 | + s = t.shape |
| 55 | + # Propagate the shape statically where possible. |
| 56 | + #num = t.shape[1:-1].numel() |
| 57 | + #return t.reshape(s[0], num, s[-1]) |
| 58 | + return t.reshape(s[0], s[1], -1).transpose(1, 2) |
| 59 | + |
| 60 | + def forward(self, x, m: Optional[torch.Tensor] = None): |
| 61 | + """Run layer computation.""" |
| 62 | + s = x.shape |
| 63 | + m = m or x |
| 64 | + |
| 65 | + reshaped_x = self._reshape_input(x) |
| 66 | + reshaped_m = self._reshape_input(m) |
| 67 | + |
| 68 | + q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj) |
| 69 | + k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj) |
| 70 | + |
| 71 | + attn = torch.einsum('bnhk,bmk->bnhm', q, k) |
| 72 | + attn = attn.softmax(dim=-1) |
| 73 | + attn = self.attn_drop(attn) |
| 74 | + |
| 75 | + v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj) |
| 76 | + o = torch.einsum('bnhm,bmv->bnhv', attn, v) |
| 77 | + result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj) |
| 78 | + result = self.proj_drop(result) |
| 79 | + return result.reshape(s) |
| 80 | + |
| 81 | + |
| 82 | +class MultiQueryAttention2d(nn.Module): |
| 83 | + """Multi Query Attention with spatial downsampling. |
| 84 | +
|
| 85 | + 3 parameters are introduced for the spatial downsampling: |
| 86 | + 1. kv_stride: downsampling factor on Key and Values only. |
| 87 | + 2. query_strides: horizontal & vertical strides on Query only. |
| 88 | +
|
| 89 | + This is an optimized version. |
| 90 | + 1. Projections in Attention is explict written out as 1x1 Conv2D. |
| 91 | + 2. Additional reshapes are introduced to bring a up to 3x speed up. |
| 92 | + """ |
| 93 | + fused_attn: torch.jit.Final[bool] |
| 94 | + |
| 95 | + def __init__( |
| 96 | + self, |
| 97 | + dim: int, |
| 98 | + dim_out: Optional[int] = None, |
| 99 | + num_heads: int = 8, |
| 100 | + key_dim: Optional[int] = None, |
| 101 | + value_dim: Optional[int] = None, |
| 102 | + query_strides: int = 1, |
| 103 | + kv_stride: int = 1, |
| 104 | + dw_kernel_size: int = 3, |
| 105 | + dilation: int = 1, |
| 106 | + padding: Union[str, int, List[int]] = '', |
| 107 | + attn_drop: float = 0., |
| 108 | + proj_drop: float = 0., |
| 109 | + norm_layer: nn.Module = nn.BatchNorm2d, |
| 110 | + use_bias: bool = False, |
| 111 | + ): |
| 112 | + """Initializer. |
| 113 | +
|
| 114 | + Args: |
| 115 | + num_heads: Number of attention heads. |
| 116 | + key_dim: Size of the attention key dimension. |
| 117 | + value_dim: Size of the attention value dimension. |
| 118 | + query_strides: Vertical stride size for query only. |
| 119 | + kv_stride: Key and value stride size. |
| 120 | + dw_kernel_size: Spatial dimension of the depthwise kernel. |
| 121 | + """ |
| 122 | + super().__init__() |
| 123 | + dim_out = dim_out or dim |
| 124 | + self.num_heads = num_heads |
| 125 | + self.key_dim = key_dim or dim // num_heads |
| 126 | + self.value_dim = value_dim or dim // num_heads |
| 127 | + self.query_strides = to_2tuple(query_strides) |
| 128 | + self.kv_stride = kv_stride |
| 129 | + self.has_query_strides = any([s > 1 for s in self.query_strides]) |
| 130 | + self.scale = self.key_dim ** -0.5 |
| 131 | + self.fused_attn = use_fused_attn() |
| 132 | + self.drop = attn_drop |
| 133 | + |
| 134 | + self.query = nn.Sequential() |
| 135 | + if self.has_query_strides: |
| 136 | + # FIXME dilation |
| 137 | + self.query.add_module('down_pool', create_pool2d( |
| 138 | + 'avg', |
| 139 | + kernel_size=self.query_strides, |
| 140 | + padding=padding, |
| 141 | + )) |
| 142 | + self.query.add_module('norm', norm_layer(dim)) |
| 143 | + self.query.add_module('proj', create_conv2d( |
| 144 | + dim, |
| 145 | + self.num_heads * self.key_dim, |
| 146 | + kernel_size=1, |
| 147 | + bias=use_bias, |
| 148 | + )) |
| 149 | + |
| 150 | + self.key = nn.Sequential() |
| 151 | + if kv_stride > 1: |
| 152 | + self.key.add_module('down_conv', create_conv2d( |
| 153 | + dim, |
| 154 | + dim, |
| 155 | + kernel_size=dw_kernel_size, |
| 156 | + stride=kv_stride, |
| 157 | + dilation=dilation, |
| 158 | + padding=padding, |
| 159 | + depthwise=True, |
| 160 | + )) |
| 161 | + self.key.add_module('norm', norm_layer(dim)) |
| 162 | + self.key.add_module('proj', create_conv2d( |
| 163 | + dim, |
| 164 | + self.key_dim, |
| 165 | + kernel_size=1, |
| 166 | + padding=padding, |
| 167 | + bias=use_bias, |
| 168 | + )) |
| 169 | + |
| 170 | + self.value = nn.Sequential() |
| 171 | + if kv_stride > 1: |
| 172 | + self.value.add_module('down_conv', create_conv2d( |
| 173 | + dim, |
| 174 | + dim, |
| 175 | + kernel_size=dw_kernel_size, |
| 176 | + stride=kv_stride, |
| 177 | + dilation=dilation, |
| 178 | + padding=padding, |
| 179 | + depthwise=True, |
| 180 | + )) |
| 181 | + self.value.add_module('norm', norm_layer(dim)) |
| 182 | + self.value.add_module('proj', create_conv2d( |
| 183 | + dim, |
| 184 | + self.value_dim, |
| 185 | + kernel_size=1, |
| 186 | + bias=use_bias, |
| 187 | + )) |
| 188 | + |
| 189 | + self.attn_drop = nn.Dropout(attn_drop) |
| 190 | + |
| 191 | + self.output = nn.Sequential() |
| 192 | + if self.has_query_strides: |
| 193 | + self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False)) |
| 194 | + self.output.add_module('proj', create_conv2d( |
| 195 | + self.value_dim * self.num_heads, |
| 196 | + dim_out, |
| 197 | + kernel_size=1, |
| 198 | + bias=use_bias, |
| 199 | + )) |
| 200 | + self.output.add_module('drop', nn.Dropout(proj_drop)) |
| 201 | + |
| 202 | + self.einsum = False |
| 203 | + |
| 204 | + def _reshape_input(self, t: torch.Tensor): |
| 205 | + """Reshapes a tensor to three dimensions, keeping the batch and channels.""" |
| 206 | + s = t.shape |
| 207 | + t = t.reshape(s[0], s[1], -1).transpose(1, 2) |
| 208 | + if self.einsum: |
| 209 | + return t |
| 210 | + else: |
| 211 | + return t.unsqueeze(1).contiguous() |
| 212 | + |
| 213 | + def _reshape_projected_query(self, t: torch.Tensor, num_heads: int, key_dim: int): |
| 214 | + """Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k].""" |
| 215 | + s = t.shape |
| 216 | + t = t.reshape(s[0], num_heads, key_dim, -1) |
| 217 | + if self.einsum: |
| 218 | + return t.permute(0, 3, 1, 2).contiguous() |
| 219 | + else: |
| 220 | + return t.transpose(-1, -2).contiguous() |
| 221 | + |
| 222 | + def _reshape_output(self, t: torch.Tensor, num_heads: int, h_px: int, w_px: int): |
| 223 | + """Reshape output:[b, n x n x h, k] -> [b, n, n, hk].""" |
| 224 | + s = t.shape |
| 225 | + feat_dim = s[-1] * num_heads |
| 226 | + if not self.einsum: |
| 227 | + t = t.transpose(1, 2) |
| 228 | + return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous() |
| 229 | + |
| 230 | + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): |
| 231 | + """Run layer computation.""" |
| 232 | + B, C, H, W = s = x.shape |
| 233 | + |
| 234 | + q = self.query(x) |
| 235 | + # desired q shape: [b, h, k, n x n] - [b, l, h, k] |
| 236 | + q = self._reshape_projected_query(q, self.num_heads, self.key_dim) |
| 237 | + |
| 238 | + k = self.key(x) |
| 239 | + # output shape of k: [b, k, p], p = m x m |
| 240 | + k = self._reshape_input(k) |
| 241 | + |
| 242 | + v = self.value(x) |
| 243 | + # output shape of v: [ b, p, k], p = m x m |
| 244 | + v = self._reshape_input(v) |
| 245 | + |
| 246 | + # desired q shape: [b, n x n, h, k] |
| 247 | + # desired k shape: [b, m x m, k] |
| 248 | + # desired logits shape: [b, n x n, h, m x m] |
| 249 | + if self.einsum: |
| 250 | + attn = torch.einsum('blhk,bpk->blhp', q, k) * self.scale |
| 251 | + if attn_mask is not None: |
| 252 | + # NOTE: assumes mask is float and in correct shape |
| 253 | + attn = attn + attn_mask |
| 254 | + attn = attn.softmax(dim=-1) |
| 255 | + attn = self.attn_drop(attn) |
| 256 | + o = torch.einsum('blhp,bpk->blhk', attn, v) |
| 257 | + else: |
| 258 | + if self.fused_attn: |
| 259 | + o = F.scaled_dot_product_attention( |
| 260 | + q, k, v, |
| 261 | + attn_mask=attn_mask, |
| 262 | + dropout_p=self.attn_drop.p if self.training else 0. |
| 263 | + ) |
| 264 | + else: |
| 265 | + q = q * self.scale |
| 266 | + attn = q @ k.transpose(-1, -2) |
| 267 | + if attn_mask is not None: |
| 268 | + # NOTE: assumes mask is float and in correct shape |
| 269 | + attn = attn + attn_mask |
| 270 | + attn = attn.softmax(dim=-1) |
| 271 | + attn = self.attn_drop(attn) |
| 272 | + o = attn @ v |
| 273 | + |
| 274 | + # reshape o into [b, hk, n, n,] |
| 275 | + o = self._reshape_output(o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1]) |
| 276 | + x = self.output(o) |
| 277 | + return x |
| 278 | + |
| 279 | + |
| 280 | +class Attention2d(nn.Module): |
| 281 | + fused_attn: torch.jit.Final[bool] |
| 282 | + |
| 283 | + """ multi-head attention for 2D NCHW tensors""" |
| 284 | + def __init__( |
| 285 | + self, |
| 286 | + dim: int, |
| 287 | + dim_out: Optional[int] = None, |
| 288 | + num_heads: int = 32, |
| 289 | + bias: bool = True, |
| 290 | + expand_first: bool = False, |
| 291 | + head_first: bool = False, |
| 292 | + attn_drop: float = 0., |
| 293 | + proj_drop: float = 0. |
| 294 | + ): |
| 295 | + super().__init__() |
| 296 | + dim_out = dim_out or dim |
| 297 | + dim_attn = dim_out if expand_first else dim |
| 298 | + self.num_heads = num_heads |
| 299 | + self.dim_head = dim_attn // num_heads |
| 300 | + self.head_first = head_first |
| 301 | + self.scale = num_heads ** -0.5 |
| 302 | + self.fused_attn = use_fused_attn() |
| 303 | + |
| 304 | + self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) |
| 305 | + self.attn_drop = nn.Dropout(attn_drop) |
| 306 | + self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) |
| 307 | + self.proj_drop = nn.Dropout(proj_drop) |
| 308 | + |
| 309 | + def forward(self, x, attn_mask: Optional[torch.Tensor] = None): |
| 310 | + B, C, H, W = x.shape |
| 311 | + |
| 312 | + if self.head_first: |
| 313 | + q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) |
| 314 | + else: |
| 315 | + q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1) |
| 316 | + |
| 317 | + if self.fused_attn: |
| 318 | + x = torch.nn.functional.scaled_dot_product_attention( |
| 319 | + q.transpose(-1, -2).contiguous(), |
| 320 | + k.transpose(-1, -2).contiguous(), |
| 321 | + v.transpose(-1, -2).contiguous(), |
| 322 | + attn_mask=attn_mask, |
| 323 | + dropout_p=self.attn_drop.p if self.training else 0., |
| 324 | + ).transpose(-1, -2).reshape(B, -1, H, W) |
| 325 | + else: |
| 326 | + q = q * self.scale |
| 327 | + attn = q.transpose(-2, -1) @ k |
| 328 | + if attn_mask is not None: |
| 329 | + # NOTE: assumes mask is float and in correct shape |
| 330 | + attn = attn + attn_mask |
| 331 | + attn = attn.softmax(dim=-1) |
| 332 | + attn = self.attn_drop(attn) |
| 333 | + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) |
| 334 | + |
| 335 | + x = self.proj(x) |
| 336 | + x = self.proj_drop(x) |
| 337 | + return x |
0 commit comments