@@ -21,7 +21,7 @@ def make_head_layer(cnv_dim, curr_dim, out_dim, head_name=None):
21
21
# nn.BatchNorm2d(curr_dim, eps=1e-3, momentum=0.01),
22
22
nn .ReLU (inplace = True ),
23
23
nn .Conv2d (curr_dim , out_dim , kernel_size = 3 , stride = 1 , padding = 1 ),
24
- ) # kernel=1, padding=0, bias=True
24
+ )
25
25
26
26
for l in fc .modules ():
27
27
if isinstance (l , nn .Conv2d ):
@@ -72,7 +72,6 @@ def forward(self, x0, x1, x0_mask=None, x1_mask=None, flag=False):
72
72
if x0_mask != None and x1_mask != None :
73
73
x0_mask , x1_mask = x0_mask .flatten (- 2 ), x1_mask .flatten (- 2 )
74
74
75
- save_feat = []
76
75
if flag is False :
77
76
for i , (layer , name ) in enumerate (zip (self .layers , self .layer_names )):
78
77
if name == "self" :
@@ -85,9 +84,6 @@ def forward(self, x0, x1, x0_mask=None, x1_mask=None, flag=False):
85
84
raise KeyError
86
85
x0 = layer (x0 , src0 , x0_mask , src0_mask )
87
86
x1 = layer (x1 , src1 , x1_mask , src1_mask )
88
- if i == 1 : # i==len(self.layer_names)//2-1:
89
- # print(i, len(self.layer_names))
90
- save_feat .append ((x0 , x1 ))
91
87
elif flag == 1 : # origin
92
88
for layer , name in zip (self .layers , self .layer_names ):
93
89
if name == "self" :
@@ -109,11 +105,7 @@ def forward(self, x0, x1, x0_mask=None, x1_mask=None, flag=False):
109
105
else :
110
106
raise KeyError
111
107
112
- # return feat0, feat1
113
- if len (save_feat ) > 0 :
114
- return x0 , x1 , save_feat
115
- else :
116
- return x0 , x1
108
+ return x0 , x1
117
109
118
110
119
111
class SegmentationModule (nn .Module ):
@@ -129,22 +121,15 @@ def __init__(self, d_model, num_query):
129
121
def forward (self , x , hs , mask = None ):
130
122
# x:[n, 256, h, w] hs:[n, num_q, 256]
131
123
132
- # TODO: BN
133
124
if mask is not None :
134
- # hs = self.encoderlayer(hs, x3_flatten, None, mask_flatten)
135
125
attn_mask = torch .einsum ("mqc,mchw->mqhw" , hs , x )
136
- # attn_mask = self.bn(attn_mask)
137
- # attn_mask = attn_mask * self.gamma
138
126
attn_mask = attn_mask .sigmoid () * mask .unsqueeze (1 )
139
127
classification = self .block (x * attn_mask + x ).sigmoid ().squeeze (1 ) * mask
140
128
else :
141
- # hs = self.encoderlayer(hs, x3_flatten)
142
129
attn_mask = torch .einsum ("mqc,mchw->mqhw" , hs , x )
143
- # attn_mask = self.bn(attn_mask)
144
- # attn_mask = attn_mask * self.gamma
145
130
attn_mask = attn_mask .sigmoid ()
146
131
classification = self .block (x * attn_mask + x ).sigmoid ().squeeze (1 )
147
- return classification # , attn_mask # , mask_feat
132
+ return classification
148
133
149
134
150
135
class FICAS (nn .Module ):
@@ -166,7 +151,7 @@ def __init__(self, layer_num=4, d_model=256):
166
151
self .layer_names1 = [
167
152
"self" ,
168
153
"cross" ,
169
- ] # ['self', 'cross', 'cross'] # ['self', 'cross'] origin for eccv
154
+ ]
170
155
self .layers1 = nn .ModuleList (
171
156
[copy .deepcopy (encoder_layer ) for _ in range (len (self .layer_names1 ))]
172
157
)
@@ -186,7 +171,7 @@ def __init__(self, layer_num=4, d_model=256):
186
171
self .layer_names3 = [
187
172
"self" ,
188
173
"cross" ,
189
- ] # ['self', 'cross', 'cross'] # ['self', 'cross'] origin for eccv
174
+ ]
190
175
self .layers3 = nn .ModuleList (
191
176
[copy .deepcopy (encoder_layer ) for _ in range (len (self .layer_names3 ))]
192
177
)
@@ -216,8 +201,7 @@ def transformer(self, x0, x1, x0_mask, x1_mask, layer_name, layer):
216
201
and src1_mask is not None
217
202
and not self .training
218
203
and 0
219
- ): # \
220
- # and layer_name == 'self' and 0:
204
+ ):
221
205
temp_x = layer (
222
206
torch .cat ([x0 , x1 ], dim = 0 ),
223
207
torch .cat ([src0 , src1 ], dim = 0 ),
@@ -252,8 +236,7 @@ def feature_interaction(self, x0, x1, x0_mask=None, x1_mask=None):
252
236
feature_embed1 = self .feature_embed .weight .unsqueeze (0 ).repeat (bs , 1 , 1 )
253
237
tgt0 = torch .zeros_like (feature_embed0 )
254
238
tgt1 = torch .zeros_like (feature_embed1 )
255
- # hs0 = self.decoder(tgt0, x0, tgt_mask=None, memory_mask=x0_mask)
256
- # hs1 = self.decoder(tgt1, x1, tgt_mask=None, memory_mask=x1_mask)
239
+
257
240
if (
258
241
0
259
242
): # x0.shape==x1.shape and x0_mask is not None and x0_mask.shape==x1_mask.shape:
@@ -331,10 +314,7 @@ def forward(self, x0, x1, x0_mask=None, x1_mask=None, use_cas=True):
331
314
out0 , out1 , hs0 , hs1 , x0_mid , x1_mid = self .feature_interaction (
332
315
x0 , x1 , x0_mask , x1_mask
333
316
)
334
- # out0 = rearrange(out0, 'n (h w) c -> n c h w',
335
- # h=h0, w=w0).contiguous()
336
- # out1 = rearrange(out1, 'n (h w) c -> n c h w',
337
- # h=h1, w=w1).contiguous()
317
+
338
318
if use_cas :
339
319
x0_mid = rearrange (x0_mid , "n (h w) c -> n c h w" , h = h0 , w = w0 ).contiguous ()
340
320
x1_mid = rearrange (x1_mid , "n (h w) c -> n c h w" , h = h1 , w = w1 ).contiguous ()
0 commit comments