Skip to content

Commit 5265984

Browse files
authored
Merge pull request #2196 from huggingface/mega_merge
Mega merge
2 parents 5dce710 + 7ccb10e commit 5265984

19 files changed

+2712
-315
lines changed

tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,15 @@
5252
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
5353
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
5454
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
55-
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera',
55+
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit',
5656
]
5757

5858
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
5959
NON_STD_FILTERS = [
6060
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
6161
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
6262
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*',
63-
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*'
63+
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*'
6464
]
6565
NUM_NON_STD = len(NON_STD_FILTERS)
6666

timm/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from .activations import *
22
from .adaptive_avgmax_pool import \
33
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
4+
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
45
from .attention_pool import AttentionPoolLatent
56
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
6-
from .blur_pool import BlurPool2d
7+
from .blur_pool import BlurPool2d, create_aa
78
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
89
from .cond_conv2d import CondConv2d, get_condconv_initializer
910
from .config import is_exportable, is_scriptable, is_no_jit, use_fused_attn, \

timm/layers/attention2d.py

Lines changed: 337 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,337 @@
1+
from typing import List, Optional, Union
2+
3+
import torch
4+
from torch import nn as nn
5+
from torch.nn import functional as F
6+
7+
from .config import use_fused_attn
8+
from .create_conv2d import create_conv2d
9+
from .helpers import to_2tuple
10+
from .pool2d_same import create_pool2d
11+
12+
13+
class MultiQueryAttentionV2(nn.Module):
14+
"""Multi Query Attention.
15+
16+
Fast Transformer Decoding: One Write-Head is All You Need
17+
https://arxiv.org/pdf/1911.02150.pdf
18+
19+
This is an acceletor optimized version - removing multiple unneccessary
20+
tensor transpose by re-arranging indices according to the following rules: 1)
21+
contracted indices are at the end, 2) other indices have the same order in the
22+
input and output tensores.
23+
24+
Compared to V1, this gives 3x speed up.
25+
"""
26+
27+
def __init__(
28+
self,
29+
dim: int,
30+
dim_out: Optional[int] = None,
31+
num_heads: int = 8,
32+
key_dim: int = 64,
33+
value_dim: int = 64,
34+
attn_drop: float = 0.,
35+
proj_drop: float = 0.,
36+
):
37+
"""Initializer."""
38+
super().__init__()
39+
dim_out = dim_out or dim
40+
self.num_heads = num_heads
41+
self.key_dim = key_dim
42+
self.value_dim = value_dim
43+
self.scale = key_dim ** -0.5
44+
45+
self.query_proj = nn.Parameter(torch.randn([self.num_heads, self.key_dim, dim]))
46+
self.key_proj = nn.Parameter(torch.randn([dim, self.key_dim]))
47+
self.value_proj = nn.Parameter(torch.randn([dim, self.value_dim]))
48+
self.attn_drop = nn.Dropout(attn_drop)
49+
self.out_proj = nn.Parameter(torch.randn([dim_out, self.num_heads, self.value_dim]))
50+
self.proj_drop = nn.Dropout(proj_drop)
51+
52+
def _reshape_input(self, t):
53+
"""Reshapes a tensor to three dimensions, keeping the first and last."""
54+
s = t.shape
55+
# Propagate the shape statically where possible.
56+
#num = t.shape[1:-1].numel()
57+
#return t.reshape(s[0], num, s[-1])
58+
return t.reshape(s[0], s[1], -1).transpose(1, 2)
59+
60+
def forward(self, x, m: Optional[torch.Tensor] = None):
61+
"""Run layer computation."""
62+
s = x.shape
63+
m = m or x
64+
65+
reshaped_x = self._reshape_input(x)
66+
reshaped_m = self._reshape_input(m)
67+
68+
q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj)
69+
k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj)
70+
71+
attn = torch.einsum('bnhk,bmk->bnhm', q, k)
72+
attn = attn.softmax(dim=-1)
73+
attn = self.attn_drop(attn)
74+
75+
v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj)
76+
o = torch.einsum('bnhm,bmv->bnhv', attn, v)
77+
result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj)
78+
result = self.proj_drop(result)
79+
return result.reshape(s)
80+
81+
82+
class MultiQueryAttention2d(nn.Module):
83+
"""Multi Query Attention with spatial downsampling.
84+
85+
3 parameters are introduced for the spatial downsampling:
86+
1. kv_stride: downsampling factor on Key and Values only.
87+
2. query_strides: horizontal & vertical strides on Query only.
88+
89+
This is an optimized version.
90+
1. Projections in Attention is explict written out as 1x1 Conv2D.
91+
2. Additional reshapes are introduced to bring a up to 3x speed up.
92+
"""
93+
fused_attn: torch.jit.Final[bool]
94+
95+
def __init__(
96+
self,
97+
dim: int,
98+
dim_out: Optional[int] = None,
99+
num_heads: int = 8,
100+
key_dim: Optional[int] = None,
101+
value_dim: Optional[int] = None,
102+
query_strides: int = 1,
103+
kv_stride: int = 1,
104+
dw_kernel_size: int = 3,
105+
dilation: int = 1,
106+
padding: Union[str, int, List[int]] = '',
107+
attn_drop: float = 0.,
108+
proj_drop: float = 0.,
109+
norm_layer: nn.Module = nn.BatchNorm2d,
110+
use_bias: bool = False,
111+
):
112+
"""Initializer.
113+
114+
Args:
115+
num_heads: Number of attention heads.
116+
key_dim: Size of the attention key dimension.
117+
value_dim: Size of the attention value dimension.
118+
query_strides: Vertical stride size for query only.
119+
kv_stride: Key and value stride size.
120+
dw_kernel_size: Spatial dimension of the depthwise kernel.
121+
"""
122+
super().__init__()
123+
dim_out = dim_out or dim
124+
self.num_heads = num_heads
125+
self.key_dim = key_dim or dim // num_heads
126+
self.value_dim = value_dim or dim // num_heads
127+
self.query_strides = to_2tuple(query_strides)
128+
self.kv_stride = kv_stride
129+
self.has_query_strides = any([s > 1 for s in self.query_strides])
130+
self.scale = self.key_dim ** -0.5
131+
self.fused_attn = use_fused_attn()
132+
self.drop = attn_drop
133+
134+
self.query = nn.Sequential()
135+
if self.has_query_strides:
136+
# FIXME dilation
137+
self.query.add_module('down_pool', create_pool2d(
138+
'avg',
139+
kernel_size=self.query_strides,
140+
padding=padding,
141+
))
142+
self.query.add_module('norm', norm_layer(dim))
143+
self.query.add_module('proj', create_conv2d(
144+
dim,
145+
self.num_heads * self.key_dim,
146+
kernel_size=1,
147+
bias=use_bias,
148+
))
149+
150+
self.key = nn.Sequential()
151+
if kv_stride > 1:
152+
self.key.add_module('down_conv', create_conv2d(
153+
dim,
154+
dim,
155+
kernel_size=dw_kernel_size,
156+
stride=kv_stride,
157+
dilation=dilation,
158+
padding=padding,
159+
depthwise=True,
160+
))
161+
self.key.add_module('norm', norm_layer(dim))
162+
self.key.add_module('proj', create_conv2d(
163+
dim,
164+
self.key_dim,
165+
kernel_size=1,
166+
padding=padding,
167+
bias=use_bias,
168+
))
169+
170+
self.value = nn.Sequential()
171+
if kv_stride > 1:
172+
self.value.add_module('down_conv', create_conv2d(
173+
dim,
174+
dim,
175+
kernel_size=dw_kernel_size,
176+
stride=kv_stride,
177+
dilation=dilation,
178+
padding=padding,
179+
depthwise=True,
180+
))
181+
self.value.add_module('norm', norm_layer(dim))
182+
self.value.add_module('proj', create_conv2d(
183+
dim,
184+
self.value_dim,
185+
kernel_size=1,
186+
bias=use_bias,
187+
))
188+
189+
self.attn_drop = nn.Dropout(attn_drop)
190+
191+
self.output = nn.Sequential()
192+
if self.has_query_strides:
193+
self.output.add_module('upsample', nn.Upsample(self.query_strides, mode='bilinear', align_corners=False))
194+
self.output.add_module('proj', create_conv2d(
195+
self.value_dim * self.num_heads,
196+
dim_out,
197+
kernel_size=1,
198+
bias=use_bias,
199+
))
200+
self.output.add_module('drop', nn.Dropout(proj_drop))
201+
202+
self.einsum = False
203+
204+
def _reshape_input(self, t: torch.Tensor):
205+
"""Reshapes a tensor to three dimensions, keeping the batch and channels."""
206+
s = t.shape
207+
t = t.reshape(s[0], s[1], -1).transpose(1, 2)
208+
if self.einsum:
209+
return t
210+
else:
211+
return t.unsqueeze(1).contiguous()
212+
213+
def _reshape_projected_query(self, t: torch.Tensor, num_heads: int, key_dim: int):
214+
"""Reshapes projected query: [b, n, n, h x k] -> [b, n x n, h, k]."""
215+
s = t.shape
216+
t = t.reshape(s[0], num_heads, key_dim, -1)
217+
if self.einsum:
218+
return t.permute(0, 3, 1, 2).contiguous()
219+
else:
220+
return t.transpose(-1, -2).contiguous()
221+
222+
def _reshape_output(self, t: torch.Tensor, num_heads: int, h_px: int, w_px: int):
223+
"""Reshape output:[b, n x n x h, k] -> [b, n, n, hk]."""
224+
s = t.shape
225+
feat_dim = s[-1] * num_heads
226+
if not self.einsum:
227+
t = t.transpose(1, 2)
228+
return t.reshape(s[0], h_px, w_px, feat_dim).permute(0, 3, 1, 2).contiguous()
229+
230+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
231+
"""Run layer computation."""
232+
B, C, H, W = s = x.shape
233+
234+
q = self.query(x)
235+
# desired q shape: [b, h, k, n x n] - [b, l, h, k]
236+
q = self._reshape_projected_query(q, self.num_heads, self.key_dim)
237+
238+
k = self.key(x)
239+
# output shape of k: [b, k, p], p = m x m
240+
k = self._reshape_input(k)
241+
242+
v = self.value(x)
243+
# output shape of v: [ b, p, k], p = m x m
244+
v = self._reshape_input(v)
245+
246+
# desired q shape: [b, n x n, h, k]
247+
# desired k shape: [b, m x m, k]
248+
# desired logits shape: [b, n x n, h, m x m]
249+
if self.einsum:
250+
attn = torch.einsum('blhk,bpk->blhp', q, k) * self.scale
251+
if attn_mask is not None:
252+
# NOTE: assumes mask is float and in correct shape
253+
attn = attn + attn_mask
254+
attn = attn.softmax(dim=-1)
255+
attn = self.attn_drop(attn)
256+
o = torch.einsum('blhp,bpk->blhk', attn, v)
257+
else:
258+
if self.fused_attn:
259+
o = F.scaled_dot_product_attention(
260+
q, k, v,
261+
attn_mask=attn_mask,
262+
dropout_p=self.attn_drop.p if self.training else 0.
263+
)
264+
else:
265+
q = q * self.scale
266+
attn = q @ k.transpose(-1, -2)
267+
if attn_mask is not None:
268+
# NOTE: assumes mask is float and in correct shape
269+
attn = attn + attn_mask
270+
attn = attn.softmax(dim=-1)
271+
attn = self.attn_drop(attn)
272+
o = attn @ v
273+
274+
# reshape o into [b, hk, n, n,]
275+
o = self._reshape_output(o, self.num_heads, H // self.query_strides[0], W // self.query_strides[1])
276+
x = self.output(o)
277+
return x
278+
279+
280+
class Attention2d(nn.Module):
281+
fused_attn: torch.jit.Final[bool]
282+
283+
""" multi-head attention for 2D NCHW tensors"""
284+
def __init__(
285+
self,
286+
dim: int,
287+
dim_out: Optional[int] = None,
288+
num_heads: int = 32,
289+
bias: bool = True,
290+
expand_first: bool = False,
291+
head_first: bool = False,
292+
attn_drop: float = 0.,
293+
proj_drop: float = 0.
294+
):
295+
super().__init__()
296+
dim_out = dim_out or dim
297+
dim_attn = dim_out if expand_first else dim
298+
self.num_heads = num_heads
299+
self.dim_head = dim_attn // num_heads
300+
self.head_first = head_first
301+
self.scale = num_heads ** -0.5
302+
self.fused_attn = use_fused_attn()
303+
304+
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
305+
self.attn_drop = nn.Dropout(attn_drop)
306+
self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias)
307+
self.proj_drop = nn.Dropout(proj_drop)
308+
309+
def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
310+
B, C, H, W = x.shape
311+
312+
if self.head_first:
313+
q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)
314+
else:
315+
q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
316+
317+
if self.fused_attn:
318+
x = torch.nn.functional.scaled_dot_product_attention(
319+
q.transpose(-1, -2).contiguous(),
320+
k.transpose(-1, -2).contiguous(),
321+
v.transpose(-1, -2).contiguous(),
322+
attn_mask=attn_mask,
323+
dropout_p=self.attn_drop.p if self.training else 0.,
324+
).transpose(-1, -2).reshape(B, -1, H, W)
325+
else:
326+
q = q * self.scale
327+
attn = q.transpose(-2, -1) @ k
328+
if attn_mask is not None:
329+
# NOTE: assumes mask is float and in correct shape
330+
attn = attn + attn_mask
331+
attn = attn.softmax(dim=-1)
332+
attn = self.attn_drop(attn)
333+
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
334+
335+
x = self.proj(x)
336+
x = self.proj_drop(x)
337+
return x

0 commit comments

Comments
 (0)