diff --git a/CONTENTS.md b/CONTENTS.md index 24a48b2f..3822239c 100644 --- a/CONTENTS.md +++ b/CONTENTS.md @@ -46,6 +46,7 @@ | [**ProxyNCALoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyncaloss) | [No Fuss Distance Metric Learning using Proxies](https://arxiv.org/pdf/1703.07464.pdf) | [**RankedListLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#rankedlistloss) | [Ranked List Loss for Deep Metric Learning](https://arxiv.org/abs/1903.03238) | [**SignalToNoiseRatioContrastiveLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#signaltonoiseratiocontrastiveloss) | [Signal-to-Noise Ratio: A Robust Distance Metric for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Yuan_Signal-To-Noise_Ratio_A_Robust_Distance_Metric_for_Deep_Metric_Learning_CVPR_2019_paper.pdf) +| [**SmoothAPLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#smoothaploss) | [Smooth-AP: Smoothing the Path Towards Large-Scale Image Retrieval](https://arxiv.org/abs/2007.12163) | [**SoftTripleLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#softtripleloss) | [SoftTriple Loss: Deep Metric Learning Without Triplet Sampling](http://openaccess.thecvf.com/content_ICCV_2019/papers/Qian_SoftTriple_Loss_Deep_Metric_Learning_Without_Triplet_Sampling_ICCV_2019_paper.pdf) | [**SphereFaceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#spherefaceloss) | [SphereFace: Deep Hypersphere Embedding for Face Recognition](https://arxiv.org/pdf/1704.08063.pdf) | [**SubCenterArcFaceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#subcenterarcfaceloss) | [Sub-center ArcFace: Boosting Face Recognition by Large-scale Noisy Web Faces](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123560715.pdf) diff --git a/README.md b/README.md index a704340f..5c1f1c42 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,11 @@ ## News +**August 17**: v2.9.0 +- Added [SmoothAPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#smoothaploss). +- Improved SubCenterArcFaceLoss and GenericPairLoss. +- Thank you [ir2718](https://github.com/ir2718), [lucamarini22](https://github.com/lucamarini22), and [marcpaga](https://github.com/marcpaga). + **December 11**: v2.8.0 - Added the [Datasets](https://kevinmusgrave.github.io/pytorch-metric-learning/datasets) module for easy downloading of common datasets: - [CUB200](https://kevinmusgrave.github.io/pytorch-metric-learning/datasets/#cub-200-2011) @@ -26,10 +31,6 @@ - [Stanford Online Products](https://kevinmusgrave.github.io/pytorch-metric-learning/datasets/#stanfordonlineproducts) - Thank you [ir2718](https://github.com/ir2718). -**November 2**: v2.7.0 -- Added [ThresholdConsistentMarginLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#thresholdconsistentmarginloss). -- Thank you [ir2718](https://github.com/ir2718). - ## Documentation - [**View the documentation here**](https://kevinmusgrave.github.io/pytorch-metric-learning/) - [**View the installation instructions here**](https://github.com/KevinMusgrave/pytorch-metric-learning#installation) @@ -231,7 +232,7 @@ Thanks to the contributors who made pull requests! |[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss)
- [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
- [HistogramLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#histogramloss)
- [DynamicSoftMarginLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#dynamicsoftmarginloss)
- [RankedListLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#rankedlistloss) | |[mlopezantequera](https://github.com/mlopezantequera) | - Made the [testers](https://kevinmusgrave.github.io/pytorch-metric-learning/testers) work on any combination of query and reference sets
- Made [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) work with arbitrary label comparisons | |[cwkeam](https://github.com/cwkeam) | - [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss)
- [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss)
- Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/)
- BaseLossWrapper| -| [ir2718](https://github.com/ir2718) | - [ThresholdConsistentMarginLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#thresholdconsistentmarginloss)
- The [Datasets](https://kevinmusgrave.github.io/pytorch-metric-learning/datasets) module | +| [ir2718](https://github.com/ir2718) | - [ThresholdConsistentMarginLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#thresholdconsistentmarginloss)
- [SmoothAPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#smoothaploss)
- The [Datasets](https://kevinmusgrave.github.io/pytorch-metric-learning/datasets) module | |[marijnl](https://github.com/marijnl)| - [BatchEasyHardMiner](https://kevinmusgrave.github.io/pytorch-metric-learning/miners/#batcheasyhardminer)
- [TwoStreamMetricLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/trainers/#twostreammetricloss)
- [GlobalTwoStreamEmbeddingSpaceTester](https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#globaltwostreamembeddingspacetester)
- [Example using trainers.TwoStreamMetricLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/TwoStreamMetricLoss.ipynb) | | [chingisooinar](https://github.com/chingisooinar) | [SubCenterArcFaceLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#subcenterarcfaceloss) | | [elias-ramzi](https://github.com/elias-ramzi) | [HierarchicalSampler](https://kevinmusgrave.github.io/pytorch-metric-learning/samplers/#hierarchicalsampler) | @@ -252,6 +253,8 @@ Thanks to the contributors who made pull requests! | [stompsjo](https://github.com/stompsjo) | Improved documentation for NTXentLoss. | | [Puzer](https://github.com/Puzer) | Bug fix for PNPLoss. | | [elisim](https://github.com/elisim) | Developer improvements to DistributedLossWrapper. | +| [lucamarini22](https://github.com/lucamarini22) | | +| [marcpaga]((https://github.com/marcpaga) | | | [GaetanLepage](https://github.com/GaetanLepage) | | | [z1w](https://github.com/z1w) | | | [thinline72](https://github.com/thinline72) | | diff --git a/docs/imgs/smooth_ap_approx_equation.png b/docs/imgs/smooth_ap_approx_equation.png new file mode 100644 index 00000000..0b37c27a Binary files /dev/null and b/docs/imgs/smooth_ap_approx_equation.png differ diff --git a/docs/imgs/smooth_ap_loss_equation.png b/docs/imgs/smooth_ap_loss_equation.png new file mode 100644 index 00000000..25f061c8 Binary files /dev/null and b/docs/imgs/smooth_ap_loss_equation.png differ diff --git a/docs/imgs/smooth_ap_sigmoid_equation.png b/docs/imgs/smooth_ap_sigmoid_equation.png new file mode 100644 index 00000000..7153f2a3 Binary files /dev/null and b/docs/imgs/smooth_ap_sigmoid_equation.png differ diff --git a/docs/losses.md b/docs/losses.md index b8ced202..7ee48da4 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -1087,6 +1087,37 @@ losses.SignalToNoiseRatioContrastiveLoss(pos_margin=0, neg_margin=1, **kwargs): * **pos_loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```. * **neg_loss**: The loss per negative pair in the batch. Reduction type is ```"neg_pair"```. +## SmoothAPLoss +[Smooth-AP: Smoothing the Path Towards Large-Scale Image Retrieval](https://arxiv.org/abs/2007.12163){target=_blank} + +```python +losses.SmoothAPLoss( + temperature=0.01, + **kwargs +) +``` + +**Equations**: + +![smooth_ap_loss_equation1](imgs/smooth_ap_sigmoid_equation.png){: style="height:100px"} +![smooth_ap_loss_equation2](imgs/smooth_ap_approx_equation.png){: style="height:100px"} +![smooth_ap_loss_equation3](imgs/smooth_ap_loss_equation.png){: style="height:100px"} + + +**Parameters**: + +* **temperature**: The desired temperature for scaling the sigmoid function. This is denoted by $\tau$ in the first and second equations. + + +**Other info**: + +* The loss requires the same number of number of elements for each class in the batch labels. An example of valid labels is: `[1, 1, 2, 2, 3, 3]`. An example of invalid labels is `[1, 1, 1, 2, 2, 3, 3]` because there are `3` elements with the value `1`. This can be achieved by using [`samplers.MPerClassSampler`](samplers.md/#mperclasssampler) and setting the `batch_size` and `m` hyperparameters. + +**Default distance**: + + - [```CosineSimilarity()```](distances.md#cosinesimilarity) + - This is the only compatible distance. + ## SoftTripleLoss [SoftTriple Loss: Deep Metric Learning Without Triplet Sampling](http://openaccess.thecvf.com/content_ICCV_2019/papers/Qian_SoftTriple_Loss_Deep_Metric_Learning_Without_Triplet_Sampling_ICCV_2019_paper.pdf){target=_blank} ```python diff --git a/src/pytorch_metric_learning/__init__.py b/src/pytorch_metric_learning/__init__.py index b4066b65..43ce13db 100644 --- a/src/pytorch_metric_learning/__init__.py +++ b/src/pytorch_metric_learning/__init__.py @@ -1 +1 @@ -__version__ = "2.8.1" +__version__ = "2.9.0" diff --git a/src/pytorch_metric_learning/datasets/__init__.py b/src/pytorch_metric_learning/datasets/__init__.py index 67050abe..ee5db708 100644 --- a/src/pytorch_metric_learning/datasets/__init__.py +++ b/src/pytorch_metric_learning/datasets/__init__.py @@ -2,4 +2,4 @@ from .cars196 import Cars196 from .cub import CUB from .inaturalist2018 import INaturalist2018 -from .sop import StanfordOnlineProducts \ No newline at end of file +from .sop import StanfordOnlineProducts diff --git a/src/pytorch_metric_learning/distances/dot_product_similarity.py b/src/pytorch_metric_learning/distances/dot_product_similarity.py index 2e0b4b01..74be22f5 100644 --- a/src/pytorch_metric_learning/distances/dot_product_similarity.py +++ b/src/pytorch_metric_learning/distances/dot_product_similarity.py @@ -9,7 +9,7 @@ def __init__(self, **kwargs): assert self.is_inverted def compute_mat(self, query_emb, ref_emb): - return torch.matmul(query_emb, ref_emb.t()) + return torch.matmul(query_emb, ref_emb.transpose(-1, -2)) def pairwise_distance(self, query_emb, ref_emb): return torch.sum(query_emb * ref_emb, dim=1) diff --git a/src/pytorch_metric_learning/losses/__init__.py b/src/pytorch_metric_learning/losses/__init__.py index 6bd679a7..096d98a6 100644 --- a/src/pytorch_metric_learning/losses/__init__.py +++ b/src/pytorch_metric_learning/losses/__init__.py @@ -30,6 +30,7 @@ from .ranked_list_loss import RankedListLoss from .self_supervised_loss import SelfSupervisedLoss from .signal_to_noise_ratio_losses import SignalToNoiseRatioContrastiveLoss +from .smooth_ap import SmoothAPLoss from .soft_triple_loss import SoftTripleLoss from .sphereface_loss import SphereFaceLoss from .subcenter_arcface_loss import SubCenterArcFaceLoss diff --git a/src/pytorch_metric_learning/losses/generic_pair_loss.py b/src/pytorch_metric_learning/losses/generic_pair_loss.py index 6996490f..eaac7507 100644 --- a/src/pytorch_metric_learning/losses/generic_pair_loss.py +++ b/src/pytorch_metric_learning/losses/generic_pair_loss.py @@ -28,6 +28,7 @@ def mat_based_loss(self, mat, indices_tuple): pos_mask, neg_mask = torch.zeros_like(mat), torch.zeros_like(mat) pos_mask[a1, p] = 1 neg_mask[a2, n] = 1 + self._assert_either_pos_or_neg(pos_mask, neg_mask) return self._compute_loss(mat, pos_mask, neg_mask) def pair_based_loss(self, mat, indices_tuple): @@ -38,3 +39,9 @@ def pair_based_loss(self, mat, indices_tuple): if len(a2) > 0: neg_pair = mat[a2, n] return self._compute_loss(pos_pair, neg_pair, indices_tuple) + + @staticmethod + def _assert_either_pos_or_neg(pos_mask, neg_mask): + assert not torch.any( + (pos_mask != 0) & (neg_mask != 0) + ), "Each pair should be either be positive or negative" diff --git a/src/pytorch_metric_learning/losses/smooth_ap.py b/src/pytorch_metric_learning/losses/smooth_ap.py new file mode 100644 index 00000000..b0e441f3 --- /dev/null +++ b/src/pytorch_metric_learning/losses/smooth_ap.py @@ -0,0 +1,103 @@ +import torch +import torch.nn.functional as F + +from ..distances import CosineSimilarity +from ..utils import common_functions as c_f +from ..utils import loss_and_miner_utils as lmu +from .base_metric_loss_function import BaseMetricLossFunction + + +class SmoothAPLoss(BaseMetricLossFunction): + """ + Implementation of the SmoothAP loss: https://arxiv.org/abs/2007.12163 + """ + + def __init__(self, temperature=0.01, **kwargs): + super().__init__(**kwargs) + c_f.assert_distance_type(self, CosineSimilarity) + self.temperature = temperature + + def get_default_distance(self): + return CosineSimilarity() + + # Implementation is based on the original repository: + # https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py#L87 + def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): + # The loss expects labels such that there is the same number of elements for each class + # The number of classes is not important, nor their order, but the number of elements must be the same, eg. + # + # The following label is valid: + # [ A,A,A, B,B,B, C,C,C ] + # The following label is NOT valid: + # [ B,B,B A,A,A,A, C,C,C ] + # + c_f.labels_required(labels) + c_f.ref_not_supported(embeddings, labels, ref_emb, ref_labels) + + counts = torch.bincount(labels) + nonzero_indices = torch.nonzero(counts, as_tuple=True)[0] + nonzero_counts = counts[nonzero_indices] + if nonzero_counts.unique().size(0) != 1: + raise ValueError( + "All classes must have the same number of elements in the labels.\n" + "The given labels have the following number of elements: {}.\n" + "You can achieve this using the samplers.MPerClassSampler class and setting the batch_size and m.".format( + nonzero_counts.cpu().tolist() + ) + ) + + batch_size = embeddings.size(0) + num_classes_batch = batch_size // torch.unique(labels).size(0) + + mask = 1.0 - torch.eye(batch_size) + mask = mask.unsqueeze(dim=0).repeat(batch_size, 1, 1) + + sims = self.distance(embeddings) + + sims_repeat = sims.unsqueeze(dim=1).repeat(1, batch_size, 1) + sims_diff = sims_repeat - sims_repeat.permute(0, 2, 1) + sims_sigm = F.sigmoid(sims_diff / self.temperature) * mask.to(sims_diff.device) + sims_ranks = torch.sum(sims_sigm, dim=-1) + 1 + + xs = embeddings.view( + num_classes_batch, batch_size // num_classes_batch, embeddings.size(-1) + ) + pos_mask = 1.0 - torch.eye(batch_size // num_classes_batch) + pos_mask = ( + pos_mask.unsqueeze(dim=0) + .unsqueeze(dim=0) + .repeat(num_classes_batch, batch_size // num_classes_batch, 1, 1) + ) + + # Circumvent the shape check in forward method + xs_norm = self.distance.maybe_normalize(xs, dim=-1) + sims_pos = self.distance.compute_mat(xs_norm, xs_norm) + + sims_pos_repeat = sims_pos.unsqueeze(dim=2).repeat( + 1, 1, batch_size // num_classes_batch, 1 + ) + sims_pos_diff = sims_pos_repeat - sims_pos_repeat.permute(0, 1, 3, 2) + + sims_pos_sigm = F.sigmoid(sims_pos_diff / self.temperature) * pos_mask.to( + sims_diff.device + ) + sims_pos_ranks = torch.sum(sims_pos_sigm, dim=-1) + 1 + + g = batch_size // num_classes_batch + ap = torch.zeros(batch_size).to(embeddings.device) + for i in range(num_classes_batch): + for j in range(g): + pos_rank = sims_pos_ranks[i, j] + all_rank = sims_ranks[i * g + j, i * g : (i + 1) * g] + ap[i * g + j] = torch.sum(pos_rank / all_rank) / g + + miner_weights = lmu.convert_to_weights(indices_tuple, labels, dtype=ap.dtype) + loss = (1 - ap) * miner_weights + + return { + "ap_loss": { + "losses": loss, + "indices": c_f.torch_arange_from_size(loss), + "reduction_type": "element", + } + } diff --git a/src/pytorch_metric_learning/losses/subcenter_arcface_loss.py b/src/pytorch_metric_learning/losses/subcenter_arcface_loss.py index fc1f2210..56af635c 100644 --- a/src/pytorch_metric_learning/losses/subcenter_arcface_loss.py +++ b/src/pytorch_metric_learning/losses/subcenter_arcface_loss.py @@ -1,4 +1,5 @@ import math +from copy import deepcopy import numpy as np import torch @@ -13,9 +14,16 @@ class SubCenterArcFaceLoss(ArcFaceLoss): """ def __init__(self, *args, margin=28.6, scale=64, sub_centers=3, **kwargs): - num_classes, embedding_size = kwargs["num_classes"], kwargs["embedding_size"] + num_classes = deepcopy(kwargs["num_classes"]) + embedding_size = deepcopy(kwargs["embedding_size"]) + del kwargs["num_classes"] + del kwargs["embedding_size"] super().__init__( - num_classes * sub_centers, embedding_size, margin=margin, scale=scale + num_classes=num_classes * sub_centers, + embedding_size=embedding_size, + margin=margin, + scale=scale, + **kwargs ) self.sub_centers = sub_centers self.num_classes = num_classes diff --git a/tests/datasets/test_cars196.py b/tests/datasets/test_cars196.py index bd5629f8..96829fe4 100644 --- a/tests/datasets/test_cars196.py +++ b/tests/datasets/test_cars196.py @@ -6,6 +6,7 @@ from torch.utils.data import DataLoader from pytorch_metric_learning.datasets import Cars196 + from .. import TEST_DATASETS diff --git a/tests/datasets/test_cub.py b/tests/datasets/test_cub.py index 60261965..763a4d12 100644 --- a/tests/datasets/test_cub.py +++ b/tests/datasets/test_cub.py @@ -6,6 +6,7 @@ from torch.utils.data import DataLoader from pytorch_metric_learning.datasets import CUB + from .. import TEST_DATASETS diff --git a/tests/datasets/test_inaturalist2018.py b/tests/datasets/test_inaturalist2018.py index ad26cb7a..6ceb0d23 100644 --- a/tests/datasets/test_inaturalist2018.py +++ b/tests/datasets/test_inaturalist2018.py @@ -6,6 +6,7 @@ from torch.utils.data import DataLoader from pytorch_metric_learning.datasets import INaturalist2018 + from .. import TEST_DATASETS diff --git a/tests/datasets/test_sop.py b/tests/datasets/test_sop.py index 55398716..aa5cd2ee 100644 --- a/tests/datasets/test_sop.py +++ b/tests/datasets/test_sop.py @@ -6,6 +6,7 @@ from torch.utils.data import DataLoader from pytorch_metric_learning.datasets import StanfordOnlineProducts + from .. import TEST_DATASETS diff --git a/tests/losses/test_cross_batch_memory.py b/tests/losses/test_cross_batch_memory.py index 5c3d47af..f1c2d8c3 100644 --- a/tests/losses/test_cross_batch_memory.py +++ b/tests/losses/test_cross_batch_memory.py @@ -238,7 +238,6 @@ def test_loss(self): batch_size = 32 for inner_loss in [ContrastiveLoss(), MultiSimilarityLoss()]: inner_miner = MultiSimilarityMiner(0.3) - outer_miner = MultiSimilarityMiner(0.2) self.loss = CrossBatchMemory( loss=inner_loss, embedding_size=self.embedding_size, @@ -267,10 +266,6 @@ def test_loss(self): labels = torch.randint(0, num_labels, (batch_size,)).to(TEST_DEVICE) loss = self.loss(embeddings, labels) loss_with_miner = self.loss_with_miner(embeddings, labels) - oa1, op, oa2, on = outer_miner(embeddings, labels) - loss_with_miner_and_input_indices = self.loss_with_miner2( - embeddings, labels, (oa1, op, oa2, on) - ) all_embeddings = torch.cat([all_embeddings, embeddings]) all_labels = torch.cat([all_labels, labels]) @@ -308,33 +303,6 @@ def test_loss(self): torch.isclose(loss_with_miner, correct_loss_with_miner) ) - # loss with inner and outer miner - indices_tuple = inner_miner( - embeddings, labels, all_embeddings, all_labels - ) - a1, p, a2, n = lmu.remove_self_comparisons( - indices_tuple, - self.loss_with_miner2.curr_batch_idx, - self.loss_with_miner2.memory_size, - ) - a1 = torch.cat([oa1, a1]) - p = torch.cat([op, p]) - a2 = torch.cat([oa2, a2]) - n = torch.cat([on, n]) - correct_loss_with_miner_and_input_indice = inner_loss( - embeddings, - labels, - (a1, p, a2, n), - all_embeddings, - all_labels, - ) - self.assertTrue( - torch.isclose( - loss_with_miner_and_input_indices, - correct_loss_with_miner_and_input_indice, - ) - ) - def test_queue(self): for test_enqueue_mask in [False, True]: for dtype in TEST_DTYPES: diff --git a/tests/losses/test_smooth_ap_loss.py b/tests/losses/test_smooth_ap_loss.py new file mode 100644 index 00000000..422e9423 --- /dev/null +++ b/tests/losses/test_smooth_ap_loss.py @@ -0,0 +1,191 @@ +import unittest + +import torch +import torch.nn.functional as F + +from pytorch_metric_learning.losses import SmoothAPLoss + +from .. import TEST_DEVICE, TEST_DTYPES + +HYPERPARAMETERS = { + "temp": 0.01, + "batch_size": 60, + "num_id": 6, + "feat_dims": 256, +} +TEST_SEEDS = [42, 1234, 5642, 9999, 3459] + + +# Original implementation of the SmoothAP loss taken from: +# https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py +def sigmoid(tensor, temp=1.0): + """temperature controlled sigmoid + + takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp + """ + exponent = -tensor / temp + # clamp the input tensor for stability + exponent = torch.clamp(exponent, min=-50, max=50) + y = 1.0 / (1.0 + torch.exp(exponent)) + return y + + +def compute_aff(x): + """computes the affinity matrix between an input vector and itself""" + return torch.mm(x, x.t()) + + +class SmoothAP(torch.nn.Module): + """PyTorch implementation of the Smooth-AP loss. + + implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns + the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must + have the same number of instances represented in the mini-batch and must be ordered sequentially by class. + + e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like: + + labels = ( A, A, A, B, B, B, C, C, C) + + (the order of the classes however does not matter) + + For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the + mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the + same class. The loss returns the average Smooth-AP across all instances in the mini-batch. + + Args: + anneal : float + the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature + results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function. + batch_size : int + the batch size being used during training. + num_id : int + the number of different classes that are represented in the batch. + feat_dims : int + the dimension of the input feature embeddings + + Shape: + - Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor) + - Output: scalar + + Examples:: + + >>> loss = SmoothAP(0.01, 60, 6, 256) + >>> input = torch.randn(60, 256, requires_grad=True).to("cuda:0") + >>> output = loss(input) + >>> output.backward() + """ + + def __init__(self, anneal, batch_size, num_id, feat_dims): + """ + Parameters + ---------- + anneal : float + the temperature of the sigmoid that is used to smooth the ranking function + batch_size : int + the batch size being used + num_id : int + the number of different classes that are represented in the batch + feat_dims : int + the dimension of the input feature embeddings + """ + super(SmoothAP, self).__init__() + + assert batch_size % num_id == 0 + + self.anneal = anneal + self.batch_size = batch_size + self.num_id = num_id + self.feat_dims = feat_dims + + def forward(self, preds): + """Forward pass for all input predictions: preds - (batch_size x feat_dims)""" + + # ------ differentiable ranking of all retrieval set ------ + # compute the mask which ignores the relevance score of the query to itself + mask = 1.0 - torch.eye(self.batch_size) + mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1) + # compute the relevance scores via cosine similarity of the CNN-produced embedding vectors + sim_all = compute_aff(preds) + sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1) + # compute the difference matrix + sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1) + # pass through the sigmoid + sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask.to(TEST_DEVICE) + # compute the rankings + sim_all_rk = torch.sum(sim_sg, dim=-1) + 1 + + # ------ differentiable ranking of only positive set in retrieval set ------ + # compute the mask which only gives non-zero weights to the positive set + xs = preds.view(self.num_id, int(self.batch_size / self.num_id), self.feat_dims) + pos_mask = 1.0 - torch.eye(int(self.batch_size / self.num_id)) + pos_mask = ( + pos_mask.unsqueeze(dim=0) + .unsqueeze(dim=0) + .repeat(self.num_id, int(self.batch_size / self.num_id), 1, 1) + ) + + # compute the relevance scores + sim_pos = torch.bmm(xs, xs.permute(0, 2, 1)) + sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat( + 1, 1, int(self.batch_size / self.num_id), 1 + ) + # compute the difference matrix + sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2) + # pass through the sigmoid + sim_pos_sg = sigmoid(sim_pos_diff, temp=self.anneal) * pos_mask.to(TEST_DEVICE) + # compute the rankings of the positive set + sim_pos_rk = torch.sum(sim_pos_sg, dim=-1) + 1 + + # sum the values of the Smooth-AP for all instances in the mini-batch + ap = torch.zeros(1).to(TEST_DEVICE) + group = int(self.batch_size / self.num_id) + for ind in range(self.num_id): + pos_divide = torch.sum( + sim_pos_rk[ind] + / ( + sim_all_rk[ + (ind * group) : ((ind + 1) * group), + (ind * group) : ((ind + 1) * group), + ] + ) + ) + ap = ap + ((pos_divide / group) / self.batch_size) + + return 1 - ap + + +class TestSmoothAPLoss(unittest.TestCase): + def test_smooth_ap_loss(self): + for dtype in TEST_DTYPES: + for seed in TEST_SEEDS: + torch.manual_seed(seed) + loss = SmoothAP( + HYPERPARAMETERS["temp"], + HYPERPARAMETERS["batch_size"], + HYPERPARAMETERS["num_id"], + HYPERPARAMETERS["feat_dims"], + ) + rand_tensor = ( + torch.randn( + HYPERPARAMETERS["batch_size"], + HYPERPARAMETERS["feat_dims"], + requires_grad=True, + ) + .to(TEST_DEVICE) + .to(dtype) + ) + # The original code uses a model that normalizes the output vector + input_ = F.normalize(rand_tensor, p=2.0, dim=-1) + output = loss(input_) + + loss2 = SmoothAPLoss(temperature=HYPERPARAMETERS["temp"]) + # The original code assumes the label is in this format + labels = [] + for i in range( + HYPERPARAMETERS["batch_size"] // HYPERPARAMETERS["num_id"] + ): + labels.extend([i for _ in range(HYPERPARAMETERS["num_id"])]) + + labels = torch.tensor(labels) + output2 = loss2.forward(rand_tensor, labels) + self.assertTrue(torch.isclose(output, output2)) diff --git a/tests/losses/test_subcenter_arcface_loss.py b/tests/losses/test_subcenter_arcface_loss.py index a4d324f4..d9501155 100644 --- a/tests/losses/test_subcenter_arcface_loss.py +++ b/tests/losses/test_subcenter_arcface_loss.py @@ -6,6 +6,7 @@ import torch.nn.functional as F from pytorch_metric_learning.losses import ArcFaceLoss, SubCenterArcFaceLoss +from pytorch_metric_learning.reducers import DoNothingReducer from .. import TEST_DEVICE, TEST_DTYPES @@ -142,3 +143,18 @@ def test_inference_subcenter_arcface(self): ) == 0 ) + + def test_reducer_subcenter_arcface(self): + + arcfaceloss = SubCenterArcFaceLoss( + num_classes=10, + sub_centers=3, + embedding_size=64, + reducer=DoNothingReducer(), + ) + + emb = torch.randn(4, 64) + result = arcfaceloss(emb, torch.arange(4)) + + self.assertTrue(isinstance(result, dict)) + self.assertTrue(result["loss"]["losses"].shape[0] == 4)