Skip to content

Commit 7dbda29

Browse files
authored
Update dualvit.py
1 parent add9bec commit 7dbda29

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

classification/dualvit.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -674,10 +674,19 @@ def forward_sep(self, x, tokenlabeling=False, H=0, W=0):
674674
C = x.shape[-1]
675675
if i == 0:
676676
x_down = self.pool(x.reshape(B, H, W, C).permute(0, 3, 1, 2))
677+
x_down_H, x_down_W = x_down.shape[2:]
677678
x_down = x_down.view(B, C, -1).permute(0, 2, 1)
678679
kv = self.kv(x_down).view(B, -1, 2, C).permute(2, 0, 1, 3)
679680
k, v = kv[0], kv[1] # B, N, C
680-
attn = (self.q_embed(self.q) @ k.transpose(-1, -2)) * self.scale # q: 1, M, C, k: B, N, C -> B, M, N
681+
682+
if x_down.shape[1] == self.q.shape[0]:
683+
self_q = self.q
684+
else:
685+
self_q = self.q.reshape(8, 8, -1).permute(2, 0, 1)
686+
self_q = F.interpolate(self_q.unsqueeze(0), size=(x_down_H, x_down_W), mode='bicubic').squeeze(0).permute(1, 2, 0)
687+
self_q = self_q.reshape(-1, self_q.shape[-1])
688+
689+
attn = (self.q_embed(self_q) @ k.transpose(-1, -2)) * self.scale # q: 1, M, C, k: B, N, C -> B, M, N
681690
attn = attn.softmax(-1) # B, M, N
682691
semantics = attn @ v # B, M, C
683692
semantics = semantics.view(B, -1, C)
@@ -803,7 +812,7 @@ def dualvit_s(pretrained=False, **kwargs):
803812
stem_hidden_dim = 32,
804813
embed_dims = [64, 128, 320, 448],
805814
num_heads = [2, 4, 10, 14],
806-
mlp_ratios = [8, 8, 4, 3],
815+
mlp_ratios = [8, 8, 4, 3, 2],
807816
norm_layer = partial(nn.LayerNorm, eps=1e-6),
808817
depths = [3, 4, 6, 3],
809818
**kwargs)
@@ -816,7 +825,7 @@ def dualvit_b(pretrained=False, **kwargs):
816825
stem_hidden_dim = 64,
817826
embed_dims = [64, 128, 320, 512],
818827
num_heads = [2, 4, 10, 16],
819-
mlp_ratios = [8, 8, 4, 3],
828+
mlp_ratios = [8, 8, 4, 3, 2],
820829
norm_layer = partial(nn.LayerNorm, eps=1e-6),
821830
depths = [3, 4, 15, 3],
822831
**kwargs)
@@ -829,7 +838,7 @@ def dualvit_l(pretrained=False, **kwargs):
829838
stem_hidden_dim = 64,
830839
embed_dims = [96, 192, 384, 512],
831840
num_heads = [3, 6, 12, 16],
832-
mlp_ratios = [8, 8, 4, 3],
841+
mlp_ratios = [8, 8, 4, 3, 2],
833842
norm_layer = partial(nn.LayerNorm, eps=1e-6),
834843
depths = [3, 6, 21, 3],
835844
**kwargs)

0 commit comments

Comments
 (0)