Skip to content

Commit 949a45f

Browse files
committed
Add multi-supcon
Add cross-batch memory for multi-supcon Add test cases
1 parent 4c6f2ca commit 949a45f

File tree

5 files changed

+413
-274
lines changed

5 files changed

+413
-274
lines changed

src/pytorch_metric_learning/losses/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,4 @@
3535
from .triplet_margin_loss import TripletMarginLoss
3636
from .tuplet_margin_loss import TupletMarginLoss
3737
from .vicreg_loss import VICRegLoss
38-
from .multilabel_supcon_loss import MultiSupConLoss
39-
from .xbm_multilabel import CrossBatchMemory4MultiLabel
38+
from .multilabel_supcon_loss import MultiSupConLoss, CrossBatchMemory4MultiLabel

src/pytorch_metric_learning/losses/multilabel_supcon_loss.py

Lines changed: 283 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,22 @@
33
from ..distances import CosineSimilarity
44
from ..reducers import AvgNonZeroReducer
55
from ..utils import common_functions as c_f
6-
from ..utils import multilabel_loss_and_miner_utils as mlmu
76
from ..utils import loss_and_miner_utils as lmu
7+
from ..utils.module_with_records import ModuleWithRecords
88
from .generic_pair_loss import GenericPairLoss
9-
9+
from .base_loss_wrapper import BaseLossWrapper
1010

1111
# adapted from https://github.com/HobbitLong/SupContrast
12+
# modified for multi-supcon
1213
class MultiSupConLoss(GenericPairLoss):
13-
def __init__(self, num_classes, temperature=0.1, **kwargs):
14+
def __init__(self, num_classes, temperature=0.1, threshold=0.3, **kwargs):
1415
super().__init__(mat_based_loss=True, **kwargs)
1516
self.temperature = temperature
1617
self.add_to_recordable_attributes(list_of_names=["temperature"], is_stat=False)
1718
self.num_classes = num_classes
19+
self.threshold = threshold
1820

