diff --git a/CONTENTS.md b/CONTENTS.md
index 24a48b2f..3822239c 100644
--- a/CONTENTS.md
+++ b/CONTENTS.md
@@ -46,6 +46,7 @@
| [**ProxyNCALoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyncaloss) | [No Fuss Distance Metric Learning using Proxies](https://arxiv.org/pdf/1703.07464.pdf)
| [**RankedListLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#rankedlistloss) | [Ranked List Loss for Deep Metric Learning](https://arxiv.org/abs/1903.03238)
| [**SignalToNoiseRatioContrastiveLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#signaltonoiseratiocontrastiveloss) | [Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Yuan_Signal-To-Noise_Ratio_A_Robust_Distance_Metric_for_Deep_Metric_Learning_CVPR_2019_paper.pdf)
+| [**SmoothAPLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#smoothaploss) | [Smooth-AP: Smoothing the Path Towards Large-Scale Image Retrieval](https://arxiv.org/abs/2007.12163)
| [**SoftTripleLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#softtripleloss) | [SoftTriple Loss: Deep Metric Learning Without Triplet Sampling](http://openaccess.thecvf.com/content_ICCV_2019/papers/Qian_SoftTriple_Loss_Deep_Metric_Learning_Without_Triplet_Sampling_ICCV_2019_paper.pdf)
| [**SphereFaceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#spherefaceloss) | [SphereFace: Deep Hypersphere Embedding for Face Recognition](https://arxiv.org/pdf/1704.08063.pdf)
| [**SubCenterArcFaceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#subcenterarcfaceloss) | [Sub-center ArcFace: Boosting Face Recognition by Large-scale Noisy Web Faces](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123560715.pdf)
diff --git a/README.md b/README.md
index a704340f..5c1f1c42 100644
--- a/README.md
+++ b/README.md
@@ -18,6 +18,11 @@
## News
+**August 17**: v2.9.0
+- Added [SmoothAPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#smoothaploss).
+- Improved SubCenterArcFaceLoss and GenericPairLoss.
+- Thank you [ir2718](https://github.com/ir2718), [lucamarini22](https://github.com/lucamarini22), and [marcpaga](https://github.com/marcpaga).
+
**December 11**: v2.8.0
- Added the [Datasets](https://kevinmusgrave.github.io/pytorch-metric-learning/datasets) module for easy downloading of common datasets:
- [CUB200](https://kevinmusgrave.github.io/pytorch-metric-learning/datasets/#cub-200-2011)
@@ -26,10 +31,6 @@
- [Stanford Online Products](https://kevinmusgrave.github.io/pytorch-metric-learning/datasets/#stanfordonlineproducts)
- Thank you [ir2718](https://github.com/ir2718).
-**November 2**: v2.7.0
-- Added [ThresholdConsistentMarginLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#thresholdconsistentmarginloss).
-- Thank you [ir2718](https://github.com/ir2718).
-
## Documentation
- [**View the documentation here**](https://kevinmusgrave.github.io/pytorch-metric-learning/)
- [**View the installation instructions here**](https://github.com/KevinMusgrave/pytorch-metric-learning#installation)
@@ -231,7 +232,7 @@ Thanks to the contributors who made pull requests!
|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
- [HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss)
- [DynamicSoftMarginLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#dynamicsoftmarginloss)
- [RankedListLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#rankedlistloss) |
|[mlopezantequera](https://github.com/mlopezantequera) | - Made the [testers](https://kevinmusgrave.github.io/pytorch-metric-learning/testers) work on any combination of query and reference sets
- Made [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) work with arbitrary label comparisons |
|[cwkeam](https://github.com/cwkeam) | - [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss)
- [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss)
- Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/)
- BaseLossWrapper|
-| [ir2718](https://github.com/ir2718) | - [ThresholdConsistentMarginLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#thresholdconsistentmarginloss)
- The [Datasets](https://kevinmusgrave.github.io/pytorch-metric-learning/datasets) module |
+| [ir2718](https://github.com/ir2718) | - [ThresholdConsistentMarginLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#thresholdconsistentmarginloss)
- [SmoothAPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#smoothaploss)
- The [Datasets](https://kevinmusgrave.github.io/pytorch-metric-learning/datasets) module |
|[marijnl](https://github.com/marijnl)| - [BatchEasyHardMiner](https://kevinmusgrave.github.io/pytorch-metric-learning/miners/#batcheasyhardminer)
- [TwoStreamMetricLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/trainers/#twostreammetricloss)
- [GlobalTwoStreamEmbeddingSpaceTester](https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#globaltwostreamembeddingspacetester)
- [Example using trainers.TwoStreamMetricLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/TwoStreamMetricLoss.ipynb) |
| [chingisooinar](https://github.com/chingisooinar) | [SubCenterArcFaceLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#subcenterarcfaceloss) |
| [elias-ramzi](https://github.com/elias-ramzi) | [HierarchicalSampler](https://kevinmusgrave.github.io/pytorch-metric-learning/samplers/#hierarchicalsampler) |
@@ -252,6 +253,8 @@ Thanks to the contributors who made pull requests!
| [stompsjo](https://github.com/stompsjo) | Improved documentation for NTXentLoss. |
| [Puzer](https://github.com/Puzer) | Bug fix for PNPLoss. |
| [elisim](https://github.com/elisim) | Developer improvements to DistributedLossWrapper. |
+| [lucamarini22](https://github.com/lucamarini22) | |
+| [marcpaga]((https://github.com/marcpaga) | |
| [GaetanLepage](https://github.com/GaetanLepage) | |
| [z1w](https://github.com/z1w) | |
| [thinline72](https://github.com/thinline72) | |
diff --git a/docs/imgs/smooth_ap_approx_equation.png b/docs/imgs/smooth_ap_approx_equation.png
new file mode 100644
index 00000000..0b37c27a
Binary files /dev/null and b/docs/imgs/smooth_ap_approx_equation.png differ
diff --git a/docs/imgs/smooth_ap_loss_equation.png b/docs/imgs/smooth_ap_loss_equation.png
new file mode 100644
index 00000000..25f061c8
Binary files /dev/null and b/docs/imgs/smooth_ap_loss_equation.png differ
diff --git a/docs/imgs/smooth_ap_sigmoid_equation.png b/docs/imgs/smooth_ap_sigmoid_equation.png
new file mode 100644
index 00000000..7153f2a3
Binary files /dev/null and b/docs/imgs/smooth_ap_sigmoid_equation.png differ
diff --git a/docs/losses.md b/docs/losses.md
index b8ced202..7ee48da4 100644
--- a/docs/losses.md
+++ b/docs/losses.md
@@ -1087,6 +1087,37 @@ losses.SignalToNoiseRatioContrastiveLoss(pos_margin=0, neg_margin=1, **kwargs):
* **pos_loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```.
* **neg_loss**: The loss per negative pair in the batch. Reduction type is ```"neg_pair"```.
+## SmoothAPLoss
+[Smooth-AP: Smoothing the Path Towards Large-Scale Image Retrieval](https://arxiv.org/abs/2007.12163){target=_blank}
+
+```python
+losses.SmoothAPLoss(
+ temperature=0.01,
+ **kwargs
+)
+```
+
+**Equations**:
+
+{: style="height:100px"}
+{: style="height:100px"}
+{: style="height:100px"}
+
+
+**Parameters**:
+
+* **temperature**: The desired temperature for scaling the sigmoid function. This is denoted by $\tau$ in the first and second equations.
+
+
+**Other info**:
+
+* The loss requires the same number of number of elements for each class in the batch labels. An example of valid labels is: `[1, 1, 2, 2, 3, 3]`. An example of invalid labels is `[1, 1, 1, 2, 2, 3, 3]` because there are `3` elements with the value `1`. This can be achieved by using [`samplers.MPerClassSampler`](samplers.md/#mperclasssampler) and setting the `batch_size` and `m` hyperparameters.
+
+**Default distance**:
+
+ - [```CosineSimilarity()```](distances.md#cosinesimilarity)
+ - This is the only compatible distance.
+
## SoftTripleLoss
[SoftTriple Loss: Deep Metric Learning Without Triplet Sampling](http://openaccess.thecvf.com/content_ICCV_2019/papers/Qian_SoftTriple_Loss_Deep_Metric_Learning_Without_Triplet_Sampling_ICCV_2019_paper.pdf){target=_blank}
```python
diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py
index b4066b65..43ce13db 100644
--- a/src/pytorch_metric_learning/__init__.py
+++ b/src/pytorch_metric_learning/__init__.py
@@ -1 +1 @@
-__version__ = "2.8.1"
+__version__ = "2.9.0"
diff --git a/src/pytorch_metric_learning/datasets/__init__.py b/src/pytorch_metric_learning/datasets/__init__.py
index 67050abe..ee5db708 100644
--- a/src/pytorch_metric_learning/datasets/__init__.py
+++ b/src/pytorch_metric_learning/datasets/__init__.py
@@ -2,4 +2,4 @@
from .cars196 import Cars196
from .cub import CUB
from .inaturalist2018 import INaturalist2018
-from .sop import StanfordOnlineProducts
\ No newline at end of file
+from .sop import StanfordOnlineProducts
diff --git a/src/pytorch_metric_learning/distances/dot_product_similarity.py b/src/pytorch_metric_learning/distances/dot_product_similarity.py
index 2e0b4b01..74be22f5 100644
--- a/src/pytorch_metric_learning/distances/dot_product_similarity.py
+++ b/src/pytorch_metric_learning/distances/dot_product_similarity.py
@@ -9,7 +9,7 @@ def __init__(self, **kwargs):
assert self.is_inverted
def compute_mat(self, query_emb, ref_emb):
- return torch.matmul(query_emb, ref_emb.t())
+ return torch.matmul(query_emb, ref_emb.transpose(-1, -2))
def pairwise_distance(self, query_emb, ref_emb):
return torch.sum(query_emb * ref_emb, dim=1)
diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py
index 6bd679a7..096d98a6 100644
--- a/src/pytorch_metric_learning/losses/__init__.py
+++ b/src/pytorch_metric_learning/losses/__init__.py
@@ -30,6 +30,7 @@
from .ranked_list_loss import RankedListLoss
from .self_supervised_loss import SelfSupervisedLoss
from .signal_to_noise_ratio_losses import SignalToNoiseRatioContrastiveLoss
+from .smooth_ap import SmoothAPLoss
from .soft_triple_loss import SoftTripleLoss
from .sphereface_loss import SphereFaceLoss
from .subcenter_arcface_loss import SubCenterArcFaceLoss
diff --git a/src/pytorch_metric_learning/losses/generic_pair_loss.py b/src/pytorch_metric_learning/losses/generic_pair_loss.py
index 6996490f..eaac7507 100644
--- a/src/pytorch_metric_learning/losses/generic_pair_loss.py
+++ b/src/pytorch_metric_learning/losses/generic_pair_loss.py
@@ -28,6 +28,7 @@ def mat_based_loss(self, mat, indices_tuple):
pos_mask, neg_mask = torch.zeros_like(mat), torch.zeros_like(mat)
pos_mask[a1, p] = 1
neg_mask[a2, n] = 1
+ self._assert_either_pos_or_neg(pos_mask, neg_mask)
return self._compute_loss(mat, pos_mask, neg_mask)
def pair_based_loss(self, mat, indices_tuple):
@@ -38,3 +39,9 @@ def pair_based_loss(self, mat, indices_tuple):
if len(a2) > 0:
neg_pair = mat[a2, n]
return self._compute_loss(pos_pair, neg_pair, indices_tuple)
+
+ @staticmethod
+ def _assert_either_pos_or_neg(pos_mask, neg_mask):
+ assert not torch.any(
+ (pos_mask != 0) & (neg_mask != 0)
+ ), "Each pair should be either be positive or negative"
diff --git a/src/pytorch_metric_learning/losses/smooth_ap.py b/src/pytorch_metric_learning/losses/smooth_ap.py
new file mode 100644
index 00000000..b0e441f3
--- /dev/null
+++ b/src/pytorch_metric_learning/losses/smooth_ap.py
@@ -0,0 +1,103 @@
+import torch
+import torch.nn.functional as F
+
+from ..distances import CosineSimilarity
+from ..utils import common_functions as c_f
+from ..utils import loss_and_miner_utils as lmu
+from .base_metric_loss_function import BaseMetricLossFunction
+
+
+class SmoothAPLoss(BaseMetricLossFunction):
+ """
+ Implementation of the SmoothAP loss: https://arxiv.org/abs/2007.12163
+ """
+
+ def __init__(self, temperature=0.01, **kwargs):
+ super().__init__(**kwargs)
+ c_f.assert_distance_type(self, CosineSimilarity)
+ self.temperature = temperature
+
+ def get_default_distance(self):
+ return CosineSimilarity()
+
+ # Implementation is based on the original repository:
+ # https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py#L87
+ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
+ # The loss expects labels such that there is the same number of elements for each class
+ # The number of classes is not important, nor their order, but the number of elements must be the same, eg.
+ #
+ # The following label is valid:
+ # [ A,A,A, B,B,B, C,C,C ]
+ # The following label is NOT valid:
+ # [ B,B,B A,A,A,A, C,C,C ]
+ #
+ c_f.labels_required(labels)
+ c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels)
+
+ counts = torch.bincount(labels)
+ nonzero_indices = torch.nonzero(counts, as_tuple=True)[0]
+ nonzero_counts = counts[nonzero_indices]
+ if nonzero_counts.unique().size(0) != 1:
+ raise ValueError(
+ "All classes must have the same number of elements in the labels.\n"
+ "The given labels have the following number of elements: {}.\n"
+ "You can achieve this using the samplers.MPerClassSampler class and setting the batch_size and m.".format(
+ nonzero_counts.cpu().tolist()
+ )
+ )
+
+ batch_size = embeddings.size(0)
+ num_classes_batch = batch_size // torch.unique(labels).size(0)
+
+ mask = 1.0 - torch.eye(batch_size)
+ mask = mask.unsqueeze(dim=0).repeat(batch_size, 1, 1)
+
+ sims = self.distance(embeddings)
+
+ sims_repeat = sims.unsqueeze(dim=1).repeat(1, batch_size, 1)
+ sims_diff = sims_repeat - sims_repeat.permute(0, 2, 1)
+ sims_sigm = F.sigmoid(sims_diff / self.temperature) * mask.to(sims_diff.device)
+ sims_ranks = torch.sum(sims_sigm, dim=-1) + 1
+
+ xs = embeddings.view(
+ num_classes_batch, batch_size // num_classes_batch, embeddings.size(-1)
+ )
+ pos_mask = 1.0 - torch.eye(batch_size // num_classes_batch)
+ pos_mask = (
+ pos_mask.unsqueeze(dim=0)
+ .unsqueeze(dim=0)
+ .repeat(num_classes_batch, batch_size // num_classes_batch, 1, 1)
+ )
+
+ # Circumvent the shape check in forward method
+ xs_norm = self.distance.maybe_normalize(xs, dim=-1)
+ sims_pos = self.distance.compute_mat(xs_norm, xs_norm)
+
+ sims_pos_repeat = sims_pos.unsqueeze(dim=2).repeat(
+ 1, 1, batch_size // num_classes_batch, 1
+ )
+ sims_pos_diff = sims_pos_repeat - sims_pos_repeat.permute(0, 1, 3, 2)
+
+ sims_pos_sigm = F.sigmoid(sims_pos_diff / self.temperature) * pos_mask.to(
+ sims_diff.device
+ )
+ sims_pos_ranks = torch.sum(sims_pos_sigm, dim=-1) + 1
+
+ g = batch_size // num_classes_batch
+ ap = torch.zeros(batch_size).to(embeddings.device)
+ for i in range(num_classes_batch):
+ for j in range(g):
+ pos_rank = sims_pos_ranks[i, j]
+ all_rank = sims_ranks[i * g + j, i * g : (i + 1) * g]
+ ap[i * g + j] = torch.sum(pos_rank / all_rank) / g
+
+ miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=ap.dtype)
+ loss = (1 - ap) * miner_weights
+
+ return {
+ "ap_loss": {
+ "losses": loss,
+ "indices": c_f.torch_arange_from_size(loss),
+ "reduction_type": "element",
+ }
+ }
diff --git a/src/pytorch_metric_learning/losses/subcenter_arcface_loss.py b/src/pytorch_metric_learning/losses/subcenter_arcface_loss.py
index fc1f2210..56af635c 100644
--- a/src/pytorch_metric_learning/losses/subcenter_arcface_loss.py
+++ b/src/pytorch_metric_learning/losses/subcenter_arcface_loss.py
@@ -1,4 +1,5 @@
import math
+from copy import deepcopy
import numpy as np
import torch
@@ -13,9 +14,16 @@ class SubCenterArcFaceLoss(ArcFaceLoss):
"""
def __init__(self, *args, margin=28.6, scale=64, sub_centers=3, **kwargs):
- num_classes, embedding_size = kwargs["num_classes"], kwargs["embedding_size"]
+ num_classes = deepcopy(kwargs["num_classes"])
+ embedding_size = deepcopy(kwargs["embedding_size"])
+ del kwargs["num_classes"]
+ del kwargs["embedding_size"]
super().__init__(
- num_classes * sub_centers, embedding_size, margin=margin, scale=scale
+ num_classes=num_classes * sub_centers,
+ embedding_size=embedding_size,
+ margin=margin,
+ scale=scale,
+ **kwargs
)
self.sub_centers = sub_centers
self.num_classes = num_classes
diff --git a/tests/datasets/test_cars196.py b/tests/datasets/test_cars196.py
index bd5629f8..96829fe4 100644
--- a/tests/datasets/test_cars196.py
+++ b/tests/datasets/test_cars196.py
@@ -6,6 +6,7 @@
from torch.utils.data import DataLoader
from pytorch_metric_learning.datasets import Cars196
+
from .. import TEST_DATASETS
diff --git a/tests/datasets/test_cub.py b/tests/datasets/test_cub.py
index 60261965..763a4d12 100644
--- a/tests/datasets/test_cub.py
+++ b/tests/datasets/test_cub.py
@@ -6,6 +6,7 @@
from torch.utils.data import DataLoader
from pytorch_metric_learning.datasets import CUB
+
from .. import TEST_DATASETS
diff --git a/tests/datasets/test_inaturalist2018.py b/tests/datasets/test_inaturalist2018.py
index ad26cb7a..6ceb0d23 100644
--- a/tests/datasets/test_inaturalist2018.py
+++ b/tests/datasets/test_inaturalist2018.py
@@ -6,6 +6,7 @@
from torch.utils.data import DataLoader
from pytorch_metric_learning.datasets import INaturalist2018
+
from .. import TEST_DATASETS
diff --git a/tests/datasets/test_sop.py b/tests/datasets/test_sop.py
index 55398716..aa5cd2ee 100644
--- a/tests/datasets/test_sop.py
+++ b/tests/datasets/test_sop.py
@@ -6,6 +6,7 @@
from torch.utils.data import DataLoader
from pytorch_metric_learning.datasets import StanfordOnlineProducts
+
from .. import TEST_DATASETS
diff --git a/tests/losses/test_cross_batch_memory.py b/tests/losses/test_cross_batch_memory.py
index 5c3d47af..f1c2d8c3 100644
--- a/tests/losses/test_cross_batch_memory.py
+++ b/tests/losses/test_cross_batch_memory.py
@@ -238,7 +238,6 @@ def test_loss(self):
batch_size = 32
for inner_loss in [ContrastiveLoss(), MultiSimilarityLoss()]:
inner_miner = MultiSimilarityMiner(0.3)
- outer_miner = MultiSimilarityMiner(0.2)
self.loss = CrossBatchMemory(
loss=inner_loss,
embedding_size=self.embedding_size,
@@ -267,10 +266,6 @@ def test_loss(self):
labels = torch.randint(0, num_labels, (batch_size,)).to(TEST_DEVICE)
loss = self.loss(embeddings, labels)
loss_with_miner = self.loss_with_miner(embeddings, labels)
- oa1, op, oa2, on = outer_miner(embeddings, labels)
- loss_with_miner_and_input_indices = self.loss_with_miner2(
- embeddings, labels, (oa1, op, oa2, on)
- )
all_embeddings = torch.cat([all_embeddings, embeddings])
all_labels = torch.cat([all_labels, labels])
@@ -308,33 +303,6 @@ def test_loss(self):
torch.isclose(loss_with_miner, correct_loss_with_miner)
)
- # loss with inner and outer miner
- indices_tuple = inner_miner(
- embeddings, labels, all_embeddings, all_labels
- )
- a1, p, a2, n = lmu.remove_self_comparisons(
- indices_tuple,
- self.loss_with_miner2.curr_batch_idx,
- self.loss_with_miner2.memory_size,
- )
- a1 = torch.cat([oa1, a1])
- p = torch.cat([op, p])
- a2 = torch.cat([oa2, a2])
- n = torch.cat([on, n])
- correct_loss_with_miner_and_input_indice = inner_loss(
- embeddings,
- labels,
- (a1, p, a2, n),
- all_embeddings,
- all_labels,
- )
- self.assertTrue(
- torch.isclose(
- loss_with_miner_and_input_indices,
- correct_loss_with_miner_and_input_indice,
- )
- )
-
def test_queue(self):
for test_enqueue_mask in [False, True]:
for dtype in TEST_DTYPES:
diff --git a/tests/losses/test_smooth_ap_loss.py b/tests/losses/test_smooth_ap_loss.py
new file mode 100644
index 00000000..422e9423
--- /dev/null
+++ b/tests/losses/test_smooth_ap_loss.py
@@ -0,0 +1,191 @@
+import unittest
+
+import torch
+import torch.nn.functional as F
+
+from pytorch_metric_learning.losses import SmoothAPLoss
+
+from .. import TEST_DEVICE, TEST_DTYPES
+
+HYPERPARAMETERS = {
+ "temp": 0.01,
+ "batch_size": 60,
+ "num_id": 6,
+ "feat_dims": 256,
+}
+TEST_SEEDS = [42, 1234, 5642, 9999, 3459]
+
+
+# Original implementation of the SmoothAP loss taken from:
+# https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py
+def sigmoid(tensor, temp=1.0):
+ """temperature controlled sigmoid
+
+ takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
+ """
+ exponent = -tensor / temp
+ # clamp the input tensor for stability
+ exponent = torch.clamp(exponent, min=-50, max=50)
+ y = 1.0 / (1.0 + torch.exp(exponent))
+ return y
+
+
+def compute_aff(x):
+ """computes the affinity matrix between an input vector and itself"""
+ return torch.mm(x, x.t())
+
+
+class SmoothAP(torch.nn.Module):
+ """PyTorch implementation of the Smooth-AP loss.
+
+ implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns
+ the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must
+ have the same number of instances represented in the mini-batch and must be ordered sequentially by class.
+
+ e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:
+
+ labels = ( A, A, A, B, B, B, C, C, C)
+
+ (the order of the classes however does not matter)
+
+ For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the
+ mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the
+ same class. The loss returns the average Smooth-AP across all instances in the mini-batch.
+
+ Args:
+ anneal : float
+ the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature
+ results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.
+ batch_size : int
+ the batch size being used during training.
+ num_id : int
+ the number of different classes that are represented in the batch.
+ feat_dims : int
+ the dimension of the input feature embeddings
+
+ Shape:
+ - Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)
+ - Output: scalar
+
+ Examples::
+
+ >>> loss = SmoothAP(0.01, 60, 6, 256)
+ >>> input = torch.randn(60, 256, requires_grad=True).to("cuda:0")
+ >>> output = loss(input)
+ >>> output.backward()
+ """
+
+ def __init__(self, anneal, batch_size, num_id, feat_dims):
+ """
+ Parameters
+ ----------
+ anneal : float
+ the temperature of the sigmoid that is used to smooth the ranking function
+ batch_size : int
+ the batch size being used
+ num_id : int
+ the number of different classes that are represented in the batch
+ feat_dims : int
+ the dimension of the input feature embeddings
+ """
+ super(SmoothAP, self).__init__()
+
+ assert batch_size % num_id == 0
+
+ self.anneal = anneal
+ self.batch_size = batch_size
+ self.num_id = num_id
+ self.feat_dims = feat_dims
+
+ def forward(self, preds):
+ """Forward pass for all input predictions: preds - (batch_size x feat_dims)"""
+
+ # ------ differentiable ranking of all retrieval set ------
+ # compute the mask which ignores the relevance score of the query to itself
+ mask = 1.0 - torch.eye(self.batch_size)
+ mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
+ # compute the relevance scores via cosine similarity of the CNN-produced embedding vectors
+ sim_all = compute_aff(preds)
+ sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1)
+ # compute the difference matrix
+ sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
+ # pass through the sigmoid
+ sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask.to(TEST_DEVICE)
+ # compute the rankings
+ sim_all_rk = torch.sum(sim_sg, dim=-1) + 1
+
+ # ------ differentiable ranking of only positive set in retrieval set ------
+ # compute the mask which only gives non-zero weights to the positive set
+ xs = preds.view(self.num_id, int(self.batch_size / self.num_id), self.feat_dims)
+ pos_mask = 1.0 - torch.eye(int(self.batch_size / self.num_id))
+ pos_mask = (
+ pos_mask.unsqueeze(dim=0)
+ .unsqueeze(dim=0)
+ .repeat(self.num_id, int(self.batch_size / self.num_id), 1, 1)
+ )
+
+ # compute the relevance scores
+ sim_pos = torch.bmm(xs, xs.permute(0, 2, 1))
+ sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat(
+ 1, 1, int(self.batch_size / self.num_id), 1
+ )
+ # compute the difference matrix
+ sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2)
+ # pass through the sigmoid
+ sim_pos_sg = sigmoid(sim_pos_diff, temp=self.anneal) * pos_mask.to(TEST_DEVICE)
+ # compute the rankings of the positive set
+ sim_pos_rk = torch.sum(sim_pos_sg, dim=-1) + 1
+
+ # sum the values of the Smooth-AP for all instances in the mini-batch
+ ap = torch.zeros(1).to(TEST_DEVICE)
+ group = int(self.batch_size / self.num_id)
+ for ind in range(self.num_id):
+ pos_divide = torch.sum(
+ sim_pos_rk[ind]
+ / (
+ sim_all_rk[
+ (ind * group) : ((ind + 1) * group),
+ (ind * group) : ((ind + 1) * group),
+ ]
+ )
+ )
+ ap = ap + ((pos_divide / group) / self.batch_size)
+
+ return 1 - ap
+
+
+class TestSmoothAPLoss(unittest.TestCase):
+ def test_smooth_ap_loss(self):
+ for dtype in TEST_DTYPES:
+ for seed in TEST_SEEDS:
+ torch.manual_seed(seed)
+ loss = SmoothAP(
+ HYPERPARAMETERS["temp"],
+ HYPERPARAMETERS["batch_size"],
+ HYPERPARAMETERS["num_id"],
+ HYPERPARAMETERS["feat_dims"],
+ )
+ rand_tensor = (
+ torch.randn(
+ HYPERPARAMETERS["batch_size"],
+ HYPERPARAMETERS["feat_dims"],
+ requires_grad=True,
+ )
+ .to(TEST_DEVICE)
+ .to(dtype)
+ )
+ # The original code uses a model that normalizes the output vector
+ input_ = F.normalize(rand_tensor, p=2.0, dim=-1)
+ output = loss(input_)
+
+ loss2 = SmoothAPLoss(temperature=HYPERPARAMETERS["temp"])
+ # The original code assumes the label is in this format
+ labels = []
+ for i in range(
+ HYPERPARAMETERS["batch_size"] // HYPERPARAMETERS["num_id"]
+ ):
+ labels.extend([i for _ in range(HYPERPARAMETERS["num_id"])])
+
+ labels = torch.tensor(labels)
+ output2 = loss2.forward(rand_tensor, labels)
+ self.assertTrue(torch.isclose(output, output2))
diff --git a/tests/losses/test_subcenter_arcface_loss.py b/tests/losses/test_subcenter_arcface_loss.py
index a4d324f4..d9501155 100644
--- a/tests/losses/test_subcenter_arcface_loss.py
+++ b/tests/losses/test_subcenter_arcface_loss.py
@@ -6,6 +6,7 @@
import torch.nn.functional as F
from pytorch_metric_learning.losses import ArcFaceLoss, SubCenterArcFaceLoss
+from pytorch_metric_learning.reducers import DoNothingReducer
from .. import TEST_DEVICE, TEST_DTYPES
@@ -142,3 +143,18 @@ def test_inference_subcenter_arcface(self):
)
== 0
)
+
+ def test_reducer_subcenter_arcface(self):
+
+ arcfaceloss = SubCenterArcFaceLoss(
+ num_classes=10,
+ sub_centers=3,
+ embedding_size=64,
+ reducer=DoNothingReducer(),
+ )
+
+ emb = torch.randn(4, 64)
+ result = arcfaceloss(emb, torch.arange(4))
+
+ self.assertTrue(isinstance(result, dict))
+ self.assertTrue(result["loss"]["losses"].shape[0] == 4)