|
11 | 11 | # adapted from https://github.com/HobbitLong/SupContrast
|
12 | 12 | # modified for multi-supcon
|
13 | 13 | class MultiSupConLoss(GenericPairLoss):
|
| 14 | + """ |
| 15 | + Args: |
| 16 | + num_classes: number of classes |
| 17 | + temperature: temperature for scaling the similarity matrix |
| 18 | + threshold: threshold for jaccard similarity |
| 19 | + |
| 20 | + Inputs: |
| 21 | + embeddings: tensor of size (batch_size, embedding_size) |
| 22 | + labels: tensor of size (batch_size, num_classes) |
| 23 | + each row is a binary vector of size num_classes that only has 1s for the positive |
| 24 | + labels, and 0s for the negative labels |
| 25 | + indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix) |
| 26 | + or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix) |
| 27 | + Can also be left as None |
| 28 | + ref_emb: tensor of size (batch_size, embedding_size) |
| 29 | + """ |
14 | 30 | def __init__(self, num_classes, temperature=0.1, threshold=0.3, **kwargs):
|
15 | 31 | super().__init__(mat_based_loss=True, **kwargs)
|
16 | 32 | self.temperature = temperature
|
@@ -77,10 +93,13 @@ def forward(
|
77 | 93 | """
|
78 | 94 | Args:
|
79 | 95 | embeddings: tensor of size (batch_size, embedding_size)
|
80 |
| - labels: tensor of size (batch_size) |
81 |
| - indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives) |
82 |
| - or size 4 for pairs (anchor1, postives, anchor2, negatives) |
| 96 | + labels: tensor of size (batch_size, num_classes) |
| 97 | + each row is a binary vector of size num_classes that only has 1s for the positive |
| 98 | + labels, and 0s for the negative labels |
| 99 | + indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix) |
| 100 | + or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix) |
83 | 101 | Can also be left as None
|
| 102 | + ref_emb: tensor of size (batch_size, embedding_size) |
84 | 103 | Returns: the loss
|
85 | 104 | """
|
86 | 105 | self.reset_stats()
|
|
0 commit comments