@@ -674,10 +674,19 @@ def forward_sep(self, x, tokenlabeling=False, H=0, W=0):
674
674
C = x .shape [- 1 ]
675
675
if i == 0 :
676
676
x_down = self .pool (x .reshape (B , H , W , C ).permute (0 , 3 , 1 , 2 ))
677
+ x_down_H , x_down_W = x_down .shape [2 :]
677
678
x_down = x_down .view (B , C , - 1 ).permute (0 , 2 , 1 )
678
679
kv = self .kv (x_down ).view (B , - 1 , 2 , C ).permute (2 , 0 , 1 , 3 )
679
680
k , v = kv [0 ], kv [1 ] # B, N, C
680
- attn = (self .q_embed (self .q ) @ k .transpose (- 1 , - 2 )) * self .scale # q: 1, M, C, k: B, N, C -> B, M, N
681
+
682
+ if x_down .shape [1 ] == self .q .shape [0 ]:
683
+ self_q = self .q
684
+ else :
685
+ self_q = self .q .reshape (8 , 8 , - 1 ).permute (2 , 0 , 1 )
686
+ self_q = F .interpolate (self_q .unsqueeze (0 ), size = (x_down_H , x_down_W ), mode = 'bicubic' ).squeeze (0 ).permute (1 , 2 , 0 )
687
+ self_q = self_q .reshape (- 1 , self_q .shape [- 1 ])
688
+
689
+ attn = (self .q_embed (self_q ) @ k .transpose (- 1 , - 2 )) * self .scale # q: 1, M, C, k: B, N, C -> B, M, N
681
690
attn = attn .softmax (- 1 ) # B, M, N
682
691
semantics = attn @ v # B, M, C
683
692
semantics = semantics .view (B , - 1 , C )
@@ -803,7 +812,7 @@ def dualvit_s(pretrained=False, **kwargs):
803
812
stem_hidden_dim = 32 ,
804
813
embed_dims = [64 , 128 , 320 , 448 ],
805
814
num_heads = [2 , 4 , 10 , 14 ],
806
- mlp_ratios = [8 , 8 , 4 , 3 ],
815
+ mlp_ratios = [8 , 8 , 4 , 3 , 2 ],
807
816
norm_layer = partial (nn .LayerNorm , eps = 1e-6 ),
808
817
depths = [3 , 4 , 6 , 3 ],
809
818
** kwargs )
@@ -816,7 +825,7 @@ def dualvit_b(pretrained=False, **kwargs):
816
825
stem_hidden_dim = 64 ,
817
826
embed_dims = [64 , 128 , 320 , 512 ],
818
827
num_heads = [2 , 4 , 10 , 16 ],
819
- mlp_ratios = [8 , 8 , 4 , 3 ],
828
+ mlp_ratios = [8 , 8 , 4 , 3 , 2 ],
820
829
norm_layer = partial (nn .LayerNorm , eps = 1e-6 ),
821
830
depths = [3 , 4 , 15 , 3 ],
822
831
** kwargs )
@@ -829,7 +838,7 @@ def dualvit_l(pretrained=False, **kwargs):
829
838
stem_hidden_dim = 64 ,
830
839
embed_dims = [96 , 192 , 384 , 512 ],
831
840
num_heads = [3 , 6 , 12 , 16 ],
832
- mlp_ratios = [8 , 8 , 4 , 3 ],
841
+ mlp_ratios = [8 , 8 , 4 , 3 , 2 ],
833
842
norm_layer = partial (nn .LayerNorm , eps = 1e-6 ),
834
843
depths = [3 , 6 , 21 , 3 ],
835
844
** kwargs )
0 commit comments