Skip to content

Commit 899a230

Browse files
Merge pull request #703 from KevinMusgrave/dev
v2.6.0
2 parents adfb78c + ef1bd06 commit 899a230

File tree

4 files changed

+22
-7
lines changed

4 files changed

+22
-7
lines changed

.github/workflows/base_test_workflow.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ jobs:
1818
pytorch-version: 1.6
1919
torchvision-version: 0.7
2020
- python-version: 3.9
21-
pytorch-version: 2.1
22-
torchvision-version: 0.16
21+
pytorch-version: 2.3
22+
torchvision-version: 0.18
2323

2424
steps:
2525
- uses: actions/checkout@v2
@@ -30,7 +30,7 @@ jobs:
3030
- name: Install dependencies
3131
run: |
3232
pip install .[with-hooks-cpu]
33-
pip install torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
33+
pip install "numpy<2.0" torch==${{ matrix.pytorch-version }} torchvision==${{ matrix.torchvision-version }} --force-reinstall
3434
pip install --upgrade protobuf==3.20.1
3535
pip install six
3636
pip install packaging

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
],
4040
python_requires=">=3.0",
4141
install_requires=[
42-
"numpy",
42+
"numpy < 2.0",
4343
"scikit-learn",
4444
"tqdm",
4545
"torch >= 1.6.0",
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.5.0"
1+
__version__ = "2.6.0"

src/pytorch_metric_learning/utils/distributed.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
import torch
24

35
from ..losses import BaseMetricLossFunction, CrossBatchMemory
@@ -93,15 +95,28 @@ def __init__(self, loss, efficient=False):
9395

9496
def forward(
9597
self,
96-
emb,
98+
embeddings,
9799
labels=None,
98100
indices_tuple=None,
99101
ref_emb=None,
100102
ref_labels=None,
101103
enqueue_mask=None,
102104
):
105+
if not is_distributed():
106+
warnings.warn(
107+
"DistributedLossWrapper is being used in a non-distributed setting. Returning the loss as is."
108+
)
109+
return self.loss(embeddings, labels, indices_tuple, ref_emb, ref_labels)
110+
103111
world_size = torch.distributed.get_world_size()
104-
common_args = [emb, labels, indices_tuple, ref_emb, ref_labels, world_size]
112+
common_args = [
113+
embeddings,
114+
labels,
115+
indices_tuple,
116+
ref_emb,
117+
ref_labels,
118+
world_size,
119+
]
105120
if isinstance(self.loss, CrossBatchMemory):
106121
return self.forward_cross_batch(*common_args, enqueue_mask)
107122
return self.forward_regular_loss(*common_args)

0 commit comments

Comments
 (0)