3
3
from ..distances import CosineSimilarity
4
4
from ..reducers import AvgNonZeroReducer
5
5
from ..utils import common_functions as c_f
6
- from ..utils import multilabel_loss_and_miner_utils as mlmu
7
6
from ..utils import loss_and_miner_utils as lmu
7
+ from ..utils .module_with_records import ModuleWithRecords
8
8
from .generic_pair_loss import GenericPairLoss
9
-
9
+ from . base_loss_wrapper import BaseLossWrapper
10
10
11
11
# adapted from https://github.com/HobbitLong/SupContrast
12
+ # modified for multi-supcon
12
13
class MultiSupConLoss (GenericPairLoss ):
13
- def __init__ (self , num_classes , temperature = 0.1 , ** kwargs ):
14
+ def __init__ (self , num_classes , temperature = 0.1 , threshold = 0.3 , ** kwargs ):
14
15
super ().__init__ (mat_based_loss = True , ** kwargs )
15
16
self .temperature = temperature
16
17
self .add_to_recordable_attributes (list_of_names = ["temperature" ], is_stat = False )
17
18
self .num_classes = num_classes
19
+ self .threshold = threshold
18
20
19
- def _compute_loss (self , mat , pos_mask , neg_mask ):
21
+ def _compute_loss (self , mat , pos_mask , neg_mask , multi_val ):
20
22
if pos_mask .bool ().any () and neg_mask .bool ().any ():
21
23
# if dealing with actual distances, use negative distances
22
24
if not self .distance .is_inverted :
@@ -29,7 +31,7 @@ def _compute_loss(self, mat, pos_mask, neg_mask):
29
31
mat , keep_mask = (pos_mask + neg_mask ).bool (), add_one = False , dim = 1
30
32
)
31
33
log_prob = mat - denominator
32
- mean_log_prob_pos = (pos_mask * log_prob ).sum (dim = 1 ) / (
34
+ mean_log_prob_pos = (multi_val * log_prob * pos_mask ).sum (dim = 1 ) / (
33
35
pos_mask .sum (dim = 1 ) + c_f .small_val (mat .dtype )
34
36
)
35
37
@@ -48,16 +50,22 @@ def get_default_reducer(self):
48
50
def get_default_distance (self ):
49
51
return CosineSimilarity ()
50
52
53
+ # ==== class methods below are overriden for adaptability to multi-supcon ====
54
+
51
55
def mat_based_loss (self , mat , indices_tuple ):
52
- a1 , p , a2 , n = indices_tuple
56
+ a1 , p , a2 , n , jaccard_mat = indices_tuple
53
57
pos_mask , neg_mask = torch .zeros_like (mat ), torch .zeros_like (mat )
54
58
pos_mask [a1 , p ] = 1
55
59
neg_mask [a2 , n ] = 1
56
- return self ._compute_loss (mat , pos_mask , neg_mask )
60
+ return self ._compute_loss (mat , pos_mask , neg_mask , jaccard_mat )
57
61
58
62
def compute_loss (self , embeddings , labels , indices_tuple , ref_emb , ref_labels ):
59
63
c_f .labels_or_indices_tuple_required (labels , indices_tuple )
60
- indices_tuple = mlmu .convert_to_pairs (indices_tuple , labels , self .num_classes , ref_labels , device = embeddings .device )
64
+ indices_tuple = convert_to_pairs (
65
+ indices_tuple ,
66
+ labels ,
67
+ ref_labels ,
68
+ threshold = self .threshold )
61
69
if all (len (x ) <= 1 for x in indices_tuple ):
62
70
return self .zero_losses ()
63
71
mat = self .distance (embeddings , ref_emb )
@@ -76,11 +84,276 @@ def forward(
76
84
Returns: the loss
77
85
"""
78
86
self .reset_stats ()
79
- mlmu . check_shapes_multilabels (embeddings , labels )
80
- ref_emb , ref_labels = mlmu . set_ref_emb (embeddings , labels , ref_emb , ref_labels )
87
+ check_shapes_multilabels (embeddings , labels )
88
+ ref_emb , ref_labels = set_ref_emb (embeddings , labels , ref_emb , ref_labels )
81
89
loss_dict = self .compute_loss (
82
90
embeddings , labels , indices_tuple , ref_emb , ref_labels
83
91
)
84
92
self .add_embedding_regularization_to_loss_dict (loss_dict , embeddings )
85
93
return self .reducer (loss_dict , embeddings , labels )
86
94
95
+ # =========================================================================
96
+
97
+
98
+ # ================== cross batch memory for multi-supcon ==================
99
+ class CrossBatchMemory4MultiLabel (BaseLossWrapper , ModuleWithRecords ):
100
+ def __init__ (self , loss , embedding_size , memory_size = 1024 , miner = None , ** kwargs ):
101
+ super ().__init__ (loss = loss , ** kwargs )
102
+ self .loss = loss
103
+ self .miner = miner
104
+ self .embedding_size = embedding_size
105
+ self .memory_size = memory_size
106
+ self .num_classes = loss .num_classes
107
+ self .reset_queue ()
108
+ self .add_to_recordable_attributes (
109
+ list_of_names = ["embedding_size" , "memory_size" , "queue_idx" ], is_stat = False
110
+ )
111
+
112
+ @staticmethod
113
+ def supported_losses ():
114
+ return [
115
+ "MultiSupConLoss"
116
+ ]
117
+
118
+ @classmethod
119
+ def check_loss_support (cls , loss_name ):
120
+ if loss_name not in cls .supported_losses ():
121
+ raise Exception (f"CrossBatchMemory not supported for { loss_name } " )
122
+
123
+ def forward (self , embeddings , labels , indices_tuple = None , enqueue_mask = None ):
124
+ if indices_tuple is not None and enqueue_mask is not None :
125
+ raise ValueError ("indices_tuple and enqueue_mask are mutually exclusive" )
126
+ if enqueue_mask is not None :
127
+ assert len (enqueue_mask ) == len (embeddings )
128
+ else :
129
+ assert len (embeddings ) <= len (self .embedding_memory )
130
+ self .reset_stats ()
131
+ device = embeddings .device
132
+ labels = c_f .to_device (labels , device = device )
133
+ self .embedding_memory = c_f .to_device (
134
+ self .embedding_memory , device = device , dtype = embeddings .dtype
135
+ )
136
+ self .label_memory = c_f .to_device (
137
+ self .label_memory , device = device , dtype = labels .dtype
138
+ )
139
+
140
+ if enqueue_mask is not None :
141
+ emb_for_queue = embeddings [enqueue_mask ]
142
+ labels_for_queue = labels [enqueue_mask ]
143
+ embeddings = embeddings [~ enqueue_mask ]
144
+ labels = labels [~ enqueue_mask ]
145
+ do_remove_self_comparisons = False
146
+ else :
147
+ emb_for_queue = embeddings
148
+ labels_for_queue = labels
149
+ do_remove_self_comparisons = True
150
+
151
+ queue_batch_size = len (emb_for_queue )
152
+ self .add_to_memory (emb_for_queue , labels_for_queue , queue_batch_size )
153
+
154
+ if not self .has_been_filled :
155
+ E_mem = self .embedding_memory [: self .queue_idx ]
156
+ L_mem = self .label_memory [: self .queue_idx ]
157
+ else :
158
+ E_mem = self .embedding_memory
159
+ L_mem = self .label_memory
160
+
161
+ indices_tuple = self .create_indices_tuple (
162
+ embeddings ,
163
+ labels ,
164
+ E_mem ,
165
+ L_mem ,
166
+ indices_tuple ,
167
+ do_remove_self_comparisons ,
168
+ )
169
+ loss = self .loss (embeddings , labels , indices_tuple , E_mem , L_mem )
170
+ return loss
171
+
172
+ def add_to_memory (self , embeddings , labels , batch_size ):
173
+ self .curr_batch_idx = (
174
+ torch .arange (
175
+ self .queue_idx , self .queue_idx + batch_size , device = labels .device
176
+ )
177
+ % self .memory_size
178
+ )
179
+ self .embedding_memory [self .curr_batch_idx ] = embeddings .detach ()
180
+ self .label_memory [self .curr_batch_idx ] = labels .detach ()
181
+ prev_queue_idx = self .queue_idx
182
+ self .queue_idx = (self .queue_idx + batch_size ) % self .memory_size
183
+ if (not self .has_been_filled ) and (self .queue_idx <= prev_queue_idx ):
184
+ self .has_been_filled = True
185
+
186
+ def create_indices_tuple (
187
+ self ,
188
+ embeddings ,
189
+ labels ,
190
+ E_mem ,
191
+ L_mem ,
192
+ input_indices_tuple ,
193
+ do_remove_self_comparisons ,
194
+ ):
195
+ if self .miner :
196
+ indices_tuple = self .miner (embeddings , labels , E_mem , L_mem )
197
+ else :
198
+ indices_tuple = get_all_pairs_indices (labels , L_mem )
199
+
200
+ if do_remove_self_comparisons :
201
+ indices_tuple = remove_self_comparisons (
202
+ indices_tuple , self .curr_batch_idx , self .memory_size
203
+ )
204
+
205
+ if input_indices_tuple is not None :
206
+ if len (input_indices_tuple ) == 3 and len (indices_tuple ) == 4 :
207
+ input_indices_tuple = convert_to_pairs (input_indices_tuple , labels )
208
+ elif len (input_indices_tuple ) == 4 and len (indices_tuple ) == 3 :
209
+ input_indices_tuple = convert_to_triplets (
210
+ input_indices_tuple , labels
211
+ )
212
+ indices_tuple = c_f .concatenate_indices_tuples (
213
+ indices_tuple , input_indices_tuple
214
+ )
215
+
216
+ return indices_tuple
217
+
218
+ def reset_queue (self ):
219
+ self .register_buffer (
220
+ "embedding_memory" , torch .zeros (self .memory_size , self .embedding_size )
221
+ )
222
+ self .register_buffer (
223
+ "label_memory" , torch .zeros (self .memory_size , self .num_classes )
224
+ )
225
+ self .has_been_filled = False
226
+ self .queue_idx = 0
227
+
228
+ # =========================================================================
229
+
230
+ # compute jaccard similarity
231
+ def jaccard (labels , ref_labels = None ):
232
+ if ref_labels is None :
233
+ ref_labels = labels
234
+
235
+ labels1 = labels .float ()
236
+ labels2 = ref_labels .float ()
237
+
238
+ # compute jaccard similarity
239
+ # jaccard = intersection / union
240
+ labels1_union = labels1 .sum (- 1 )
241
+ labels2_union = labels2 .sum (- 1 )
242
+ union = labels1_union .unsqueeze (1 ) + labels2_union .unsqueeze (0 )
243
+ intersection = torch .mm (labels1 , labels2 .T )
244
+ jaccard_matrix = intersection / (union - intersection )
245
+
246
+ # return indices of jaccard similarity above threshold
247
+ return jaccard_matrix
248
+
249
+ # ====== methods below are overriden for adaptability to multi-supcon ======
250
+
251
+ # use jaccard similarity to get matches
252
+ def get_matches_and_diffs (labels , ref_labels = None , threshold = 0.3 ):
253
+ if ref_labels is None :
254
+ ref_labels = labels
255
+ jaccard_matrix = jaccard (labels , ref_labels )
256
+ matches = torch .where (jaccard_matrix > threshold , 1 , 0 )
257
+ diffs = matches ^ 1
258
+ if ref_labels is labels :
259
+ matches .fill_diagonal_ (0 )
260
+ return matches , diffs , jaccard_matrix
261
+
262
+ def check_shapes_multilabels (embeddings , labels ):
263
+ if labels is not None and embeddings .shape [0 ] != labels .shape [0 ]:
264
+ raise ValueError ("Number of embeddings must equal number of labels" )
265
+ if labels is not None and labels .ndim != 2 :
266
+ raise ValueError ("labels must be a 1D tensor of shape (batch_size,)" )
267
+
268
+
269
+ def set_ref_emb (embeddings , labels , ref_emb , ref_labels ):
270
+ if ref_emb is None :
271
+ ref_emb , ref_labels = embeddings , labels
272
+ check_shapes_multilabels (ref_emb , ref_labels )
273
+ return ref_emb , ref_labels
274
+
275
+
276
+ def convert_to_pairs (indices_tuple , labels , ref_labels = None , threshold = 0.3 ):
277
+ """
278
+ This returns anchor-positive and anchor-negative indices,
279
+ regardless of what the input indices_tuple is
280
+ Args:
281
+ indices_tuple: tuple of tensors. Each tensor is 1d and specifies indices
282
+ within a batch
283
+ labels: a tensor which has the label for each element in a batch
284
+ """
285
+ if indices_tuple is None :
286
+ return get_all_pairs_indices (labels , ref_labels , threshold = threshold )
287
+ elif len (indices_tuple ) == 5 :
288
+ return indices_tuple
289
+ else :
290
+ a , p , n , jaccard_mat = indices_tuple
291
+ return a , p , a , n ,jaccard_mat
292
+
293
+
294
+ def get_all_pairs_indices (labels , ref_labels = None , threshold = 0.3 ):
295
+ """
296
+ Given a tensor of labels, this will return 4 tensors.
297
+ The first 2 tensors are the indices which form all positive pairs
298
+ The second 2 tensors are the indices which form all negative pairs
299
+ """
300
+ matches , diffs , multi_val = get_matches_and_diffs (labels , ref_labels , threshold = threshold )
301
+ a1_idx , p_idx = torch .where (matches )
302
+ a2_idx , n_idx = torch .where (diffs )
303
+ return a1_idx , p_idx , a2_idx , n_idx , multi_val
304
+
305
+
306
+ def convert_to_triplets (indices_tuple , labels , ref_labels = None , t_per_anchor = 100 ):
307
+ """
308
+ This returns anchor-positive-negative triplets
309
+ regardless of what the input indices_tuple is
310
+ """
311
+ if indices_tuple is None :
312
+ if t_per_anchor == "all" :
313
+ return get_all_triplets_indices (labels , ref_labels )
314
+ else :
315
+ return lmu .get_random_triplet_indices (
316
+ labels , ref_labels , t_per_anchor = t_per_anchor
317
+ )
318
+ elif len (indices_tuple ) == 3 :
319
+ return indices_tuple
320
+ else :
321
+ a1 , p , a2 , n = indices_tuple
322
+ p_idx , n_idx = torch .where (a1 .unsqueeze (1 ) == a2 )
323
+ return a1 [p_idx ], p [p_idx ], n [n_idx ]
324
+
325
+
326
+ def get_all_triplets_indices (labels , ref_labels = None ):
327
+ matches , diffs = get_matches_and_diffs (labels , ref_labels )
328
+ triplets = matches .unsqueeze (2 ) * diffs .unsqueeze (1 )
329
+ return torch .where (triplets )
330
+
331
+
332
+ def remove_self_comparisons (
333
+ indices_tuple , curr_batch_idx , ref_size , ref_is_subset = False
334
+ ):
335
+ # remove self-comparisons
336
+ assert len (indices_tuple ) in [4 , 5 ]
337
+ s , e = curr_batch_idx [0 ], curr_batch_idx [- 1 ]
338
+ if len (indices_tuple ) == 4 :
339
+ a , p , n , jaccard_mat = indices_tuple
340
+ keep_mask = lmu .not_self_comparisons (
341
+ a , p , s , e , curr_batch_idx , ref_size , ref_is_subset
342
+ )
343
+ a = a [keep_mask ]
344
+ p = p [keep_mask ]
345
+ n = n [keep_mask ]
346
+ assert len (a ) == len (p ) == len (n )
347
+ return a , p , n , jaccard_mat
348
+ elif len (indices_tuple ) == 5 :
349
+ a1 , p , a2 , n , jaccard_mat = indices_tuple
350
+ keep_mask = lmu .not_self_comparisons (
351
+ a1 , p , s , e , curr_batch_idx , ref_size , ref_is_subset
352
+ )
353
+ a1 = a1 [keep_mask ]
354
+ p = p [keep_mask ]
355
+ assert len (a1 ) == len (p )
356
+ assert len (a2 ) == len (n )
357
+ return a1 , p , a2 , n , jaccard_mat
358
+
359
+ # =========================================================================
0 commit comments