19-
def _compute_loss(self, mat, pos_mask, neg_mask):
21+
def _compute_loss(self, mat, pos_mask, neg_mask, multi_val):
2022
if pos_mask.bool().any() and neg_mask.bool().any():
2123
# if dealing with actual distances, use negative distances
2224
if not self.distance.is_inverted:
@@ -29,7 +31,7 @@ def _compute_loss(self, mat, pos_mask, neg_mask):
2931
mat, keep_mask=(pos_mask + neg_mask).bool(), add_one=False, dim=1
3032
)
3133
log_prob = mat - denominator
32-
mean_log_prob_pos = (pos_mask * log_prob).sum(dim=1) / (
34+
mean_log_prob_pos = (multi_val * log_prob * pos_mask).sum(dim=1) / (
3335
pos_mask.sum(dim=1) + c_f.small_val(mat.dtype)
3436
)
3537

@@ -48,16 +50,22 @@ def get_default_reducer(self):
4850
def get_default_distance(self):
4951
return CosineSimilarity()
5052

53+
# ==== class methods below are overriden for adaptability to multi-supcon ====
54+
5155
def mat_based_loss(self, mat, indices_tuple):
52-
a1, p, a2, n = indices_tuple
56+
a1, p, a2, n, jaccard_mat = indices_tuple
5357
pos_mask, neg_mask = torch.zeros_like(mat), torch.zeros_like(mat)
5458
pos_mask[a1, p] = 1
5559
neg_mask[a2, n] = 1
56-
return self._compute_loss(mat, pos_mask, neg_mask)
60+
return self._compute_loss(mat, pos_mask, neg_mask, jaccard_mat)
5761

5862
def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
5963
c_f.labels_or_indices_tuple_required(labels, indices_tuple)
60-
indices_tuple = mlmu.convert_to_pairs(indices_tuple, labels, self.num_classes, ref_labels, device=embeddings.device)
64+
indices_tuple = convert_to_pairs(
65+
indices_tuple,
66+
labels,
67+
ref_labels,
68+
threshold=self.threshold)
6169
if all(len(x) <= 1 for x in indices_tuple):
6270
return self.zero_losses()
6371
mat = self.distance(embeddings, ref_emb)
@@ -76,11 +84,276 @@ def forward(
7684
Returns: the loss
7785
"""
7886
self.reset_stats()
79-
mlmu.check_shapes_multilabels(embeddings, labels)
80-
ref_emb, ref_labels = mlmu.set_ref_emb(embeddings, labels, ref_emb, ref_labels)
87+
check_shapes_multilabels(embeddings, labels)
88+
ref_emb, ref_labels = set_ref_emb(embeddings, labels, ref_emb, ref_labels)
8189
loss_dict = self.compute_loss(
8290
embeddings, labels, indices_tuple, ref_emb, ref_labels
8391
)
8492
self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings)
8593
return self.reducer(loss_dict, embeddings, labels)
8694

95+
# =========================================================================
96+
97+
98+
# ================== cross batch memory for multi-supcon ==================
99+
class CrossBatchMemory4MultiLabel(BaseLossWrapper, ModuleWithRecords):
100+
def __init__(self, loss, embedding_size, memory_size=1024, miner=None, **kwargs):
101+
super().__init__(loss=loss, **kwargs)
102+
self.loss = loss
103+
self.miner = miner
104+
self.embedding_size = embedding_size
105+
self.memory_size = memory_size
106+
self.num_classes = loss.num_classes
107+
self.reset_queue()
108+
self.add_to_recordable_attributes(
109+
list_of_names=["embedding_size", "memory_size", "queue_idx"], is_stat=False
110+
)
111+
112+
@staticmethod
113+
def supported_losses():
114+
return [
115+
"MultiSupConLoss"
116+
]
117+
118+
@classmethod
119+
def check_loss_support(cls, loss_name):
120+
if loss_name not in cls.supported_losses():
121+
raise Exception(f"CrossBatchMemory not supported for {loss_name}")
122+
123+
def forward(self, embeddings, labels, indices_tuple=None, enqueue_mask=None):
124+
if indices_tuple is not None and enqueue_mask is not None:
125+
raise ValueError("indices_tuple and enqueue_mask are mutually exclusive")
126+
if enqueue_mask is not None:
127+
assert len(enqueue_mask) == len(embeddings)
128+
else:
129+
assert len(embeddings) <= len(self.embedding_memory)
130+
self.reset_stats()
131+
device = embeddings.device
132+
labels = c_f.to_device(labels, device=device)
133+
self.embedding_memory = c_f.to_device(
134+
self.embedding_memory, device=device, dtype=embeddings.dtype
135+
)
136+
self.label_memory = c_f.to_device(
137+
self.label_memory, device=device, dtype=labels.dtype
138+
)
139+
140+
if enqueue_mask is not None:
141+
emb_for_queue = embeddings[enqueue_mask]
142+
labels_for_queue = labels[enqueue_mask]
143+
embeddings = embeddings[~enqueue_mask]
144+
labels = labels[~enqueue_mask]
145+
do_remove_self_comparisons = False
146+
else:
147+
emb_for_queue = embeddings
148+
labels_for_queue = labels
149+
do_remove_self_comparisons = True
150+
151+
queue_batch_size = len(emb_for_queue)
152+
self.add_to_memory(emb_for_queue, labels_for_queue, queue_batch_size)
153+
154+
if not self.has_been_filled:
155+
E_mem = self.embedding_memory[: self.queue_idx]
156+
L_mem = self.label_memory[: self.queue_idx]
157+
else:
158+
E_mem = self.embedding_memory
159+
L_mem = self.label_memory
160+
161+
indices_tuple = self.create_indices_tuple(
162+
embeddings,
163+
labels,
164+
E_mem,
165+
L_mem,
166+
indices_tuple,
167+
do_remove_self_comparisons,
168+
)
169+
loss = self.loss(embeddings, labels, indices_tuple, E_mem, L_mem)
170+
return loss
171+
172+
def add_to_memory(self, embeddings, labels, batch_size):
173+
self.curr_batch_idx = (
174+
torch.arange(
175+
self.queue_idx, self.queue_idx + batch_size, device=labels.device
176+
)
177+
% self.memory_size
178+
)
179+
self.embedding_memory[self.curr_batch_idx] = embeddings.detach()
180+
self.label_memory[self.curr_batch_idx] = labels.detach()
181+
prev_queue_idx = self.queue_idx
182+
self.queue_idx = (self.queue_idx + batch_size) % self.memory_size
183+
if (not self.has_been_filled) and (self.queue_idx <= prev_queue_idx):
184+
self.has_been_filled = True
185+
186+
def create_indices_tuple(
187+
self,
188+
embeddings,
189+
labels,
190+
E_mem,
191+
L_mem,
192+
input_indices_tuple,
193+
do_remove_self_comparisons,
194+
):
195+
if self.miner:
196+
indices_tuple = self.miner(embeddings, labels, E_mem, L_mem)
197+
else:
198+
indices_tuple = get_all_pairs_indices(labels, L_mem)
199+
200+
if do_remove_self_comparisons:
201+
indices_tuple = remove_self_comparisons(
202+
indices_tuple, self.curr_batch_idx, self.memory_size
203+
)
204+
205+
if input_indices_tuple is not None:
206+
if len(input_indices_tuple) == 3 and len(indices_tuple) == 4:
207+
input_indices_tuple = convert_to_pairs(input_indices_tuple, labels)
208+
elif len(input_indices_tuple) == 4 and len(indices_tuple) == 3:
209+
input_indices_tuple = convert_to_triplets(
210+
input_indices_tuple, labels
211+
)
212+
indices_tuple = c_f.concatenate_indices_tuples(
213+
indices_tuple, input_indices_tuple
214+
)
215+
216+
return indices_tuple
217+
218+
def reset_queue(self):
219+
self.register_buffer(
220+
"embedding_memory", torch.zeros(self.memory_size, self.embedding_size)
221+
)
222+
self.register_buffer(
223+
"label_memory", torch.zeros(self.memory_size, self.num_classes)
224+
)
225+
self.has_been_filled = False
226+
self.queue_idx = 0
227+
228+
# =========================================================================
229+
230+
# compute jaccard similarity
231+
def jaccard(labels, ref_labels=None):
232+
if ref_labels is None:
233+
ref_labels = labels
234+
235+
labels1 = labels.float()
236+
labels2 = ref_labels.float()
237+
238+
# compute jaccard similarity
239+
# jaccard = intersection / union
240+
labels1_union = labels1.sum(-1)
241+
labels2_union = labels2.sum(-1)
242+
union = labels1_union.unsqueeze(1) + labels2_union.unsqueeze(0)
243+
intersection = torch.mm(labels1, labels2.T)
244+
jaccard_matrix = intersection / (union - intersection)
245+
246+
# return indices of jaccard similarity above threshold
247+
return jaccard_matrix
248+
249+
# ====== methods below are overriden for adaptability to multi-supcon ======
250+
251+
# use jaccard similarity to get matches
252+
def get_matches_and_diffs(labels, ref_labels=None, threshold=0.3):
253+
if ref_labels is None:
254+
ref_labels = labels
255+
jaccard_matrix = jaccard(labels, ref_labels)
256+
matches = torch.where(jaccard_matrix > threshold, 1, 0)
257+
diffs = matches ^ 1
258+
if ref_labels is labels:
259+
matches.fill_diagonal_(0)
260+
return matches, diffs, jaccard_matrix
261+
262+
def check_shapes_multilabels(embeddings, labels):
263+
if labels is not None and embeddings.shape[0] != labels.shape[0]:
264+
raise ValueError("Number of embeddings must equal number of labels")
265+
if labels is not None and labels.ndim != 2:
266+
raise ValueError("labels must be a 1D tensor of shape (batch_size,)")
267+
268+
269+
def set_ref_emb(embeddings, labels, ref_emb, ref_labels):
270+
if ref_emb is None:
271+
ref_emb, ref_labels = embeddings, labels
272+
check_shapes_multilabels(ref_emb, ref_labels)
273+
return ref_emb, ref_labels
274+
275+
276+
def convert_to_pairs(indices_tuple, labels, ref_labels=None, threshold=0.3):
277+
"""
278+
This returns anchor-positive and anchor-negative indices,
279+
regardless of what the input indices_tuple is
280+
Args:
281+
indices_tuple: tuple of tensors. Each tensor is 1d and specifies indices
282+
within a batch
283+
labels: a tensor which has the label for each element in a batch
284+
"""
285+
if indices_tuple is None:
286+
return get_all_pairs_indices(labels, ref_labels, threshold=threshold)
287+
elif len(indices_tuple) == 5:
288+
return indices_tuple
289+
else:
290+
a, p, n, jaccard_mat = indices_tuple
291+
return a, p, a, n,jaccard_mat
292+
293+
294+
def get_all_pairs_indices(labels, ref_labels=None, threshold=0.3):
295+
"""
296+
Given a tensor of labels, this will return 4 tensors.
297+
The first 2 tensors are the indices which form all positive pairs
298+
The second 2 tensors are the indices which form all negative pairs
299+
"""
300+
matches, diffs, multi_val = get_matches_and_diffs(labels, ref_labels, threshold=threshold)
301+
a1_idx, p_idx = torch.where(matches)
302+
a2_idx, n_idx = torch.where(diffs)
303+
return a1_idx, p_idx, a2_idx, n_idx, multi_val
304+
305+
306+
def convert_to_triplets(indices_tuple, labels, ref_labels=None, t_per_anchor=100):
307+
"""
308+
This returns anchor-positive-negative triplets
309+
regardless of what the input indices_tuple is
310+
"""
311+
if indices_tuple is None:
312+
if t_per_anchor == "all":
313+
return get_all_triplets_indices(labels, ref_labels)
314+
else:
315+
return lmu.get_random_triplet_indices(
316+
labels, ref_labels, t_per_anchor=t_per_anchor
317+
)
318+
elif len(indices_tuple) == 3:
319+
return indices_tuple
320+
else:
321+
a1, p, a2, n = indices_tuple
322+
p_idx, n_idx = torch.where(a1.unsqueeze(1) == a2)
323+
return a1[p_idx], p[p_idx], n[n_idx]
324+
325+
326+
def get_all_triplets_indices(labels, ref_labels=None):
327+
matches, diffs = get_matches_and_diffs(labels, ref_labels)
328+
triplets = matches.unsqueeze(2) * diffs.unsqueeze(1)
329+
return torch.where(triplets)
330+
331+
332+
def remove_self_comparisons(
333+
indices_tuple, curr_batch_idx, ref_size, ref_is_subset=False
334+
):
335+
# remove self-comparisons
336+
assert len(indices_tuple) in [4, 5]
337+
s, e = curr_batch_idx[0], curr_batch_idx[-1]
338+
if len(indices_tuple) == 4:
339+
a, p, n, jaccard_mat = indices_tuple
340+
keep_mask = lmu.not_self_comparisons(
341+
a, p, s, e, curr_batch_idx, ref_size, ref_is_subset
342+
)
343+
a = a[keep_mask]
344+
p = p[keep_mask]
345+
n = n[keep_mask]
346+
assert len(a) == len(p) == len(n)
347+
return a, p, n, jaccard_mat
348+
elif len(indices_tuple) == 5:
349+
a1, p, a2, n, jaccard_mat = indices_tuple
350+
keep_mask = lmu.not_self_comparisons(
351+
a1, p, s, e, curr_batch_idx, ref_size, ref_is_subset
352+
)
353+
a1 = a1[keep_mask]
354+
p = p[keep_mask]
355+
assert len(a1) == len(p)
356+
assert len(a2) == len(n)
357+
return a1, p, a2, n, jaccard_mat
358+
359+
# =========================================================================

0 commit comments

Comments
 (0)