Skip to content

Commit 855baa6

Browse files
committed
Add docstring
1 parent 949a45f commit 855baa6

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

src/pytorch_metric_learning/losses/multilabel_supcon_loss.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@
1111
# adapted from https://github.com/HobbitLong/SupContrast
1212
# modified for multi-supcon
1313
class MultiSupConLoss(GenericPairLoss):
14+
"""
15+
Args:
16+
num_classes: number of classes
17+
temperature: temperature for scaling the similarity matrix
18+
threshold: threshold for jaccard similarity
19+
20+
Inputs:
21+
embeddings: tensor of size (batch_size, embedding_size)
22+
labels: tensor of size (batch_size, num_classes)
23+
each row is a binary vector of size num_classes that only has 1s for the positive
24+
labels, and 0s for the negative labels
25+
indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix)
26+
or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix)
27+
Can also be left as None
28+
ref_emb: tensor of size (batch_size, embedding_size)
29+
"""
1430
def __init__(self, num_classes, temperature=0.1, threshold=0.3, **kwargs):
1531
super().__init__(mat_based_loss=True, **kwargs)
1632
self.temperature = temperature
@@ -77,10 +93,13 @@ def forward(
7793
"""
7894
Args:
7995
embeddings: tensor of size (batch_size, embedding_size)
80-
labels: tensor of size (batch_size)
81-
indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives)
82-
or size 4 for pairs (anchor1, postives, anchor2, negatives)
96+
labels: tensor of size (batch_size, num_classes)
97+
each row is a binary vector of size num_classes that only has 1s for the positive
98+
labels, and 0s for the negative labels
99+
indices_tuple: tuple of size 4 for triplets (anchors, positives, negatives, jaccard_matrix)
100+
or size 5 for pairs (anchor1, postives, anchor2, negativesm, jaccard_matrix)
83101
Can also be left as None
102+
ref_emb: tensor of size (batch_size, embedding_size)
84103
Returns: the loss
85104
"""
86105
self.reset_stats()

0 commit comments

Comments
 (0)