@@ -34,7 +34,14 @@ class ConvPosEnc(nn.Module):
34
34
def __init__ (self , dim : int , k : int = 3 , act : bool = False ):
35
35
super (ConvPosEnc , self ).__init__ ()
36
36
37
- self .proj = nn .Conv2d (dim , dim , k , 1 , k // 2 , groups = dim )
37
+ self .proj = nn .Conv2d (
38
+ dim ,
39
+ dim ,
40
+ kernel_size = k ,
41
+ stride = 1 ,
42
+ padding = k // 2 ,
43
+ groups = dim ,
44
+ )
38
45
self .act = nn .GELU () if act else nn .Identity ()
39
46
40
47
def forward (self , x : Tensor ):
@@ -72,8 +79,9 @@ def __init__(
72
79
73
80
def forward (self , x : Tensor ):
74
81
B , C , H , W = x .shape
75
- x = F .pad (x , (0 , (self .stride [1 ] - W % self .stride [1 ]) % self .stride [1 ]))
76
- x = F .pad (x , (0 , 0 , 0 , (self .stride [0 ] - H % self .stride [0 ]) % self .stride [0 ]))
82
+ pad_r = (self .stride [1 ] - W % self .stride [1 ]) % self .stride [1 ]
83
+ pad_b = (self .stride [0 ] - H % self .stride [0 ]) % self .stride [0 ]
84
+ x = F .pad (x , (0 , pad_r , 0 , pad_b ))
77
85
x = self .conv (x )
78
86
x = self .norm (x )
79
87
return x
@@ -84,30 +92,66 @@ def __init__(
84
92
self ,
85
93
in_chs ,
86
94
out_chs ,
95
+ kernel_size = 3 ,
87
96
norm_layer = LayerNorm2d ,
88
97
):
89
98
super ().__init__ ()
90
99
self .in_chs = in_chs
91
100
self .out_chs = out_chs
92
101
93
102
self .norm = norm_layer (in_chs )
103
+ self .even_k = kernel_size % 2 == 0
94
104
self .conv = nn .Conv2d (
95
105
in_chs ,
96
106
out_chs ,
97
- kernel_size = 2 ,
107
+ kernel_size = kernel_size ,
98
108
stride = 2 ,
99
- padding = 0 ,
109
+ padding = 0 if self . even_k else kernel_size // 2 ,
100
110
)
101
111
102
112
def forward (self , x : Tensor ):
103
113
B , C , H , W = x .shape
104
114
x = self .norm (x )
105
- x = F .pad (x , (0 , (2 - W % 2 ) % 2 ))
106
- x = F .pad (x , (0 , 0 , 0 , (2 - H % 2 ) % 2 ))
115
+ if self .even_k :
116
+ k_h , k_w = self .conv .kernel_size
117
+ pad_r = (k_w - W % k_w ) % k_w
118
+ pad_b = (k_h - H % k_h ) % k_h
119
+ x = F .pad (x , (0 , pad_r , 0 , pad_b ))
107
120
x = self .conv (x )
108
121
return x
109
122
110
123
124
+ class ChannelAttentionV2 (nn .Module ):
125
+
126
+ def __init__ (self , dim , num_heads = 8 , qkv_bias = True , dynamic_scale = True ):
127
+ super ().__init__ ()
128
+ self .groups = num_heads
129
+ self .head_dim = dim // num_heads
130
+ self .dynamic_scale = dynamic_scale
131
+
132
+ self .qkv = nn .Linear (dim , dim * 3 , bias = qkv_bias )
133
+ self .proj = nn .Linear (dim , dim )
134
+
135
+ def forward (self , x ):
136
+ B , N , C = x .shape
137
+
138
+ qkv = self .qkv (x ).reshape (B , N , 3 , self .groups , C // self .groups ).permute (2 , 0 , 3 , 1 , 4 )
139
+ q , k , v = qkv .unbind (0 )
140
+
141
+ if self .dynamic_scale :
142
+ q = q * N ** - 0.5
143
+ else :
144
+ q = q * self .head_dim ** - 0.5
145
+ attn = q .transpose (- 1 , - 2 ) @ k
146
+ attn = attn .softmax (dim = - 1 )
147
+ x = (attn @ v .transpose (- 1 , - 2 )).transpose (- 1 , - 2 )
148
+
149
+ x = x .transpose (1 , 2 ).reshape (B , N , C )
150
+ x = self .proj (x )
151
+ return x
152
+
153
+
154
+
111
155
class ChannelAttention (nn .Module ):
112
156
113
157
def __init__ (self , dim , num_heads = 8 , qkv_bias = False ):
@@ -147,13 +191,19 @@ def __init__(
147
191
norm_layer = nn .LayerNorm ,
148
192
ffn = True ,
149
193
cpe_act = False ,
194
+ v2 = False ,
150
195
):
151
196
super ().__init__ ()
152
197
153
198
self .cpe1 = ConvPosEnc (dim = dim , k = 3 , act = cpe_act )
154
199
self .ffn = ffn
155
200
self .norm1 = norm_layer (dim )
156
- self .attn = ChannelAttention (dim , num_heads = num_heads , qkv_bias = qkv_bias )
201
+ attn_layer = ChannelAttentionV2 if v2 else ChannelAttention
202
+ self .attn = attn_layer (
203
+ dim ,
204
+ num_heads = num_heads ,
205
+ qkv_bias = qkv_bias ,
206
+ )
157
207
self .drop_path1 = DropPath (drop_path ) if drop_path > 0. else nn .Identity ()
158
208
self .cpe2 = ConvPosEnc (dim = dim , k = 3 , act = cpe_act )
159
209
@@ -372,21 +422,24 @@ def __init__(
372
422
attn_types = ('spatial' , 'channel' ),
373
423
num_heads = 3 ,
374
424
window_size = 7 ,
375
- mlp_ratio = 4 ,
425
+ mlp_ratio = 4. ,
376
426
qkv_bias = True ,
377
427
drop_path_rates = (0 , 0 ),
378
428
norm_layer = LayerNorm2d ,
379
429
norm_layer_cl = nn .LayerNorm ,
380
430
ffn = True ,
381
- cpe_act = False
431
+ cpe_act = False ,
432
+ down_kernel_size = 2 ,
433
+ named_blocks = False ,
434
+ channel_attn_v2 = False ,
382
435
):
383
436
super ().__init__ ()
384
437
385
438
self .grad_checkpointing = False
386
439
387
440
# downsample embedding layer at the beginning of each stage
388
441
if downsample :
389
- self .downsample = Downsample (in_chs , out_chs , norm_layer = norm_layer )
442
+ self .downsample = Downsample (in_chs , out_chs , kernel_size = down_kernel_size , norm_layer = norm_layer )
390
443
else :
391
444
self .downsample = nn .Identity ()
392
445
@@ -399,10 +452,11 @@ def __init__(
399
452
'''
400
453
stage_blocks = []
401
454
for block_idx in range (depth ):
455
+ from collections import OrderedDict
402
456
dual_attention_block = []
403
457
for attn_idx , attn_type in enumerate (attn_types ):
404
458
if attn_type == 'spatial' :
405
- dual_attention_block .append (SpatialBlock (
459
+ dual_attention_block .append (( 'spatial_block' , SpatialBlock (
406
460
dim = out_chs ,
407
461
num_heads = num_heads ,
408
462
mlp_ratio = mlp_ratio ,
@@ -412,19 +466,23 @@ def __init__(
412
466
ffn = ffn ,
413
467
cpe_act = cpe_act ,
414
468
window_size = window_size ,
415
- ))
469
+ )))
416
470
elif attn_type == 'channel' :
417
- dual_attention_block .append (ChannelBlock (
471
+ dual_attention_block .append (( 'channel_block' , ChannelBlock (
418
472
dim = out_chs ,
419
473
num_heads = num_heads ,
420
474
mlp_ratio = mlp_ratio ,
421
475
qkv_bias = qkv_bias ,
422
476
drop_path = drop_path_rates [block_idx ],
423
477
norm_layer = norm_layer_cl ,
424
478
ffn = ffn ,
425
- cpe_act = cpe_act
426
- ))
427
- stage_blocks .append (nn .Sequential (* dual_attention_block ))
479
+ cpe_act = cpe_act ,
480
+ v2 = channel_attn_v2 ,
481
+ )))
482
+ if named_blocks :
483
+ stage_blocks .append (nn .Sequential (OrderedDict (dual_attention_block )))
484
+ else :
485
+ stage_blocks .append (nn .Sequential (* [b [1 ] for b in dual_attention_block ]))
428
486
self .blocks = nn .Sequential (* stage_blocks )
429
487
430
488
@torch .jit .ignore
@@ -473,6 +531,9 @@ def __init__(
473
531
attn_types = ('spatial' , 'channel' ),
474
532
ffn = True ,
475
533
cpe_act = False ,
534
+ down_kernel_size = 2 ,
535
+ channel_attn_v2 = False ,
536
+ named_blocks = False ,
476
537
drop_rate = 0. ,
477
538
drop_path_rate = 0. ,
478
539
num_classes = 1000 ,
@@ -512,6 +573,9 @@ def __init__(
512
573
norm_layer_cl = norm_layer_cl ,
513
574
ffn = ffn ,
514
575
cpe_act = cpe_act ,
576
+ down_kernel_size = down_kernel_size ,
577
+ channel_attn_v2 = channel_attn_v2 ,
578
+ named_blocks = named_blocks ,
515
579
)
516
580
in_chs = out_chs
517
581
stages .append (stage )
@@ -589,6 +653,34 @@ def forward(self, x):
589
653
return x
590
654
591
655
656
+ def _convert_florence2 (state_dict , model , prefix = 'vision_tower.' ):
657
+ import re
658
+ out_dict = {}
659
+
660
+ for k , v in state_dict .items ():
661
+ if k .startswith (prefix ):
662
+ k = k .replace (prefix , '' )
663
+ else :
664
+ continue
665
+ k = re .sub (r'convs.([0-9]+)' , r'stages.\1.downsample' , k )
666
+ k = re .sub (r'blocks.([0-9]+)' , r'stages.\1.blocks' , k )
667
+ k = k .replace ('downsample.proj' , 'downsample.conv' )
668
+ k = k .replace ('stages.0.downsample' , 'stem' )
669
+ #k = k.replace('head.', 'head.fc.')
670
+ #k = k.replace('norms.', 'head.norm.')
671
+ k = k .replace ('window_attn.norm.' , 'norm1.' )
672
+ k = k .replace ('window_attn.fn.' , 'attn.' )
673
+ k = k .replace ('channel_attn.norm.' , 'norm1.' )
674
+ k = k .replace ('channel_attn.fn.' , 'attn.' )
675
+ k = k .replace ('ffn.norm.' , 'norm2.' )
676
+ k = k .replace ('ffn.fn.net.' , 'mlp.' )
677
+ k = k .replace ('conv1.fn.dw' , 'cpe1.proj' )
678
+ k = k .replace ('conv2.fn.dw' , 'cpe2.proj' )
679
+ out_dict [k ] = v
680
+
681
+ return out_dict
682
+
683
+
592
684
def checkpoint_filter_fn (state_dict , model ):
593
685
""" Remap MSFT checkpoints -> timm """
594
686
if 'head.fc.weight' in state_dict :
@@ -597,6 +689,9 @@ def checkpoint_filter_fn(state_dict, model):
597
689
if 'state_dict' in state_dict :
598
690
state_dict = state_dict ['state_dict' ]
599
691
692
+ if 'vision_tower.convs.0.proj.weight' in state_dict :
693
+ return _convert_florence2 (state_dict , model )
694
+
600
695
import re
601
696
out_dict = {}
602
697
for k , v in state_dict .items ():
@@ -615,13 +710,17 @@ def checkpoint_filter_fn(state_dict, model):
615
710
def _create_davit (variant , pretrained = False , ** kwargs ):
616
711
default_out_indices = tuple (i for i , _ in enumerate (kwargs .get ('depths' , (1 , 1 , 3 , 1 ))))
617
712
out_indices = kwargs .pop ('out_indices' , default_out_indices )
618
-
713
+ strict = True
714
+ if variant .endswith ('_fl' ):
715
+ # FIXME cleaner approach to missing head norm?
716
+ strict = False
619
717
model = build_model_with_cfg (
620
718
DaVit ,
621
719
variant ,
622
720
pretrained ,
623
721
pretrained_filter_fn = checkpoint_filter_fn ,
624
722
feature_cfg = dict (flatten_sequential = True , out_indices = out_indices ),
723
+ pretrained_strict = strict ,
625
724
** kwargs )
626
725
627
726
return model
@@ -650,6 +749,12 @@ def _cfg(url='', **kwargs):
650
749
'davit_large' : _cfg (),
651
750
'davit_huge' : _cfg (),
652
751
'davit_giant' : _cfg (),
752
+ 'davit_base_fl.msft_florence2' : _cfg (
753
+ hf_hub_id = 'microsoft/Florence-2-base' ,
754
+ num_classes = 0 , input_size = (3 , 768 , 768 )),
755
+ 'davit_huge_fl.msft_florence2' : _cfg (
756
+ hf_hub_id = 'microsoft/Florence-2-large' ,
757
+ num_classes = 0 , input_size = (3 , 768 , 768 )),
653
758
})
654
759
655
760
@@ -687,3 +792,23 @@ def davit_huge(pretrained=False, **kwargs) -> DaVit:
687
792
def davit_giant (pretrained = False , ** kwargs ) -> DaVit :
688
793
model_args = dict (depths = (1 , 1 , 12 , 3 ), embed_dims = (384 , 768 , 1536 , 3072 ), num_heads = (12 , 24 , 48 , 96 ))
689
794
return _create_davit ('davit_giant' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
795
+
796
+
797
+
798
+ @register_model
799
+ def davit_base_fl (pretrained = False , ** kwargs ) -> DaVit :
800
+ model_args = dict (
801
+ depths = (1 , 1 , 9 , 1 ), embed_dims = (128 , 256 , 512 , 1024 ), num_heads = (4 , 8 , 16 , 32 ),
802
+ window_size = 12 , down_kernel_size = 3 , channel_attn_v2 = True , named_blocks = True ,
803
+ )
804
+ return _create_davit ('davit_base_fl' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
805
+
806
+
807
+ @register_model
808
+ def davit_huge_fl (pretrained = False , ** kwargs ) -> DaVit :
809
+ # NOTE: huge image tower used in 'large' Florence2 model
810
+ model_args = dict (
811
+ depths = (1 , 1 , 9 , 1 ), embed_dims = (256 , 512 , 1024 , 2048 ), num_heads = (8 , 16 , 32 , 64 ),
812
+ window_size = 12 , down_kernel_size = 3 , channel_attn_v2 = True , named_blocks = True ,
813
+ )
814
+ return _create_davit ('davit_huge_fl' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments