From 899b7a7caa5fe050866f9e2f41e8691a09ac29d8 Mon Sep 17 00:00:00 2001 From: Domenico Muscillo Date: Thu, 6 Jul 2023 18:52:29 +0200 Subject: [PATCH 1/2] First simplification: reducers --- .gitignore | 2 + .../losses/margin_loss.py | 2 +- .../reducers/avg_non_zero_reducer.py | 1 + .../reducers/base_reducer.py | 4 +- .../reducers/class_weighted_reducer.py | 26 ++--- .../reducers/divisor_reducer.py | 27 ++--- .../reducers/do_nothing_reducer.py | 2 +- .../reducers/mean_reducer.py | 24 ++-- .../reducers/multiple_reducers.py | 26 +++-- .../reducers/per_anchor_reducer.py | 28 +++-- .../reducers/sum_reducer.py | 13 ++- .../reducers/threshold_reducer.py | 38 +++---- .../utils/logging_presets.py | 2 +- tests/reducers/test_class_weighted_reducer.py | 104 ++++++++++-------- tests/reducers/test_divisor_reducer.py | 3 +- tests/reducers/test_multiple_reducers.py | 3 +- tests/reducers/test_sum_reducer.py | 59 ++++++++++ 17 files changed, 213 insertions(+), 151 deletions(-) create mode 100644 tests/reducers/test_sum_reducer.py diff --git a/.gitignore b/.gitignore index ef79da47..16287b9a 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,8 @@ dist/ site/ venv/ .ipynb_checkpoints +**/.vscode +**/temp*_for_pytorch_metric_learning_test examples/notebooks/dataset examples/notebooks/CIFAR10_Dataset examples/notebooks/CIFAR100_Dataset diff --git a/src/pytorch_metric_learning/losses/margin_loss.py b/src/pytorch_metric_learning/losses/margin_loss.py index af834846..7109b4e5 100644 --- a/src/pytorch_metric_learning/losses/margin_loss.py +++ b/src/pytorch_metric_learning/losses/margin_loss.py @@ -36,7 +36,7 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): if len(anchor_idx) == 0: return self.zero_losses() - beta = self.beta if len(self.beta) == 1 else self.beta[labels[anchor_idx]] + beta = self.beta if len(self.beta) == 1 else self.beta[labels[anchor_idx].to("cpu")] # When labels are on gpu gives error beta = c_f.to_device(beta, device=embeddings.device, dtype=embeddings.dtype) mat = self.distance(embeddings, ref_emb) diff --git a/src/pytorch_metric_learning/reducers/avg_non_zero_reducer.py b/src/pytorch_metric_learning/reducers/avg_non_zero_reducer.py index 916daa3b..f42ab20b 100644 --- a/src/pytorch_metric_learning/reducers/avg_non_zero_reducer.py +++ b/src/pytorch_metric_learning/reducers/avg_non_zero_reducer.py @@ -2,5 +2,6 @@ class AvgNonZeroReducer(ThresholdReducer): + """Equivalent to ThresholdReducer with `low=0`""" def __init__(self, **kwargs): super().__init__(low=0, **kwargs) diff --git a/src/pytorch_metric_learning/reducers/base_reducer.py b/src/pytorch_metric_learning/reducers/base_reducer.py index 91d2da25..4213eaab 100644 --- a/src/pytorch_metric_learning/reducers/base_reducer.py +++ b/src/pytorch_metric_learning/reducers/base_reducer.py @@ -15,7 +15,7 @@ def forward(self, loss_dict, embeddings, labels): loss_name = list(loss_dict.keys())[0] loss_info = loss_dict[loss_name] losses, loss_indices, reduction_type, kwargs = self.unpack_loss_info(loss_info) - loss_val = self.reduce_the_loss( + loss_val = self.reduce_loss( # Similar to compute_loss losses, loss_indices, reduction_type, kwargs, embeddings, labels ) return loss_val @@ -28,7 +28,7 @@ def unpack_loss_info(self, loss_info): {}, ) - def reduce_the_loss( + def reduce_loss( # Similar to compute_loss self, losses, loss_indices, reduction_type, kwargs, embeddings, labels ): self.set_losses_size_stat(losses) diff --git a/src/pytorch_metric_learning/reducers/class_weighted_reducer.py b/src/pytorch_metric_learning/reducers/class_weighted_reducer.py index ffa4c9d4..39fa14bf 100644 --- a/src/pytorch_metric_learning/reducers/class_weighted_reducer.py +++ b/src/pytorch_metric_learning/reducers/class_weighted_reducer.py @@ -1,28 +1,16 @@ -import torch - from ..utils import common_functions as c_f -from .base_reducer import BaseReducer +from .threshold_reducer import ThresholdReducer -class ClassWeightedReducer(BaseReducer): +class ClassWeightedReducer(ThresholdReducer): + """It weights the losses with user-specified weights and then takes the average. + + Subclass of ThresholdReducer, therefore it is possible to specify `low` and `high` hyperparameters.""" def __init__(self, weights, **kwargs): super().__init__(**kwargs) self.weights = weights def element_reduction(self, losses, loss_indices, embeddings, labels): - return self.element_reduction_helper(losses, loss_indices, labels) - - def pos_pair_reduction(self, losses, loss_indices, embeddings, labels): - return self.element_reduction_helper(losses, loss_indices[0], labels) - - # based on anchor label - def neg_pair_reduction(self, losses, loss_indices, embeddings, labels): - return self.element_reduction_helper(losses, loss_indices[0], labels) - - # based on anchor label - def triplet_reduction(self, losses, loss_indices, embeddings, labels): - return self.element_reduction_helper(losses, loss_indices[0], labels) - - def element_reduction_helper(self, losses, indices, labels): self.weights = c_f.to_device(self.weights, losses, dtype=losses.dtype) - return torch.mean(losses * self.weights[labels[indices]]) + losses = losses * self.weights[labels[loss_indices]] + return super().element_reduction(losses, loss_indices, embeddings, labels) diff --git a/src/pytorch_metric_learning/reducers/divisor_reducer.py b/src/pytorch_metric_learning/reducers/divisor_reducer.py index 784edd07..f0b0d099 100644 --- a/src/pytorch_metric_learning/reducers/divisor_reducer.py +++ b/src/pytorch_metric_learning/reducers/divisor_reducer.py @@ -7,28 +7,23 @@ class DivisorReducer(BaseReducer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.divisor = 1 self.add_to_recordable_attributes(name="divisor", is_stat=True) def unpack_loss_info(self, loss_info): - losses, loss_indices, reduction_type, kwargs = super().unpack_loss_info( - loss_info - ) - if reduction_type != "already_reduced": - kwargs = {"divisor": loss_info["divisor"]} - self.divisor = kwargs["divisor"] - return losses, loss_indices, reduction_type, kwargs - - def sum_and_divide(self, losses, embeddings, divisor): - if divisor != 0: - output = torch.sum(losses) / divisor - if torch.isnan(output) and losses.dtype == torch.float16: - output = torch.sum(c_f.to_dtype(losses, dtype=torch.float32)) / divisor - output = c_f.to_dtype(output, dtype=torch.float16) + if loss_info["reduction_type"] != "already_reduced": + self.divisor = loss_info["divisor"] + return super().unpack_loss_info(loss_info) + + def sum_and_divide(self, losses, embeddings): + if self.divisor != 0: + output = torch.sum(losses.float()) / self.divisor + output = c_f.to_dtype(output, tensor=losses) return output return self.zero_loss(embeddings) - def element_reduction(self, losses, loss_indices, embeddings, labels, divisor=1): - return self.sum_and_divide(losses, embeddings, divisor) + def element_reduction(self, losses, loss_indices, embeddings, labels): + return self.sum_and_divide(losses, embeddings) def pos_pair_reduction(self, *args, **kwargs): return self.element_reduction(*args, **kwargs) diff --git a/src/pytorch_metric_learning/reducers/do_nothing_reducer.py b/src/pytorch_metric_learning/reducers/do_nothing_reducer.py index 412f50dc..de8ce795 100644 --- a/src/pytorch_metric_learning/reducers/do_nothing_reducer.py +++ b/src/pytorch_metric_learning/reducers/do_nothing_reducer.py @@ -2,5 +2,5 @@ class DoNothingReducer(BaseReducer): - def forward(self, loss_dict, embeddings, labels): + def forward(self, loss_dict, *_): return loss_dict diff --git a/src/pytorch_metric_learning/reducers/mean_reducer.py b/src/pytorch_metric_learning/reducers/mean_reducer.py index e1d1dec1..7aaae32d 100644 --- a/src/pytorch_metric_learning/reducers/mean_reducer.py +++ b/src/pytorch_metric_learning/reducers/mean_reducer.py @@ -1,17 +1,13 @@ -import torch +from .threshold_reducer import ThresholdReducer +import numpy as np -from .base_reducer import BaseReducer +class MeanReducer(ThresholdReducer): + """Equivalent to ThresholdReducer with default parameters. + + Any element is accepted""" + def __init__(self, **kwargs): + kwargs["low"] = -np.inf + kwargs["high"] = np.inf + super().__init__(**kwargs) -class MeanReducer(BaseReducer): - def element_reduction(self, losses, *_): - return torch.mean(losses) - - def pos_pair_reduction(self, losses, *args): - return self.element_reduction(losses, *args) - - def neg_pair_reduction(self, losses, *args): - return self.element_reduction(losses, *args) - - def triplet_reduction(self, losses, *args): - return self.element_reduction(losses, *args) diff --git a/src/pytorch_metric_learning/reducers/multiple_reducers.py b/src/pytorch_metric_learning/reducers/multiple_reducers.py index b2d986c4..1f1ffbcf 100644 --- a/src/pytorch_metric_learning/reducers/multiple_reducers.py +++ b/src/pytorch_metric_learning/reducers/multiple_reducers.py @@ -2,30 +2,32 @@ from .base_reducer import BaseReducer from .mean_reducer import MeanReducer +from collections import defaultdict + +class DefaultModuleDict(torch.nn.ModuleDict): + def __init__(self, module_factory, modules): + torch.nn.ModuleDict.__init__(self, modules) + self._modules = defaultdict(module_factory, self._modules) class MultipleReducers(BaseReducer): - def __init__(self, reducers, default_reducer=None, **kwargs): + def __init__(self, reducers, default_reducer: BaseReducer = None, **kwargs): super().__init__(**kwargs) - self.reducers = torch.nn.ModuleDict(reducers) - self.default_reducer = ( - MeanReducer() if default_reducer is None else default_reducer - ) + reducer_type = MeanReducer if default_reducer is None else type(default_reducer) + self.reducers = DefaultModuleDict(module_factory=lambda : reducer_type(), + modules=reducers) def forward(self, loss_dict, embeddings, labels): self.reset_stats() sub_losses = torch.zeros( len(loss_dict), dtype=embeddings.dtype, device=embeddings.device ) - loss_count = 0 - for loss_name, loss_info in loss_dict.items(): + + for loss_count, (loss_name, loss_info) in enumerate(loss_dict.items()): input_dict = {loss_name: loss_info} - if loss_name in self.reducers: - loss_val = self.reducers[loss_name](input_dict, embeddings, labels) - else: - loss_val = self.default_reducer(input_dict, embeddings, labels) + loss_val = self.reducers[loss_name](input_dict, embeddings, labels) sub_losses[loss_count] = loss_val - loss_count += 1 + return self.sub_loss_reduction(sub_losses, embeddings, labels) def sub_loss_reduction(self, sub_losses, embeddings=None, labels=None): diff --git a/src/pytorch_metric_learning/reducers/per_anchor_reducer.py b/src/pytorch_metric_learning/reducers/per_anchor_reducer.py index 504350f3..18e2e115 100644 --- a/src/pytorch_metric_learning/reducers/per_anchor_reducer.py +++ b/src/pytorch_metric_learning/reducers/per_anchor_reducer.py @@ -30,19 +30,23 @@ def element_reduction(self, losses, loss_indices, embeddings, labels): def tuple_reduction_helper(self, losses, loss_indices, embeddings, labels): batch_size = embeddings.shape[0] - device, dtype = losses.device, losses.dtype - new_array = torch.zeros(batch_size, batch_size, device=device, dtype=dtype) - pos_inf = c_f.pos_inf(dtype) - new_array += pos_inf - anchors, others = loss_indices - new_array[anchors, others] = losses - pos_inf_mask = new_array == pos_inf - num_inf = torch.sum(pos_inf_mask, dim=1) - new_array[pos_inf_mask] = 0 - num_per_row = batch_size - num_inf - output = self.aggregation_func(new_array, num_per_row) + # Prepare tensors for results + anchors = c_f.to_device(anchors, tensor=losses) + others = c_f.to_device(others, tensor=losses) + output = c_f.to_device(torch.zeros(batch_size, batch_size), tensor=losses, dtype=losses.dtype) + num_per_row = c_f.to_device(torch.zeros(batch_size), tensor=losses, dtype=torch.long) # Remember to fuse in an unique call to to_device when to_device will accept list inputs + + # Insert loss values in corresponence of anchor-embedding + output[anchors, others] = losses + + # Calculate the count of 'others' for each unique anchor + # Equivalent to: 'num_per_row[anchors[i]] += 1' for every i + num_per_row = num_per_row.scatter_add_(0, anchors, torch.ones_like(anchors, device=anchors.device)) + + # Aggregate results + output = self.aggregation_func(output, num_per_row) loss_dict = { "loss": { @@ -59,5 +63,5 @@ def pos_pair_reduction(self, *args, **kwargs): def neg_pair_reduction(self, *args, **kwargs): return self.tuple_reduction_helper(*args, **kwargs) - def triplet_reduction(self, *args, **kwargs): + def triplet_reduction(self, *_): # Explicitly indicate hyperparameters are ignored raise NotImplementedError("Triplet reduction not supported") diff --git a/src/pytorch_metric_learning/reducers/sum_reducer.py b/src/pytorch_metric_learning/reducers/sum_reducer.py index dec7e2a6..a9d0b9c2 100644 --- a/src/pytorch_metric_learning/reducers/sum_reducer.py +++ b/src/pytorch_metric_learning/reducers/sum_reducer.py @@ -1,8 +1,15 @@ import torch -from pytorch_metric_learning.reducers import MeanReducer +from .threshold_reducer import ThresholdReducer -class SumReducer(MeanReducer): +class SumReducer(ThresholdReducer): + """It reduces the losses by summing up all the values. + + Subclass of ThresholdReducer, therefore it is possible to specify `low` and `high` hyperparameters.""" + def __init__(self, **kwargs): + kwargs["collect_stats"] = True + super().__init__(**kwargs) + def element_reduction(self, losses, *_): - return torch.sum(losses) + return super().element_reduction(losses, *_) * self.num_past_filter diff --git a/src/pytorch_metric_learning/reducers/threshold_reducer.py b/src/pytorch_metric_learning/reducers/threshold_reducer.py index 245f7cd7..a57422b1 100644 --- a/src/pytorch_metric_learning/reducers/threshold_reducer.py +++ b/src/pytorch_metric_learning/reducers/threshold_reducer.py @@ -1,16 +1,13 @@ import torch - +import numpy as np from .base_reducer import BaseReducer class ThresholdReducer(BaseReducer): - def __init__(self, low=None, high=None, **kwargs): + def __init__(self, low=-np.inf, high=np.inf, **kwargs): super().__init__(**kwargs) - assert (low is not None) or ( - high is not None - ), "At least one of low or high must be specified" - self.low = low - self.high = high + self.low = low if low is not None else -np.inf # Since there is no None default value it could be better to exclude testing for low=None in test_treshold_reducer + self.high = high if high is not None else np.inf # Since there is no None default value it could be better to exclude testing for high=None in test_treshold_reducer self.add_to_recordable_attributes(list_of_names=["low", "high"], is_stat=False) self.add_to_recordable_attributes( list_of_names=["num_past_filter", "num_above_low", "num_below_high"], @@ -21,22 +18,19 @@ def element_reduction(self, losses, loss_indices, embeddings, labels): return self.element_reduction_helper(losses, embeddings) def pos_pair_reduction(self, losses, loss_indices, embeddings, labels): - return self.element_reduction_helper(losses, embeddings) + return self.element_reduction(losses, loss_indices[0], embeddings, labels) def neg_pair_reduction(self, losses, loss_indices, embeddings, labels): - return self.element_reduction_helper(losses, embeddings) + return self.element_reduction(losses, loss_indices[0], embeddings, labels) def triplet_reduction(self, losses, loss_indices, embeddings, labels): - return self.element_reduction_helper(losses, embeddings) + return self.element_reduction(losses, loss_indices[0], embeddings, labels) def element_reduction_helper(self, losses, embeddings): - low_condition, high_condition = None, None - if self.low is not None: - low_condition = losses > self.low - losses = losses[low_condition] - if self.high is not None: - high_condition = losses < self.high - losses = losses[high_condition] + low_condition = losses > self.low + high_condition = losses < self.high + losses = losses[low_condition & high_condition] + num_past_filter = len(losses) if num_past_filter >= 1: loss = torch.mean(losses) @@ -45,11 +39,11 @@ def element_reduction_helper(self, losses, embeddings): self.set_stats(low_condition, high_condition, num_past_filter) return loss + @torch.no_grad() def set_stats(self, low_condition, high_condition, num_past_filter): if self.collect_stats: self.num_past_filter = num_past_filter - with torch.no_grad(): - if self.low is not None: - self.num_above_low = torch.sum(low_condition).item() - if self.high is not None: - self.num_above_high = torch.sum(high_condition).item() + if np.isfinite(self.low): # Why record this only if it was not None? + self.num_above_low = torch.sum(low_condition).item() + if np.isfinite(self.high): # Why record this only if it was not None? + self.num_above_high = torch.sum(high_condition).item() diff --git a/src/pytorch_metric_learning/utils/logging_presets.py b/src/pytorch_metric_learning/utils/logging_presets.py index 4ab92ab1..5d4f2a56 100644 --- a/src/pytorch_metric_learning/utils/logging_presets.py +++ b/src/pytorch_metric_learning/utils/logging_presets.py @@ -379,7 +379,7 @@ def optimizer_custom_attr_func(self, optimizer): class EmptyContainer: - def end_of_epoch_hook(self, *args): + def end_of_epoch_hook(self, *_, **__): # Gives "TypeError: EmptyContainer.end_of_epoch_hook() got an unexpected keyword argument 'test_interval'" return None end_of_iteration_hook = None diff --git a/tests/reducers/test_class_weighted_reducer.py b/tests/reducers/test_class_weighted_reducer.py index 9a6d728a..73c3e91a 100644 --- a/tests/reducers/test_class_weighted_reducer.py +++ b/tests/reducers/test_class_weighted_reducer.py @@ -4,56 +4,68 @@ from pytorch_metric_learning.reducers import ClassWeightedReducer -from .. import TEST_DEVICE, TEST_DTYPES +from .. import TEST_DEVICE, TEST_DTYPES, WITH_COLLECT_STATS class TestClassWeightedReducer(unittest.TestCase): - def test_class_weighted_reducer(self): - torch.manual_seed(99114) + def test_class_weighted_reducer_with_threshold(self): + torch.manual_seed(99115) class_weights = torch.tensor([1, 0.9, 1, 0.1, 0, 0, 0, 0, 0, 0]) - for dtype in TEST_DTYPES: - reducer = ClassWeightedReducer(class_weights) + for low_threshold, high_threshold in [(None, None), (0.1, None), (None, 0.2), (0.1, 0.2)]: + reducer = ClassWeightedReducer(class_weights, low=low_threshold, high=high_threshold) batch_size = 100 - num_classes = 10 embedding_size = 64 - embeddings = ( - torch.randn(batch_size, embedding_size).type(dtype).to(TEST_DEVICE) - ) - labels = torch.randint(0, num_classes, (batch_size,)) - pair_indices = ( - torch.randint(0, batch_size, (batch_size,)), - torch.randint(0, batch_size, (batch_size,)), - ) - triplet_indices = pair_indices + ( - torch.randint(0, batch_size, (batch_size,)), - ) - losses = torch.randn(batch_size).type(dtype).to(TEST_DEVICE) + for dtype in TEST_DTYPES: + embeddings = ( + torch.randn(batch_size, embedding_size).type(dtype).to(TEST_DEVICE) + ) + labels = torch.randint(0, 10, (batch_size,)) + pair_indices = ( + torch.randint(0, batch_size, (batch_size,)), + torch.randint(0, batch_size, (batch_size,)), + ) + triplet_indices = pair_indices + ( + torch.randint(0, batch_size, (batch_size,)), + ) + losses = torch.randn(batch_size).type(dtype).to(TEST_DEVICE) + zero_losses = torch.zeros(batch_size).type(dtype).to(TEST_DEVICE) - for indices, reduction_type in [ - (torch.arange(batch_size), "element"), - (pair_indices, "pos_pair"), - (pair_indices, "neg_pair"), - (triplet_indices, "triplet"), - ]: - loss_dict = { - "loss": { - "losses": losses, - "indices": indices, - "reduction_type": reduction_type, - } - } - output = reducer(loss_dict, embeddings, labels) - correct_output = 0 - for i in range(len(losses)): - if reduction_type == "element": - batch_idx = indices[i] - else: - batch_idx = indices[0][i] - class_label = labels[batch_idx] - correct_output += ( - losses[i] - * class_weights.type(dtype).to(TEST_DEVICE)[class_label] - ) - correct_output /= len(losses) - rtol = 1e-2 if dtype == torch.float16 else 1e-5 - self.assertTrue(torch.isclose(output, correct_output, rtol=rtol)) + for indices, reduction_type in [ + (torch.arange(batch_size), "element"), + (pair_indices, "pos_pair"), + (pair_indices, "neg_pair"), + (triplet_indices, "triplet"), + ]: + for L in [losses, zero_losses]: + loss_dict = { + "loss": { + "losses": L, + "indices": indices, + "reduction_type": reduction_type, + } + } + output = reducer(loss_dict, embeddings, labels) + if low_threshold is not None: + L = L[L > low_threshold] + if high_threshold is not None: + L = L[L < high_threshold] + if len(L) > 0: + correct_output = 0 + for i in range(len(L)): + if reduction_type == "element": + batch_idx = indices[i] + else: + batch_idx = indices[0][i] + class_label = labels[batch_idx] + correct_output += ( + L[i] + * class_weights.type(dtype).to(TEST_DEVICE)[class_label] + ) + correct_output /= len(L) + else: + correct_output = 0 + rtol = 1e-2 if dtype == torch.float16 else 1e-5 + self.assertTrue(torch.isclose(output, correct_output, rtol=rtol)) + + if WITH_COLLECT_STATS: + self.assertTrue(reducer.num_past_filter == len(L)) \ No newline at end of file diff --git a/tests/reducers/test_divisor_reducer.py b/tests/reducers/test_divisor_reducer.py index 7131562c..35222497 100644 --- a/tests/reducers/test_divisor_reducer.py +++ b/tests/reducers/test_divisor_reducer.py @@ -53,7 +53,8 @@ def test_divisor_reducer(self): correct_output = torch.sum(L) * 0 else: correct_output = torch.sum(L) / (32 + 15) - self.assertTrue(output == correct_output) + rtol = 1e-2 if dtype == torch.float16 else 1e-5 + self.assertTrue(torch.isclose(output, correct_output, rtol=rtol)) loss_dict = { "loss": { diff --git a/tests/reducers/test_multiple_reducers.py b/tests/reducers/test_multiple_reducers.py index cb982c6f..bc76a1c6 100644 --- a/tests/reducers/test_multiple_reducers.py +++ b/tests/reducers/test_multiple_reducers.py @@ -56,4 +56,5 @@ def test_multiple_reducers(self): correct_output = (torch.mean(lossesA[lossesA > 0])) + ( torch.sum(lossesB) / (32 + 15) ) - self.assertTrue(output == correct_output) + rtol = 1e-2 if dtype == torch.float16 else 1e-5 + self.assertTrue(torch.isclose(output, correct_output, rtol=rtol)) diff --git a/tests/reducers/test_sum_reducer.py b/tests/reducers/test_sum_reducer.py new file mode 100644 index 00000000..159989fa --- /dev/null +++ b/tests/reducers/test_sum_reducer.py @@ -0,0 +1,59 @@ +import unittest + +import torch + +from pytorch_metric_learning.reducers import SumReducer + +from .. import TEST_DEVICE, TEST_DTYPES, WITH_COLLECT_STATS + + +class TestSumReducer(unittest.TestCase): + def test_sum_reducer_with_thresholds(self): + torch.manual_seed(99115) + for low_threshold, high_threshold in [(None, None), (0.1, None), (None, 0.2), (0.1, 0.2)]: + reducer = SumReducer(low=low_threshold, high=high_threshold) + batch_size = 100 + embedding_size = 64 + for dtype in TEST_DTYPES: + embeddings = ( + torch.randn(batch_size, embedding_size).type(dtype).to(TEST_DEVICE) + ) + labels = torch.randint(0, 10, (batch_size,)) + pair_indices = ( + torch.randint(0, batch_size, (batch_size,)), + torch.randint(0, batch_size, (batch_size,)), + ) + triplet_indices = pair_indices + ( + torch.randint(0, batch_size, (batch_size,)), + ) + losses = torch.randn(batch_size).type(dtype).to(TEST_DEVICE) + zero_losses = torch.zeros(batch_size).type(dtype).to(TEST_DEVICE) + + for indices, reduction_type in [ + (torch.arange(batch_size), "element"), + (pair_indices, "pos_pair"), + (pair_indices, "neg_pair"), + (triplet_indices, "triplet"), + ]: + for L in [losses, zero_losses]: + loss_dict = { + "loss": { + "losses": L, + "indices": indices, + "reduction_type": reduction_type, + } + } + output = reducer(loss_dict, embeddings, labels) + if low_threshold is not None: + L = L[L > low_threshold] + if high_threshold is not None: + L = L[L < high_threshold] + if len(L) > 0: + correct_output = torch.sum(L, dtype=dtype) + else: + correct_output = torch.zeros(1, dtype=dtype, device=TEST_DEVICE) + rtol = 1e-2 if dtype == torch.float16 else 1e-5 + self.assertTrue(torch.isclose(output, correct_output, rtol=rtol)) + + if WITH_COLLECT_STATS: + self.assertTrue(reducer.num_past_filter == len(L)) \ No newline at end of file From 056e396c40e507c62249c77fa0348aaea1fb69b7 Mon Sep 17 00:00:00 2001 From: Domenico Muscillo Date: Sun, 9 Jul 2023 17:45:53 +0200 Subject: [PATCH 2/2] Refactored mixins.py and formatted code --- docs/losses.md | 30 ++--- .../losses/base_metric_loss_function.py | 30 +++-- .../losses/margin_loss.py | 6 +- src/pytorch_metric_learning/losses/mixins.py | 103 +++++++++--------- .../losses/vicreg_loss.py | 2 +- .../reducers/avg_non_zero_reducer.py | 1 + .../reducers/base_reducer.py | 4 +- .../reducers/class_weighted_reducer.py | 6 +- .../reducers/divisor_reducer.py | 2 +- .../reducers/mean_reducer.py | 7 +- .../reducers/multiple_reducers.py | 9 +- .../reducers/per_anchor_reducer.py | 14 ++- .../reducers/sum_reducer.py | 8 +- .../reducers/threshold_reducer.py | 17 ++- .../utils/logging_presets.py | 4 +- .../utils/module_with_records.py | 4 +- tests/losses/test_npairs_loss.py | 4 +- .../test_signal_to_noise_ratio_losses.py | 21 ++-- tests/losses/test_soft_triple_loss.py | 8 +- tests/reducers/test_class_weighted_reducer.py | 21 +++- tests/reducers/test_sum_reducer.py | 17 ++- .../test_center_invariant_regularizer.py | 4 +- .../test_regular_face_regularizer.py | 4 +- 23 files changed, 184 insertions(+), 142 deletions(-) diff --git a/docs/losses.md b/docs/losses.md index 2af1f905..212d9290 100644 --- a/docs/losses.md +++ b/docs/losses.md @@ -111,7 +111,7 @@ losses.ArcFaceLoss(num_classes, embedding_size, margin=28.6, scale=64, **kwargs) **Other info**: -* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments. +* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments. * This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example: ```python loss_func = losses.ArcFaceLoss(...).to(torch.device('cuda')) @@ -141,8 +141,8 @@ All loss functions extend this class and therefore inherit its ```__init__``` pa losses.BaseMetricLossFunction(collect_stats = False, reducer = None, distance = None, - embedding_regularizer = None, - embedding_reg_weight = 1) + regularizer = None, + reg_weight = 1) ``` **Parameters**: @@ -150,8 +150,8 @@ losses.BaseMetricLossFunction(collect_stats = False, * **collect_stats**: If True, will collect various statistics that may be useful to analyze during experiments. If False, these computations will be skipped. Want to make ```True``` the default? Set the global [COLLECT_STATS](common_functions.md#collect_stats) flag. * **reducer**: A [reducer](reducers.md) object. If None, then the default reducer will be used. * **distance**: A [distance](distances.md) object. If None, then the default distance will be used. -* **embedding_regularizer**: A [regularizer](regularizers.md) object that will be applied to embeddings. If None, then no embedding regularization will be used. -* **embedding_reg_weight**: If an embedding regularizer is used, then its loss will be multiplied by this amount before being added to the total loss. +* **regularizer**: A [regularizer](regularizers.md) object that will be applied to embeddings. If None, then no embedding regularization will be used. +* **reg_weight**: If an embedding regularizer is used, then its loss will be multiplied by this amount before being added to the total loss. **Default distance**: @@ -273,7 +273,7 @@ losses.CosFaceLoss(num_classes, embedding_size, margin=0.35, scale=64, **kwargs) **Other info**: -* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments. +* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments. * This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example: ```python loss_func = losses.CosFaceLoss(...).to(torch.device('cuda')) @@ -491,7 +491,7 @@ where **Other info**: -* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments. +* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments. * This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example: ```python loss_func = losses.LargeMarginSoftmaxLoss(...).to(torch.device('cuda')) @@ -737,7 +737,7 @@ losses.NormalizedSoftmaxLoss(num_classes, embedding_size, temperature=0.05, **kw **Other info** -* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments. +* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments. * This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example: ```python loss_func = losses.NormalizedSoftmaxLoss(...).to(torch.device('cuda')) @@ -870,7 +870,7 @@ losses.ProxyAnchorLoss(num_classes, embedding_size, margin = 0.1, alpha = 32, ** **Other info** -* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments. +* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments. * This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example: ```python loss_func = losses.ProxyAnchorLoss(...).to(torch.device('cuda')) @@ -907,7 +907,7 @@ losses.ProxyNCALoss(num_classes, embedding_size, softmax_scale=1, **kwargs) **Other info** -* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments. +* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments. * This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example: ```python loss_func = losses.ProxyNCALoss(...).to(torch.device('cuda')) @@ -1027,7 +1027,7 @@ where **Other info** -* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments. +* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments. * This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example: ```python loss_func = losses.SoftTripleLoss(...).to(torch.device('cuda')) @@ -1068,7 +1068,7 @@ See [LargeMarginSoftmaxLoss](losses.md#largemarginsoftmaxloss) **Other info** -* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments. +* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments. * This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example: ```python loss_func = losses.SphereFaceLoss(...).to(torch.device('cuda')) @@ -1230,14 +1230,14 @@ complete_loss = losses.MultipleLosses([main_loss, var_loss], weights=[1, 0.5]) ## WeightRegularizerMixin Losses can extend this class in addition to BaseMetricLossFunction. You should extend this class if your loss function contains a learnable weight matrix. ```python -losses.WeightRegularizerMixin(weight_init_func=None, weight_regularizer=None, weight_reg_weight=1, **kwargs) +losses.WeightRegularizerMixin(weight_init_func=None, regularizer=None, reg_weight=1, **kwargs) ``` **Parameters**: * **weight_init_func**: An [TorchInitWrapper](common_functions.md#torchinitwrapper) object, which will be used to initialize the weights of the loss function. -* **weight_regularizer**: The [regularizer](regularizers.md) to apply to the loss's learned weights. -* **weight_reg_weight**: The amount the regularization loss will be multiplied by. +* **regularizer**: The [regularizer](regularizers.md) to apply to the loss's learned weights. +* **reg_weight**: The amount the regularization loss will be multiplied by. Extended by: diff --git a/src/pytorch_metric_learning/losses/base_metric_loss_function.py b/src/pytorch_metric_learning/losses/base_metric_loss_function.py index cc09fe77..012a8557 100644 --- a/src/pytorch_metric_learning/losses/base_metric_loss_function.py +++ b/src/pytorch_metric_learning/losses/base_metric_loss_function.py @@ -1,13 +1,19 @@ import inspect +import re from ..utils import common_functions as c_f from ..utils.module_with_records_and_reducer import ModuleWithRecordsReducerAndDistance +from . import mixins from .mixins import EmbeddingRegularizerMixin -class BaseMetricLossFunction( - EmbeddingRegularizerMixin, ModuleWithRecordsReducerAndDistance -): +class BaseMetricLossFunction(ModuleWithRecordsReducerAndDistance): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.emb_loss_regularizer = EmbeddingRegularizerMixin( + **kwargs + ) # Avoid multiple inheritance errors. In this way if a loss function inherits from a RegularizerMixin subclass it does not affect the mro + def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): """ This has to be implemented and is what actually computes the loss. @@ -34,7 +40,9 @@ def forward( loss_dict = self.compute_loss( embeddings, labels, indices_tuple, ref_emb, ref_labels ) - self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings) + self.emb_loss_regularizer.add_embedding_regularization_to_loss_dict( + loss_dict, embeddings + ) return self.reducer(loss_dict, embeddings, labels) def zero_loss(self): @@ -50,12 +58,10 @@ def sub_loss_names(self): return self._sub_loss_names() + self.all_regularization_loss_names() def all_regularization_loss_names(self): - reg_names = [] + reg_loss_names = [] for base_class in inspect.getmro(self.__class__): - base_class_name = base_class.__name__ - mixin_keyword = "RegularizerMixin" - if base_class_name.endswith(mixin_keyword): - descriptor = base_class_name.replace(mixin_keyword, "").lower() - if getattr(self, "{}_regularizer".format(descriptor)): - reg_names.extend(base_class.regularization_loss_names(self)) - return reg_names + if base_class.__module__ == mixins.__name__: + m = re.search(r"(\w+)RegularizerMixin", base_class.__name__) + if m is not None: + reg_loss_names.append(m.group(1).lower()) + return reg_loss_names diff --git a/src/pytorch_metric_learning/losses/margin_loss.py b/src/pytorch_metric_learning/losses/margin_loss.py index 7109b4e5..02af9f00 100644 --- a/src/pytorch_metric_learning/losses/margin_loss.py +++ b/src/pytorch_metric_learning/losses/margin_loss.py @@ -36,7 +36,11 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels): if len(anchor_idx) == 0: return self.zero_losses() - beta = self.beta if len(self.beta) == 1 else self.beta[labels[anchor_idx].to("cpu")] # When labels are on gpu gives error + beta = ( + self.beta + if len(self.beta) == 1 + else self.beta[labels[anchor_idx].to("cpu")] + ) # When labels are on gpu gives error beta = c_f.to_device(beta, device=embeddings.device, dtype=embeddings.dtype) mat = self.distance(embeddings, ref_emb) diff --git a/src/pytorch_metric_learning/losses/mixins.py b/src/pytorch_metric_learning/losses/mixins.py index 374f98f7..d607da7c 100644 --- a/src/pytorch_metric_learning/losses/mixins.py +++ b/src/pytorch_metric_learning/losses/mixins.py @@ -1,72 +1,67 @@ +from typing import Dict + import torch from ..utils import common_functions as c_f +SUPPORTED_REGULARIZATION_TYPES = ["custom", "weight", "embedding"] -class WeightMixin: - def __init__(self, weight_init_func=None, **kwargs): - super().__init__(**kwargs) - self.weight_init_func = weight_init_func - if self.weight_init_func is None: - self.weight_init_func = self.get_default_weight_init_func() - def get_default_weight_init_func(self): - return c_f.TorchInitWrapper(torch.nn.init.normal_) +class RegularizerMixin: + """Base class for regularization losses. + regularizer: function-like object or `nn.Module` that transforms input data into single number or single-element `torch.Tensor` + """ - -class WeightRegularizerMixin(WeightMixin): - def __init__(self, weight_regularizer=None, weight_reg_weight=1, **kwargs): - self.weight_regularizer = ( - weight_regularizer is not None - ) # hack needed to know whether reg will be in sub-loss names - super().__init__(**kwargs) - self.weight_regularizer = weight_regularizer - self.weight_reg_weight = weight_reg_weight - if self.weight_regularizer is not None: + def __init__(self, regularizer=None, reg_weight=1, type="custom", **kwargs): + self.check_type(type) + self.regularizer = regularizer if regularizer is not None else (lambda data: 0) + self.reg_weight = reg_weight + if regularizer is not None: self.add_to_recordable_attributes( - list_of_names=["weight_reg_weight"], is_stat=False + list_of_names=[f"{type}_reg_weight"], is_stat=False ) - def weight_regularization_loss(self, weights): - if self.weight_regularizer is None: - loss = 0 - else: - loss = self.weight_regularizer(weights) * self.weight_reg_weight - return {"losses": loss, "indices": None, "reduction_type": "already_reduced"} + def regularization_loss(self, data): + loss = self.regularizer(data) * self.reg_weight + return loss + + def add_regularization_to_loss_dict(self, loss_dict: Dict[str, Dict], data): + loss_dict[self.reg_loss_type] = { + "losses": self.regularization_loss(data), + "indices": None, + "reduction_type": "already_reduced", + } + + def check_type(self, type: str): + if type not in SUPPORTED_REGULARIZATION_TYPES: + raise ValueError( + f"Type provided not supported. Supported types are {', '.join(SUPPORTED_REGULARIZATION_TYPES)}, given type is {type}." + ) + self.reg_loss_type = f"{type}_reg_loss" - def add_weight_regularization_to_loss_dict(self, loss_dict, weights): - if self.weight_regularizer is not None: - loss_dict["weight_reg_loss"] = self.weight_regularization_loss(weights) - def regularization_loss_names(self): - return ["weight_reg_loss"] +def get_default_weight_init_func(): + return c_f.TorchInitWrapper(torch.nn.init.normal_) -class EmbeddingRegularizerMixin: - def __init__(self, embedding_regularizer=None, embedding_reg_weight=1, **kwargs): - self.embedding_regularizer = ( - embedding_regularizer is not None - ) # hack needed to know whether reg will be in sub-loss names +class WeightRegularizerMixin(RegularizerMixin): + def __init__(self, weight_init_func=None, **kwargs): + kwargs["type"] = "weight" super().__init__(**kwargs) - self.embedding_regularizer = embedding_regularizer - self.embedding_reg_weight = embedding_reg_weight - if self.embedding_regularizer is not None: - self.add_to_recordable_attributes( - list_of_names=["embedding_reg_weight"], is_stat=False - ) + self.weight_init_func = ( + weight_init_func + if weight_init_func is not None + else get_default_weight_init_func() + ) - def embedding_regularization_loss(self, embeddings): - if self.embedding_regularizer is None: - loss = 0 - else: - loss = self.embedding_regularizer(embeddings) * self.embedding_reg_weight - return {"losses": loss, "indices": None, "reduction_type": "already_reduced"} + def add_weight_regularization_to_loss_dict(self, loss_dict, weights): + self.add_regularization_to_loss_dict(loss_dict, weights) - def add_embedding_regularization_to_loss_dict(self, loss_dict, embeddings): - if self.embedding_regularizer is not None: - loss_dict["embedding_reg_loss"] = self.embedding_regularization_loss( - embeddings - ) - def regularization_loss_names(self): - return ["embedding_reg_loss"] +class EmbeddingRegularizerMixin(RegularizerMixin): + def __init__(self, **kwargs): + kwargs["type"] = "embedding" + super().__init__(**kwargs) + + def add_embedding_regularization_to_loss_dict(self, loss_dict, embeddings): + self.add_regularization_to_loss_dict(loss_dict, embeddings) diff --git a/src/pytorch_metric_learning/losses/vicreg_loss.py b/src/pytorch_metric_learning/losses/vicreg_loss.py index ab0b8e80..e2bb6c1b 100644 --- a/src/pytorch_metric_learning/losses/vicreg_loss.py +++ b/src/pytorch_metric_learning/losses/vicreg_loss.py @@ -11,7 +11,7 @@ def __init__( ): if "distance" in kwargs: raise ValueError("VICRegLoss cannot use a distance function") - if "embedding_regularizer" in kwargs: + if "regularizer" in kwargs: raise ValueError("VICRegLoss cannot use a regularizer") super().__init__(**kwargs) """ diff --git a/src/pytorch_metric_learning/reducers/avg_non_zero_reducer.py b/src/pytorch_metric_learning/reducers/avg_non_zero_reducer.py index f42ab20b..fa770e0c 100644 --- a/src/pytorch_metric_learning/reducers/avg_non_zero_reducer.py +++ b/src/pytorch_metric_learning/reducers/avg_non_zero_reducer.py @@ -3,5 +3,6 @@ class AvgNonZeroReducer(ThresholdReducer): """Equivalent to ThresholdReducer with `low=0`""" + def __init__(self, **kwargs): super().__init__(low=0, **kwargs) diff --git a/src/pytorch_metric_learning/reducers/base_reducer.py b/src/pytorch_metric_learning/reducers/base_reducer.py index 4213eaab..0d872f19 100644 --- a/src/pytorch_metric_learning/reducers/base_reducer.py +++ b/src/pytorch_metric_learning/reducers/base_reducer.py @@ -15,7 +15,7 @@ def forward(self, loss_dict, embeddings, labels): loss_name = list(loss_dict.keys())[0] loss_info = loss_dict[loss_name] losses, loss_indices, reduction_type, kwargs = self.unpack_loss_info(loss_info) - loss_val = self.reduce_loss( # Similar to compute_loss + loss_val = self.reduce_loss( # Similar to compute_loss losses, loss_indices, reduction_type, kwargs, embeddings, labels ) return loss_val @@ -28,7 +28,7 @@ def unpack_loss_info(self, loss_info): {}, ) - def reduce_loss( # Similar to compute_loss + def reduce_loss( # Similar to compute_loss self, losses, loss_indices, reduction_type, kwargs, embeddings, labels ): self.set_losses_size_stat(losses) diff --git a/src/pytorch_metric_learning/reducers/class_weighted_reducer.py b/src/pytorch_metric_learning/reducers/class_weighted_reducer.py index 39fa14bf..00054c9b 100644 --- a/src/pytorch_metric_learning/reducers/class_weighted_reducer.py +++ b/src/pytorch_metric_learning/reducers/class_weighted_reducer.py @@ -4,8 +4,10 @@ class ClassWeightedReducer(ThresholdReducer): """It weights the losses with user-specified weights and then takes the average. - - Subclass of ThresholdReducer, therefore it is possible to specify `low` and `high` hyperparameters.""" + + Subclass of ThresholdReducer, therefore it is possible to specify `low` and `high` hyperparameters. + """ + def __init__(self, weights, **kwargs): super().__init__(**kwargs) self.weights = weights diff --git a/src/pytorch_metric_learning/reducers/divisor_reducer.py b/src/pytorch_metric_learning/reducers/divisor_reducer.py index f0b0d099..882be8f5 100644 --- a/src/pytorch_metric_learning/reducers/divisor_reducer.py +++ b/src/pytorch_metric_learning/reducers/divisor_reducer.py @@ -14,7 +14,7 @@ def unpack_loss_info(self, loss_info): if loss_info["reduction_type"] != "already_reduced": self.divisor = loss_info["divisor"] return super().unpack_loss_info(loss_info) - + def sum_and_divide(self, losses, embeddings): if self.divisor != 0: output = torch.sum(losses.float()) / self.divisor diff --git a/src/pytorch_metric_learning/reducers/mean_reducer.py b/src/pytorch_metric_learning/reducers/mean_reducer.py index 7aaae32d..c00e1bd3 100644 --- a/src/pytorch_metric_learning/reducers/mean_reducer.py +++ b/src/pytorch_metric_learning/reducers/mean_reducer.py @@ -1,13 +1,14 @@ -from .threshold_reducer import ThresholdReducer import numpy as np +from .threshold_reducer import ThresholdReducer + class MeanReducer(ThresholdReducer): """Equivalent to ThresholdReducer with default parameters. - + Any element is accepted""" + def __init__(self, **kwargs): kwargs["low"] = -np.inf kwargs["high"] = np.inf super().__init__(**kwargs) - diff --git a/src/pytorch_metric_learning/reducers/multiple_reducers.py b/src/pytorch_metric_learning/reducers/multiple_reducers.py index 1f1ffbcf..bf197da3 100644 --- a/src/pytorch_metric_learning/reducers/multiple_reducers.py +++ b/src/pytorch_metric_learning/reducers/multiple_reducers.py @@ -1,8 +1,10 @@ +from collections import defaultdict + import torch from .base_reducer import BaseReducer from .mean_reducer import MeanReducer -from collections import defaultdict + class DefaultModuleDict(torch.nn.ModuleDict): def __init__(self, module_factory, modules): @@ -14,8 +16,9 @@ class MultipleReducers(BaseReducer): def __init__(self, reducers, default_reducer: BaseReducer = None, **kwargs): super().__init__(**kwargs) reducer_type = MeanReducer if default_reducer is None else type(default_reducer) - self.reducers = DefaultModuleDict(module_factory=lambda : reducer_type(), - modules=reducers) + self.reducers = DefaultModuleDict( + module_factory=lambda: reducer_type(), modules=reducers + ) def forward(self, loss_dict, embeddings, labels): self.reset_stats() diff --git a/src/pytorch_metric_learning/reducers/per_anchor_reducer.py b/src/pytorch_metric_learning/reducers/per_anchor_reducer.py index 18e2e115..38752947 100644 --- a/src/pytorch_metric_learning/reducers/per_anchor_reducer.py +++ b/src/pytorch_metric_learning/reducers/per_anchor_reducer.py @@ -35,15 +35,21 @@ def tuple_reduction_helper(self, losses, loss_indices, embeddings, labels): # Prepare tensors for results anchors = c_f.to_device(anchors, tensor=losses) others = c_f.to_device(others, tensor=losses) - output = c_f.to_device(torch.zeros(batch_size, batch_size), tensor=losses, dtype=losses.dtype) - num_per_row = c_f.to_device(torch.zeros(batch_size), tensor=losses, dtype=torch.long) # Remember to fuse in an unique call to to_device when to_device will accept list inputs + output = c_f.to_device( + torch.zeros(batch_size, batch_size), tensor=losses, dtype=losses.dtype + ) + num_per_row = c_f.to_device( + torch.zeros(batch_size), tensor=losses, dtype=torch.long + ) # Remember to fuse in an unique call to to_device when to_device will accept list inputs # Insert loss values in corresponence of anchor-embedding output[anchors, others] = losses # Calculate the count of 'others' for each unique anchor # Equivalent to: 'num_per_row[anchors[i]] += 1' for every i - num_per_row = num_per_row.scatter_add_(0, anchors, torch.ones_like(anchors, device=anchors.device)) + num_per_row = num_per_row.scatter_add_( + 0, anchors, torch.ones_like(anchors, device=anchors.device) + ) # Aggregate results output = self.aggregation_func(output, num_per_row) @@ -63,5 +69,5 @@ def pos_pair_reduction(self, *args, **kwargs): def neg_pair_reduction(self, *args, **kwargs): return self.tuple_reduction_helper(*args, **kwargs) - def triplet_reduction(self, *_): # Explicitly indicate hyperparameters are ignored + def triplet_reduction(self, *_): # Explicitly indicate hyperparameters are ignored raise NotImplementedError("Triplet reduction not supported") diff --git a/src/pytorch_metric_learning/reducers/sum_reducer.py b/src/pytorch_metric_learning/reducers/sum_reducer.py index a9d0b9c2..28cef578 100644 --- a/src/pytorch_metric_learning/reducers/sum_reducer.py +++ b/src/pytorch_metric_learning/reducers/sum_reducer.py @@ -1,12 +1,12 @@ -import torch - from .threshold_reducer import ThresholdReducer class SumReducer(ThresholdReducer): """It reduces the losses by summing up all the values. - - Subclass of ThresholdReducer, therefore it is possible to specify `low` and `high` hyperparameters.""" + + Subclass of ThresholdReducer, therefore it is possible to specify `low` and `high` hyperparameters. + """ + def __init__(self, **kwargs): kwargs["collect_stats"] = True super().__init__(**kwargs) diff --git a/src/pytorch_metric_learning/reducers/threshold_reducer.py b/src/pytorch_metric_learning/reducers/threshold_reducer.py index a57422b1..d435e13e 100644 --- a/src/pytorch_metric_learning/reducers/threshold_reducer.py +++ b/src/pytorch_metric_learning/reducers/threshold_reducer.py @@ -1,13 +1,18 @@ -import torch import numpy as np +import torch + from .base_reducer import BaseReducer class ThresholdReducer(BaseReducer): def __init__(self, low=-np.inf, high=np.inf, **kwargs): super().__init__(**kwargs) - self.low = low if low is not None else -np.inf # Since there is no None default value it could be better to exclude testing for low=None in test_treshold_reducer - self.high = high if high is not None else np.inf # Since there is no None default value it could be better to exclude testing for high=None in test_treshold_reducer + self.low = ( + low if low is not None else -np.inf + ) # Since there is no None default value it could be better to exclude testing for low=None in test_treshold_reducer + self.high = ( + high if high is not None else np.inf + ) # Since there is no None default value it could be better to exclude testing for high=None in test_treshold_reducer self.add_to_recordable_attributes(list_of_names=["low", "high"], is_stat=False) self.add_to_recordable_attributes( list_of_names=["num_past_filter", "num_above_low", "num_below_high"], @@ -30,7 +35,7 @@ def element_reduction_helper(self, losses, embeddings): low_condition = losses > self.low high_condition = losses < self.high losses = losses[low_condition & high_condition] - + num_past_filter = len(losses) if num_past_filter >= 1: loss = torch.mean(losses) @@ -43,7 +48,7 @@ def element_reduction_helper(self, losses, embeddings): def set_stats(self, low_condition, high_condition, num_past_filter): if self.collect_stats: self.num_past_filter = num_past_filter - if np.isfinite(self.low): # Why record this only if it was not None? + if np.isfinite(self.low): # Why record this only if it was not None? self.num_above_low = torch.sum(low_condition).item() - if np.isfinite(self.high): # Why record this only if it was not None? + if np.isfinite(self.high): # Why record this only if it was not None? self.num_above_high = torch.sum(high_condition).item() diff --git a/src/pytorch_metric_learning/utils/logging_presets.py b/src/pytorch_metric_learning/utils/logging_presets.py index 5d4f2a56..766c35ee 100644 --- a/src/pytorch_metric_learning/utils/logging_presets.py +++ b/src/pytorch_metric_learning/utils/logging_presets.py @@ -379,7 +379,9 @@ def optimizer_custom_attr_func(self, optimizer): class EmptyContainer: - def end_of_epoch_hook(self, *_, **__): # Gives "TypeError: EmptyContainer.end_of_epoch_hook() got an unexpected keyword argument 'test_interval'" + def end_of_epoch_hook( + self, *_, **__ + ): # Gives "TypeError: EmptyContainer.end_of_epoch_hook() got an unexpected keyword argument 'test_interval'" return None end_of_iteration_hook = None diff --git a/src/pytorch_metric_learning/utils/module_with_records.py b/src/pytorch_metric_learning/utils/module_with_records.py index 9fa039f6..368fc4ff 100644 --- a/src/pytorch_metric_learning/utils/module_with_records.py +++ b/src/pytorch_metric_learning/utils/module_with_records.py @@ -4,7 +4,9 @@ class ModuleWithRecords(torch.nn.Module): - def __init__(self, collect_stats=None): + def __init__( + self, collect_stats=None, **_ + ): # Hack needed to ignore other arguments passed from subclasses super().__init__() self.collect_stats = ( c_f.COLLECT_STATS if collect_stats is None else collect_stats diff --git a/tests/losses/test_npairs_loss.py b/tests/losses/test_npairs_loss.py index 568209cc..48d78816 100644 --- a/tests/losses/test_npairs_loss.py +++ b/tests/losses/test_npairs_loss.py @@ -12,7 +12,7 @@ class TestNPairsLoss(unittest.TestCase): def test_npairs_loss(self): loss_funcA = NPairsLoss() - loss_funcB = NPairsLoss(embedding_regularizer=LpRegularizer(power=2)) + loss_funcB = NPairsLoss(regularizer=LpRegularizer(power=2)) embedding_norm = 2.3 for dtype in TEST_DTYPES: @@ -76,7 +76,7 @@ def test_with_no_valid_pairs(self): def test_backward(self): loss_funcA = NPairsLoss() - loss_funcB = NPairsLoss(embedding_regularizer=LpRegularizer()) + loss_funcB = NPairsLoss(regularizer=LpRegularizer()) for dtype in TEST_DTYPES: for loss_func in [loss_funcA, loss_funcB]: diff --git a/tests/losses/test_signal_to_noise_ratio_losses.py b/tests/losses/test_signal_to_noise_ratio_losses.py index 806a6c9f..441e07a4 100644 --- a/tests/losses/test_signal_to_noise_ratio_losses.py +++ b/tests/losses/test_signal_to_noise_ratio_losses.py @@ -11,12 +11,12 @@ class TestSNRContrastiveLoss(unittest.TestCase): def test_snr_contrastive_loss(self): - pos_margin, neg_margin, embedding_reg_weight = 0, 0.1, 0.1 + pos_margin, neg_margin, reg_weight = 0, 0.1, 0.1 loss_func = SignalToNoiseRatioContrastiveLoss( pos_margin=pos_margin, neg_margin=neg_margin, - embedding_regularizer=ZeroMeanRegularizer(), - embedding_reg_weight=embedding_reg_weight, + regularizer=ZeroMeanRegularizer(), + reg_weight=reg_weight, ) for dtype in TEST_DTYPES: @@ -82,19 +82,17 @@ def test_snr_contrastive_loss(self): reg_loss = torch.mean(torch.abs(torch.sum(embeddings, dim=1))) - correct_total = ( - correct_pos_loss + correct_neg_loss + embedding_reg_weight * reg_loss - ) + correct_total = correct_pos_loss + correct_neg_loss + reg_weight * reg_loss rtol = 1e-2 if dtype == torch.float16 else 1e-5 self.assertTrue(torch.isclose(loss, correct_total, rtol=rtol)) def test_with_no_valid_pairs(self): - embedding_reg_weight = 0.1 + reg_weight = 0.1 loss_func = SignalToNoiseRatioContrastiveLoss( pos_margin=0, neg_margin=0.5, - embedding_regularizer=ZeroMeanRegularizer(), - embedding_reg_weight=embedding_reg_weight, + regularizer=ZeroMeanRegularizer(), + reg_weight=reg_weight, ) for dtype in TEST_DTYPES: embedding_angles = [0] @@ -106,10 +104,7 @@ def test_with_no_valid_pairs(self): TEST_DEVICE ) # 2D embeddings labels = torch.LongTensor([0]) - reg_loss = ( - torch.mean(torch.abs(torch.sum(embeddings, dim=1))) - * embedding_reg_weight - ) + reg_loss = torch.mean(torch.abs(torch.sum(embeddings, dim=1))) * reg_weight loss = loss_func(embeddings, labels) loss.backward() self.assertEqual(loss, reg_loss) diff --git a/tests/losses/test_soft_triple_loss.py b/tests/losses/test_soft_triple_loss.py index 42ad8e93..dd93e92d 100644 --- a/tests/losses/test_soft_triple_loss.py +++ b/tests/losses/test_soft_triple_loss.py @@ -78,11 +78,11 @@ def test_soft_triple_loss(self): gamma = 1 if dtype == torch.float16 else 0.1 for centers_per_class in range(1, 12): if centers_per_class > 1: - weight_regularizer = SparseCentersRegularizer( + regularizer = SparseCentersRegularizer( num_classes, centers_per_class ) else: - weight_regularizer = None + regularizer = None loss_func = SoftTripleLoss( num_classes, embedding_size, @@ -90,8 +90,8 @@ def test_soft_triple_loss(self): la=la, gamma=gamma, margin=margin, - weight_regularizer=weight_regularizer, - weight_reg_weight=reg_weight, + regularizer=regularizer, + reg_weight=reg_weight, ).to(TEST_DEVICE) original_loss_func = OriginalImplementationSoftTriple( la, diff --git a/tests/reducers/test_class_weighted_reducer.py b/tests/reducers/test_class_weighted_reducer.py index 73c3e91a..eb4b59c2 100644 --- a/tests/reducers/test_class_weighted_reducer.py +++ b/tests/reducers/test_class_weighted_reducer.py @@ -11,8 +11,15 @@ class TestClassWeightedReducer(unittest.TestCase): def test_class_weighted_reducer_with_threshold(self): torch.manual_seed(99115) class_weights = torch.tensor([1, 0.9, 1, 0.1, 0, 0, 0, 0, 0, 0]) - for low_threshold, high_threshold in [(None, None), (0.1, None), (None, 0.2), (0.1, 0.2)]: - reducer = ClassWeightedReducer(class_weights, low=low_threshold, high=high_threshold) + for low_threshold, high_threshold in [ + (None, None), + (0.1, None), + (None, 0.2), + (0.1, 0.2), + ]: + reducer = ClassWeightedReducer( + class_weights, low=low_threshold, high=high_threshold + ) batch_size = 100 embedding_size = 64 for dtype in TEST_DTYPES: @@ -59,13 +66,17 @@ def test_class_weighted_reducer_with_threshold(self): class_label = labels[batch_idx] correct_output += ( L[i] - * class_weights.type(dtype).to(TEST_DEVICE)[class_label] + * class_weights.type(dtype).to(TEST_DEVICE)[ + class_label + ] ) correct_output /= len(L) else: correct_output = 0 rtol = 1e-2 if dtype == torch.float16 else 1e-5 - self.assertTrue(torch.isclose(output, correct_output, rtol=rtol)) + self.assertTrue( + torch.isclose(output, correct_output, rtol=rtol) + ) if WITH_COLLECT_STATS: - self.assertTrue(reducer.num_past_filter == len(L)) \ No newline at end of file + self.assertTrue(reducer.num_past_filter == len(L)) diff --git a/tests/reducers/test_sum_reducer.py b/tests/reducers/test_sum_reducer.py index 159989fa..6a97183a 100644 --- a/tests/reducers/test_sum_reducer.py +++ b/tests/reducers/test_sum_reducer.py @@ -10,7 +10,12 @@ class TestSumReducer(unittest.TestCase): def test_sum_reducer_with_thresholds(self): torch.manual_seed(99115) - for low_threshold, high_threshold in [(None, None), (0.1, None), (None, 0.2), (0.1, 0.2)]: + for low_threshold, high_threshold in [ + (None, None), + (0.1, None), + (None, 0.2), + (0.1, 0.2), + ]: reducer = SumReducer(low=low_threshold, high=high_threshold) batch_size = 100 embedding_size = 64 @@ -51,9 +56,13 @@ def test_sum_reducer_with_thresholds(self): if len(L) > 0: correct_output = torch.sum(L, dtype=dtype) else: - correct_output = torch.zeros(1, dtype=dtype, device=TEST_DEVICE) + correct_output = torch.zeros( + 1, dtype=dtype, device=TEST_DEVICE + ) rtol = 1e-2 if dtype == torch.float16 else 1e-5 - self.assertTrue(torch.isclose(output, correct_output, rtol=rtol)) + self.assertTrue( + torch.isclose(output, correct_output, rtol=rtol) + ) if WITH_COLLECT_STATS: - self.assertTrue(reducer.num_past_filter == len(L)) \ No newline at end of file + self.assertTrue(reducer.num_past_filter == len(L)) diff --git a/tests/regularizers/test_center_invariant_regularizer.py b/tests/regularizers/test_center_invariant_regularizer.py index 053e1224..2443b8c2 100644 --- a/tests/regularizers/test_center_invariant_regularizer.py +++ b/tests/regularizers/test_center_invariant_regularizer.py @@ -19,8 +19,8 @@ def test_center_invariant_regularizer(self): temperature=temperature, num_classes=num_classes, embedding_size=embedding_size, - weight_regularizer=CenterInvariantRegularizer(), - weight_reg_weight=reg_weight, + regularizer=CenterInvariantRegularizer(), + reg_weight=reg_weight, ).to(TEST_DEVICE) embeddings = torch.nn.functional.normalize( diff --git a/tests/regularizers/test_regular_face_regularizer.py b/tests/regularizers/test_regular_face_regularizer.py index caa34fb4..d9a9134b 100644 --- a/tests/regularizers/test_regular_face_regularizer.py +++ b/tests/regularizers/test_regular_face_regularizer.py @@ -20,8 +20,8 @@ def test_regular_face_regularizer(self): temperature=temperature, num_classes=num_classes, embedding_size=embedding_size, - weight_regularizer=RegularFaceRegularizer(), - weight_reg_weight=reg_weight, + regularizer=RegularFaceRegularizer(), + reg_weight=reg_weight, ).to(TEST_DEVICE) embeddings = torch.nn.functional.normalize(