Skip to content

Commit b5cf5a0

Browse files
KylevdLangemheenyutong-xiang-97
authored andcommitted
Start on supcon loss and tests (#1554)
1 parent 7ff2c05 commit b5cf5a0

File tree

3 files changed

+108
-17
lines changed

3 files changed

+108
-17
lines changed

lightly/loss/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from lightly.loss.negative_cosine_similarity import NegativeCosineSimilarity
1717
from lightly.loss.ntx_ent_loss import NTXentLoss
1818
from lightly.loss.pmsn_loss import PMSNCustomLoss, PMSNLoss
19+
from lightly.loss.supcon_loss import SupConLoss
1920
from lightly.loss.swav_loss import SwaVLoss
2021
from lightly.loss.sym_neg_cos_sim_loss import SymNegCosineSimilarityLoss
2122
from lightly.loss.tico_loss import TiCoLoss

lightly/loss/supcon_loss.py

Lines changed: 84 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ class ContrastMode(Enum):
4848
ONLY_NEGATIVES = 3
4949

5050

51+
VALID_CONTRAST_MODES = set(item.name for item in ContrastMode)
52+
53+
5154
class SupConLoss(nn.Module):
5255
"""Implementation of the Supervised Contrastive Loss.
5356
@@ -68,21 +71,24 @@ class SupConLoss(nn.Module):
6871
Raises:
6972
ValueError: If abs(temperature) < 1e-8 to prevent divide by zero.
7073
ValueError: If gather_distributed is True but torch.distributed is not available.
71-
NotImplementedError: If contrast_mode is outside the accepted ContrastMode values.
74+
ValueError: If contrast_mode is outside the accepted ContrastMode values.
7275
7376
Examples:
74-
>>> # initialize loss function without memory bank
75-
>>> loss_fn = NTXentLoss(memory_bank_size=0)
77+
>>> # initialize loss function
78+
>>> loss_fn = SupConLoss()
7679
>>>
77-
>>> # generate two random transforms of images
80+
>>> # generate two or more views of images
7881
>>> t0 = transforms(images)
7982
>>> t1 = transforms(images)
8083
>>>
81-
>>> # feed through SimCLR or MoCo model
84+
>>> # feed through SimCLR model
8285
>>> out0, out1 = model(t0), model(t1)
8386
>>>
87+
>>> # Stack views along 2nd dimensions
88+
>>> features = torch.stack([out0, out1], dim=1)
89+
>>>
8490
>>> # calculate loss
85-
>>> loss = loss_fn(out0, out1)
91+
>>> loss = loss_fn(features, labels)
8692
8793
"""
8894

@@ -92,18 +98,21 @@ def __init__(
9298
contrast_mode: ContrastMode = ContrastMode.ALL,
9399
gather_distributed: bool = False,
94100
):
95-
"""Initializes the NTXentLoss module with the specified parameters.
101+
"""Initializes the SupConLoss module with the specified parameters.
96102
97103
Args:
98104
temperature:
99105
Scale logits by the inverse of the temperature.
106+
contrast_mode:
107+
Whether to use all positives, one positive, or none. All negatives are
108+
used in all cases.
100109
gather_distributed:
101110
If True, negatives from all GPUs are gathered before the loss calculation.
102111
103112
Raises:
104113
ValueError: If temperature is less than 1e-8 to prevent divide by zero.
105114
ValueError: If gather_distributed is True but torch.distributed is not available.
106-
NotImplementedError: If contrast_mode is outside the accepted ContrastMode values.
115+
ValueError: If contrast_mode is outside the accepted ContrastMode values.
107116
"""
108117
super().__init__()
109118
self.temperature = temperature
@@ -124,6 +133,11 @@ def __init__(
124133
"distributed support."
125134
)
126135

136+
if contrast_mode.name not in VALID_CONTRAST_MODES:
137+
raise ValueError(
138+
f"contrast_mode is {contrast_mode} but must be one of ContrastMode.{VALID_CONTRAST_MODES}"
139+
)
140+
127141
def forward(self, features: Tensor, labels: Optional[Tensor] = None) -> Tensor:
128142
"""Forward pass through Supervised Contrastive Loss.
129143
@@ -140,14 +154,34 @@ def forward(self, features: Tensor, labels: Optional[Tensor] = None) -> Tensor:
140154
Raises:
141155
ValueError: If features does not have at least 3 dimensions.
142156
ValueError: If number of labels does not match batch_size.
157+
ValueError: If labels is not one-hot encoded.
143158
144159
Returns:
145160
Supervised Contrastive Loss value.
146161
"""
147162

163+
if len(features.shape) < 3:
164+
raise ValueError(
165+
f"Features must have at least 3 dimensions, got {len(features.shape)}."
166+
)
167+
148168
device = features.device
149169
batch_size, num_views = features.shape[:2]
150170

171+
if labels is not None and labels.size(0) != batch_size:
172+
raise ValueError(
173+
f"When setting labels, labels must match batch_size {batch_size}, got {labels.size(0)}."
174+
)
175+
176+
if labels is not None:
177+
if not self._is_one_hot(labels):
178+
raise ValueError(
179+
"labels must be a 2D matrix representing the one-hot encoded classes."
180+
)
181+
182+
# Flatten the features in case they are still images or other
183+
features = features.flatten(2)
184+
151185
# Normalize the features to length 1
152186
features = F.normalize(features, dim=2)
153187

@@ -178,31 +212,43 @@ def forward(self, features: Tensor, labels: Optional[Tensor] = None) -> Tensor:
178212
else:
179213
mask = (labels @ global_labels.T).to(device)
180214

181-
# Get features in shape [num_views * n, c]
215+
# Get features in shape [num_views * batch_size, c]
182216
all_global_features = global_features.permute(1, 0, 2).reshape(
183217
-1, global_features.size(-1)
184218
)
185219

186220
if self.contrast_mode == ContrastMode.ONE_POSITIVE:
221+
# We take only the first view as anchor
187222
anchor_features = features[:, 0]
188223
num_anchor_views = 1
189224
else:
225+
# We take all views as anchors in the same shape as the global features
190226
anchor_features = features.permute(1, 0, 2).reshape(-1, features.size(-1))
191227
num_anchor_views = num_views
192228

193229
# Obtain the logits between anchor features and features across all processes
194230
# Logits will be shaped [local_batch_size * num_anchor_views, global_batch_size * num_views]
195231
# We then temperature scale it and subtract the max to improve numerical stability
232+
# In the einsum, n is local_batch_size * num_anchor_views, m is global_batch_size * num_views,
233+
# and c is the flattened feature length
234+
# Note: features are ordered by view first, i.e. first all samples of view 0, then all samples
235+
# of view 1, and so on.
196236
logits = torch.einsum("nc,mc->nm", anchor_features, all_global_features)
197237
logits /= self.temperature
198238
logits -= logits.max(dim=1, keepdim=True)[0].detach()
199239
exp_logits = torch.exp(logits)
200240

241+
# Get the positive and negative masks for numerator & denominator
201242
positives_mask, negatives_mask = self._create_tiled_masks(
202-
mask, diag_mask, num_views, num_anchor_views, self.positives_cap
243+
mask.long(),
244+
diag_mask.long(),
245+
num_views,
246+
num_anchor_views,
247+
self.positives_cap,
203248
)
204249
num_positives_per_row = positives_mask.sum(dim=1)
205250

251+
# Calculate denominator based on contrast_mode
206252
if self.contrast_mode == ContrastMode.ONE_POSITIVE:
207253
denominator = exp_logits + (exp_logits * negatives_mask).sum(
208254
dim=1, keepdim=True
@@ -216,13 +262,14 @@ def forward(self, features: Tensor, labels: Optional[Tensor] = None) -> Tensor:
216262
# num_positives_per_row can be zero iff 1 view is used. Here we use a safe
217263
# dividing method seting those values to zero to prevent division by zero errors.
218264

219-
# Only implements SupCon_{out}
265+
# Only implements SupCon_{out}.
220266
log_probs = (logits - torch.log(denominator)) * positives_mask
221267
log_probs = log_probs.sum(dim=1)
222268
log_probs = divide_no_nan(log_probs, num_positives_per_row)
223269

224270
loss = -log_probs
225271

272+
# Adjust for num_positives_per_row being zero when using exactly 1 view
226273
if num_views != 1:
227274
loss = loss.mean(dim=0)
228275
else:
@@ -232,21 +279,27 @@ def forward(self, features: Tensor, labels: Optional[Tensor] = None) -> Tensor:
232279
return loss
233280

234281
def _create_tiled_masks(
235-
self, untiled_mask, diagonal_mask, num_views, num_anchor_views, positives_cap
282+
self,
283+
untiled_mask: Tensor,
284+
diagonal_mask: Tensor,
285+
num_views: int,
286+
num_anchor_views: int,
287+
positives_cap: int,
236288
) -> Tuple[Tensor, Tensor]:
237289
# Get total batch size across all processes
238-
print(untiled_mask.shape)
239290
global_batch_size = untiled_mask.size(1)
240291

241292
# Find index of the anchor for each sample
242-
labels = torch.argmax(diagonal_mask.long(), dim=1)
293+
labels = torch.argmax(diagonal_mask, dim=1)
243294

244295
# Generate tiled labels across views
245296
tiled_labels = []
246297
for i in range(num_anchor_views):
247298
tiled_labels.append(labels + global_batch_size * i)
248-
tiled_labels = torch.cat(tiled_labels, 0)
249-
tiled_diagonal_mask = F.one_hot(tiled_labels, global_batch_size * num_views)
299+
tiled_labels_tensor = torch.cat(tiled_labels, 0)
300+
tiled_diagonal_mask = F.one_hot(
301+
tiled_labels_tensor, global_batch_size * num_views
302+
)
250303

251304
# Mask to zero the diagonal at the end
252305
all_but_diagonal_mask = 1 - tiled_diagonal_mask
@@ -257,7 +310,7 @@ def _create_tiled_masks(
257310
)
258311

259312
# The negatives is simply the bitflipped positives
260-
negatives_mask = 1 - uncapped_positives_mask
313+
negatives_mask = 1.0 - uncapped_positives_mask
261314

262315
# For when positives_cap is implemented
263316
if positives_cap > -1:
@@ -269,3 +322,17 @@ def _create_tiled_masks(
269322
positives_mask *= all_but_diagonal_mask
270323

271324
return positives_mask, negatives_mask
325+
326+
def _is_one_hot(self, tensor: Tensor) -> bool:
327+
# Tensor is not a 2D matrix
328+
if tensor.ndim != 2:
329+
return False
330+
331+
# Check values are only 0 or 1
332+
is_binary = ((tensor == 0) | (tensor == 1)).all()
333+
334+
# Check each row sums to 1
335+
row_sums = tensor.sum(dim=1)
336+
has_single_one = (row_sums == 1).all()
337+
338+
return bool(is_binary.item() and has_single_one.item())

tests/loss/test_supcon_loss.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
import torch
3+
from torch import Tensor
4+
5+
from lightly.loss import NTXentLoss, SupConLoss
6+
7+
8+
class TestSupConLoss:
9+
def test_simple_input(self) -> None:
10+
my_input = torch.rand([3, 2, 4])
11+
my_label = Tensor([[1, 0], [0, 1], [0, 1]])
12+
my_loss = SupConLoss()
13+
my_loss(my_input, my_label)
14+
15+
def test_unsup_equal_to_simclr(self) -> None:
16+
supcon = SupConLoss(temperature=0.5, gather_distributed=False)
17+
ntxent = NTXentLoss(
18+
temperature=0.5, memory_bank_size=0, gather_distributed=False
19+
)
20+
features = torch.rand((8, 2, 10))
21+
supcon_loss = supcon(features)
22+
ntxent_loss = ntxent(features[:, 0, :], features[:, 1, :])
23+
assert (supcon_loss - ntxent_loss).pow(2).item() == pytest.approx(0.0)

0 commit comments

Comments
 (0)