Skip to content

Commit f8342a0

Browse files
authored
Merge pull request #2213 from huggingface/florence2
Fix #2212 map florence2 image tower to davit with a few changes
2 parents e7b4ab6 + 02d0f27 commit f8342a0

File tree

1 file changed

+143
-18
lines changed

1 file changed

+143
-18
lines changed

timm/models/davit.py

Lines changed: 143 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,14 @@ class ConvPosEnc(nn.Module):
3434
def __init__(self, dim: int, k: int = 3, act: bool = False):
3535
super(ConvPosEnc, self).__init__()
3636

37-
self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
37+
self.proj = nn.Conv2d(
38+
dim,
39+
dim,
40+
kernel_size=k,
41+
stride=1,
42+
padding=k // 2,
43+
groups=dim,
44+
)
3845
self.act = nn.GELU() if act else nn.Identity()
3946

4047
def forward(self, x: Tensor):
@@ -72,8 +79,9 @@ def __init__(
7279

7380
def forward(self, x: Tensor):
7481
B, C, H, W = x.shape
75-
x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1]))
76-
x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0]))
82+
pad_r = (self.stride[1] - W % self.stride[1]) % self.stride[1]
83+
pad_b = (self.stride[0] - H % self.stride[0]) % self.stride[0]
84+
x = F.pad(x, (0, pad_r, 0, pad_b))
7785
x = self.conv(x)
7886
x = self.norm(x)
7987
return x
@@ -84,30 +92,66 @@ def __init__(
8492
self,
8593
in_chs,
8694
out_chs,
95+
kernel_size=3,
8796
norm_layer=LayerNorm2d,
8897
):
8998
super().__init__()
9099
self.in_chs = in_chs
91100
self.out_chs = out_chs
92101

93102
self.norm = norm_layer(in_chs)
103+
self.even_k = kernel_size % 2 == 0
94104
self.conv = nn.Conv2d(
95105
in_chs,
96106
out_chs,
97-
kernel_size=2,
107+
kernel_size=kernel_size,
98108
stride=2,
99-
padding=0,
109+
padding=0 if self.even_k else kernel_size // 2,
100110
)
101111

102112
def forward(self, x: Tensor):
103113
B, C, H, W = x.shape
104114
x = self.norm(x)
105-
x = F.pad(x, (0, (2 - W % 2) % 2))
106-
x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2))
115+
if self.even_k:
116+
k_h, k_w = self.conv.kernel_size
117+
pad_r = (k_w - W % k_w) % k_w
118+
pad_b = (k_h - H % k_h) % k_h
119+
x = F.pad(x, (0, pad_r , 0, pad_b))
107120
x = self.conv(x)
108121
return x
109122

110123

124+
class ChannelAttentionV2(nn.Module):
125+
126+
def __init__(self, dim, num_heads=8, qkv_bias=True, dynamic_scale=True):
127+
super().__init__()
128+
self.groups = num_heads
129+
self.head_dim = dim // num_heads
130+
self.dynamic_scale = dynamic_scale
131+
132+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
133+
self.proj = nn.Linear(dim, dim)
134+
135+
def forward(self, x):
136+
B, N, C = x.shape
137+
138+
qkv = self.qkv(x).reshape(B, N, 3, self.groups, C // self.groups).permute(2, 0, 3, 1, 4)
139+
q, k, v = qkv.unbind(0)
140+
141+
if self.dynamic_scale:
142+
q = q * N ** -0.5
143+
else:
144+
q = q * self.head_dim ** -0.5
145+
attn = q.transpose(-1, -2) @ k
146+
attn = attn.softmax(dim=-1)
147+
x = (attn @ v.transpose(-1, -2)).transpose(-1, -2)
148+
149+
x = x.transpose(1, 2).reshape(B, N, C)
150+
x = self.proj(x)
151+
return x
152+
153+
154+
111155
class ChannelAttention(nn.Module):
112156

113157
def __init__(self, dim, num_heads=8, qkv_bias=False):
@@ -147,13 +191,19 @@ def __init__(
147191
norm_layer=nn.LayerNorm,
148192
ffn=True,
149193
cpe_act=False,
194+
v2=False,
150195
):
151196
super().__init__()
152197

153198
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
154199
self.ffn = ffn
155200
self.norm1 = norm_layer(dim)
156-
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
201+
attn_layer = ChannelAttentionV2 if v2 else ChannelAttention
202+
self.attn = attn_layer(
203+
dim,
204+
num_heads=num_heads,
205+
qkv_bias=qkv_bias,
206+
)
157207
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
158208
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
159209

@@ -372,21 +422,24 @@ def __init__(
372422
attn_types=('spatial', 'channel'),
373423
num_heads=3,
374424
window_size=7,
375-
mlp_ratio=4,
425+
mlp_ratio=4.,
376426
qkv_bias=True,
377427
drop_path_rates=(0, 0),
378428
norm_layer=LayerNorm2d,
379429
norm_layer_cl=nn.LayerNorm,
380430
ffn=True,
381-
cpe_act=False
431+
cpe_act=False,
432+
down_kernel_size=2,
433+
named_blocks=False,
434+
channel_attn_v2=False,
382435
):
383436
super().__init__()
384437

385438
self.grad_checkpointing = False
386439

387440
# downsample embedding layer at the beginning of each stage
388441
if downsample:
389-
self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer)
442+
self.downsample = Downsample(in_chs, out_chs, kernel_size=down_kernel_size, norm_layer=norm_layer)
390443
else:
391444
self.downsample = nn.Identity()
392445

@@ -399,10 +452,11 @@ def __init__(
399452
'''
400453
stage_blocks = []
401454
for block_idx in range(depth):
455+
from collections import OrderedDict
402456
dual_attention_block = []
403457
for attn_idx, attn_type in enumerate(attn_types):
404458
if attn_type == 'spatial':
405-
dual_attention_block.append(SpatialBlock(
459+
dual_attention_block.append(('spatial_block', SpatialBlock(
406460
dim=out_chs,
407461
num_heads=num_heads,
408462
mlp_ratio=mlp_ratio,
@@ -412,19 +466,23 @@ def __init__(
412466
ffn=ffn,
413467
cpe_act=cpe_act,
414468
window_size=window_size,
415-
))
469+
)))
416470
elif attn_type == 'channel':
417-
dual_attention_block.append(ChannelBlock(
471+
dual_attention_block.append(('channel_block', ChannelBlock(
418472
dim=out_chs,
419473
num_heads=num_heads,
420474
mlp_ratio=mlp_ratio,
421475
qkv_bias=qkv_bias,
422476
drop_path=drop_path_rates[block_idx],
423477
norm_layer=norm_layer_cl,
424478
ffn=ffn,
425-
cpe_act=cpe_act
426-
))
427-
stage_blocks.append(nn.Sequential(*dual_attention_block))
479+
cpe_act=cpe_act,
480+
v2=channel_attn_v2,
481+
)))
482+
if named_blocks:
483+
stage_blocks.append(nn.Sequential(OrderedDict(dual_attention_block)))
484+
else:
485+
stage_blocks.append(nn.Sequential(*[b[1] for b in dual_attention_block]))
428486
self.blocks = nn.Sequential(*stage_blocks)
429487

430488
@torch.jit.ignore
@@ -473,6 +531,9 @@ def __init__(
473531
attn_types=('spatial', 'channel'),
474532
ffn=True,
475533
cpe_act=False,
534+
down_kernel_size=2,
535+
channel_attn_v2=False,
536+
named_blocks=False,
476537
drop_rate=0.,
477538
drop_path_rate=0.,
478539
num_classes=1000,
@@ -512,6 +573,9 @@ def __init__(
512573
norm_layer_cl=norm_layer_cl,
513574
ffn=ffn,
514575
cpe_act=cpe_act,
576+
down_kernel_size=down_kernel_size,
577+
channel_attn_v2=channel_attn_v2,
578+
named_blocks=named_blocks,
515579
)
516580
in_chs = out_chs
517581
stages.append(stage)
@@ -589,6 +653,34 @@ def forward(self, x):
589653
return x
590654

591655

656+
def _convert_florence2(state_dict, model, prefix='vision_tower.'):
657+
import re
658+
out_dict = {}
659+
660+
for k, v in state_dict.items():
661+
if k.startswith(prefix):
662+
k = k.replace(prefix, '')
663+
else:
664+
continue
665+
k = re.sub(r'convs.([0-9]+)', r'stages.\1.downsample', k)
666+
k = re.sub(r'blocks.([0-9]+)', r'stages.\1.blocks', k)
667+
k = k.replace('downsample.proj', 'downsample.conv')
668+
k = k.replace('stages.0.downsample', 'stem')
669+
#k = k.replace('head.', 'head.fc.')
670+
#k = k.replace('norms.', 'head.norm.')
671+
k = k.replace('window_attn.norm.', 'norm1.')
672+
k = k.replace('window_attn.fn.', 'attn.')
673+
k = k.replace('channel_attn.norm.', 'norm1.')
674+
k = k.replace('channel_attn.fn.', 'attn.')
675+
k = k.replace('ffn.norm.', 'norm2.')
676+
k = k.replace('ffn.fn.net.', 'mlp.')
677+
k = k.replace('conv1.fn.dw', 'cpe1.proj')
678+
k = k.replace('conv2.fn.dw', 'cpe2.proj')
679+
out_dict[k] = v
680+
681+
return out_dict
682+
683+
592684
def checkpoint_filter_fn(state_dict, model):
593685
""" Remap MSFT checkpoints -> timm """
594686
if 'head.fc.weight' in state_dict:
@@ -597,6 +689,9 @@ def checkpoint_filter_fn(state_dict, model):
597689
if 'state_dict' in state_dict:
598690
state_dict = state_dict['state_dict']
599691

692+
if 'vision_tower.convs.0.proj.weight' in state_dict:
693+
return _convert_florence2(state_dict, model)
694+
600695
import re
601696
out_dict = {}
602697
for k, v in state_dict.items():
@@ -615,13 +710,17 @@ def checkpoint_filter_fn(state_dict, model):
615710
def _create_davit(variant, pretrained=False, **kwargs):
616711
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
617712
out_indices = kwargs.pop('out_indices', default_out_indices)
618-
713+
strict = True
714+
if variant.endswith('_fl'):
715+
# FIXME cleaner approach to missing head norm?
716+
strict = False
619717
model = build_model_with_cfg(
620718
DaVit,
621719
variant,
622720
pretrained,
623721
pretrained_filter_fn=checkpoint_filter_fn,
624722
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
723+
pretrained_strict=strict,
625724
**kwargs)
626725

627726
return model
@@ -650,6 +749,12 @@ def _cfg(url='', **kwargs):
650749
'davit_large': _cfg(),
651750
'davit_huge': _cfg(),
652751
'davit_giant': _cfg(),
752+
'davit_base_fl.msft_florence2': _cfg(
753+
hf_hub_id='microsoft/Florence-2-base',
754+
num_classes=0, input_size=(3, 768, 768)),
755+
'davit_huge_fl.msft_florence2': _cfg(
756+
hf_hub_id='microsoft/Florence-2-large',
757+
num_classes=0, input_size=(3, 768, 768)),
653758
})
654759

655760

@@ -687,3 +792,23 @@ def davit_huge(pretrained=False, **kwargs) -> DaVit:
687792
def davit_giant(pretrained=False, **kwargs) -> DaVit:
688793
model_args = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96))
689794
return _create_davit('davit_giant', pretrained=pretrained, **dict(model_args, **kwargs))
795+
796+
797+
798+
@register_model
799+
def davit_base_fl(pretrained=False, **kwargs) -> DaVit:
800+
model_args = dict(
801+
depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32),
802+
window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
803+
)
804+
return _create_davit('davit_base_fl', pretrained=pretrained, **dict(model_args, **kwargs))
805+
806+
807+
@register_model
808+
def davit_huge_fl(pretrained=False, **kwargs) -> DaVit:
809+
# NOTE: huge image tower used in 'large' Florence2 model
810+
model_args = dict(
811+
depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64),
812+
window_size=12, down_kernel_size=3, channel_attn_v2=True, named_blocks=True,
813+
)
814+
return _create_davit('davit_huge_fl', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)