Skip to content

Commit 7ff2c05

Browse files
KylevdLangemheenyutong-xiang-97
authored andcommitted
Add first setup supcon loss
1 parent ee30cd4 commit 7ff2c05

File tree

1 file changed

+271
-0
lines changed

1 file changed

+271
-0
lines changed

lightly/loss/supcon_loss.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
""" Contrastive Loss Functions """
2+
3+
# Copyright (c) 2020. Lightly AG and its affiliates.
4+
# All Rights Reserved
5+
6+
from enum import Enum
7+
from typing import Optional, Tuple
8+
9+
import torch
10+
import torch.nn.functional as F
11+
from torch import Tensor
12+
from torch import distributed as torch_dist
13+
from torch import nn
14+
15+
from lightly.utils import dist
16+
17+
18+
def divide_no_nan(numerator: Tensor, denominator: Tensor) -> Tensor:
19+
"""Performs tensor division, setting result to zero where denominator is zero.
20+
21+
Args:
22+
numerator:
23+
Numerator tensor.
24+
denominator:
25+
Denominator tensor with possible zeroes.
26+
27+
Returns:
28+
Result with zeros where denominator is zero.
29+
"""
30+
result = torch.zeros_like(numerator)
31+
nonzero_mask = denominator != 0
32+
result[nonzero_mask] = numerator[nonzero_mask] / denominator[nonzero_mask]
33+
return result
34+
35+
36+
class ContrastMode(Enum):
37+
"""Contrast Mode Enum for SupCon Loss.
38+
39+
Offers the three contrast modes as enum for the SupCon loss. The three modes are:
40+
41+
- ContrastMode.ALL: Uses all positives and negatives.
42+
- ContrastMode.ONE_POSITIVE: Uses only one positive, and all negatives.
43+
- ContrastMode.ONLY_NEGATIVES: Uses no positives, only negatives.
44+
"""
45+
46+
ALL = 1
47+
ONE_POSITIVE = 2
48+
ONLY_NEGATIVES = 3
49+
50+
51+
class SupConLoss(nn.Module):
52+
"""Implementation of the Supervised Contrastive Loss.
53+
54+
This implementation follows the SupCon[0] paper.
55+
56+
- [0] SupCon, 2020, https://arxiv.org/abs/2004.11362
57+
58+
Attributes:
59+
temperature:
60+
Scale logits by the inverse of the temperature.
61+
contrast_mode:
62+
Whether to use all positives, one positive, or none. All negatives are
63+
used in all cases.
64+
gather_distributed:
65+
If True then negatives from all GPUs are gathered before the
66+
loss calculation.
67+
68+
Raises:
69+
ValueError: If abs(temperature) < 1e-8 to prevent divide by zero.
70+
ValueError: If gather_distributed is True but torch.distributed is not available.
71+
NotImplementedError: If contrast_mode is outside the accepted ContrastMode values.
72+
73+
Examples:
74+
>>> # initialize loss function without memory bank
75+
>>> loss_fn = NTXentLoss(memory_bank_size=0)
76+
>>>
77+
>>> # generate two random transforms of images
78+
>>> t0 = transforms(images)
79+
>>> t1 = transforms(images)
80+
>>>
81+
>>> # feed through SimCLR or MoCo model
82+
>>> out0, out1 = model(t0), model(t1)
83+
>>>
84+
>>> # calculate loss
85+
>>> loss = loss_fn(out0, out1)
86+
87+
"""
88+
89+
def __init__(
90+
self,
91+
temperature: float = 0.5,
92+
contrast_mode: ContrastMode = ContrastMode.ALL,
93+
gather_distributed: bool = False,
94+
):
95+
"""Initializes the NTXentLoss module with the specified parameters.
96+
97+
Args:
98+
temperature:
99+
Scale logits by the inverse of the temperature.
100+
gather_distributed:
101+
If True, negatives from all GPUs are gathered before the loss calculation.
102+
103+
Raises:
104+
ValueError: If temperature is less than 1e-8 to prevent divide by zero.
105+
ValueError: If gather_distributed is True but torch.distributed is not available.
106+
NotImplementedError: If contrast_mode is outside the accepted ContrastMode values.
107+
"""
108+
super().__init__()
109+
self.temperature = temperature
110+
self.contrast_mode = contrast_mode
111+
self.positives_cap = -1 # Unused at the moment
112+
self.gather_distributed = gather_distributed
113+
self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")
114+
self.eps = 1e-8
115+
116+
if abs(self.temperature) < self.eps:
117+
raise ValueError(
118+
"Illegal temperature: abs({}) < 1e-8".format(self.temperature)
119+
)
120+
if gather_distributed and not torch_dist.is_available():
121+
raise ValueError(
122+
"gather_distributed is True but torch.distributed is not available. "
123+
"Please set gather_distributed=False or install a torch version with "
124+
"distributed support."
125+
)
126+
127+
def forward(self, features: Tensor, labels: Optional[Tensor] = None) -> Tensor:
128+
"""Forward pass through Supervised Contrastive Loss.
129+
130+
Computes the loss based on contrast_mode setting.
131+
132+
Args:
133+
features:
134+
Tensor of at least 3 dimensions, corresponding to
135+
(batch_size, num_views, ...)
136+
labels:
137+
Onehot labels for each sample. Must match shape
138+
(batch_size, num_classes)
139+
140+
Raises:
141+
ValueError: If features does not have at least 3 dimensions.
142+
ValueError: If number of labels does not match batch_size.
143+
144+
Returns:
145+
Supervised Contrastive Loss value.
146+
"""
147+
148+
device = features.device
149+
batch_size, num_views = features.shape[:2]
150+
151+
# Normalize the features to length 1
152+
features = F.normalize(features, dim=2)
153+
154+
# Memory bank could be used here but labelled samples are not yet supported.
155+
156+
# Use cosine similarity (dot product) as all vectors are normalized to unit length
157+
158+
# Use other samples from different classes in batch as negatives
159+
# and create diagonal mask that only selects similarities between
160+
# views of the same image / same class
161+
if self.gather_distributed and dist.world_size() > 1:
162+
# Gather hidden representations and optional labels from other processes
163+
global_features = torch.cat(dist.gather(features), 0)
164+
diag_mask = dist.eye_rank(batch_size, device=device)
165+
if labels is not None:
166+
global_labels = torch.cat(dist.gather(labels), 0)
167+
else:
168+
# Single process
169+
global_features = features
170+
diag_mask = torch.eye(batch_size, device=device, dtype=torch.bool)
171+
if labels is not None:
172+
global_labels = labels
173+
174+
# Use the diagonal mask if labels is none, else compute the mask based on labels
175+
if labels is None:
176+
# No labels, typical semi-supervised contrastive learning like SimCLR
177+
mask = diag_mask
178+
else:
179+
mask = (labels @ global_labels.T).to(device)
180+
181+
# Get features in shape [num_views * n, c]
182+
all_global_features = global_features.permute(1, 0, 2).reshape(
183+
-1, global_features.size(-1)
184+
)
185+
186+
if self.contrast_mode == ContrastMode.ONE_POSITIVE:
187+
anchor_features = features[:, 0]
188+
num_anchor_views = 1
189+
else:
190+
anchor_features = features.permute(1, 0, 2).reshape(-1, features.size(-1))
191+
num_anchor_views = num_views
192+
193+
# Obtain the logits between anchor features and features across all processes
194+
# Logits will be shaped [local_batch_size * num_anchor_views, global_batch_size * num_views]
195+
# We then temperature scale it and subtract the max to improve numerical stability
196+
logits = torch.einsum("nc,mc->nm", anchor_features, all_global_features)
197+
logits /= self.temperature
198+
logits -= logits.max(dim=1, keepdim=True)[0].detach()
199+
exp_logits = torch.exp(logits)
200+
201+
positives_mask, negatives_mask = self._create_tiled_masks(
202+
mask, diag_mask, num_views, num_anchor_views, self.positives_cap
203+
)
204+
num_positives_per_row = positives_mask.sum(dim=1)
205+
206+
if self.contrast_mode == ContrastMode.ONE_POSITIVE:
207+
denominator = exp_logits + (exp_logits * negatives_mask).sum(
208+
dim=1, keepdim=True
209+
)
210+
elif self.contrast_mode == ContrastMode.ALL:
211+
denominator = (exp_logits * negatives_mask).sum(dim=1, keepdim=True)
212+
denominator += (exp_logits * positives_mask).sum(dim=1, keepdim=True)
213+
else: # ContrastMode.ONLY_NEGATIVES
214+
denominator = (exp_logits * negatives_mask).sum(dim=1, keepdim=True)
215+
216+
# num_positives_per_row can be zero iff 1 view is used. Here we use a safe
217+
# dividing method seting those values to zero to prevent division by zero errors.
218+
219+
# Only implements SupCon_{out}
220+
log_probs = (logits - torch.log(denominator)) * positives_mask
221+
log_probs = log_probs.sum(dim=1)
222+
log_probs = divide_no_nan(log_probs, num_positives_per_row)
223+
224+
loss = -log_probs
225+
226+
if num_views != 1:
227+
loss = loss.mean(dim=0)
228+
else:
229+
num_valid_views_per_sample = num_positives_per_row.unsqueeze(0)
230+
loss = divide_no_nan(loss, num_valid_views_per_sample).squeeze()
231+
232+
return loss
233+
234+
def _create_tiled_masks(
235+
self, untiled_mask, diagonal_mask, num_views, num_anchor_views, positives_cap
236+
) -> Tuple[Tensor, Tensor]:
237+
# Get total batch size across all processes
238+
print(untiled_mask.shape)
239+
global_batch_size = untiled_mask.size(1)
240+
241+
# Find index of the anchor for each sample
242+
labels = torch.argmax(diagonal_mask.long(), dim=1)
243+
244+
# Generate tiled labels across views
245+
tiled_labels = []
246+
for i in range(num_anchor_views):
247+
tiled_labels.append(labels + global_batch_size * i)
248+
tiled_labels = torch.cat(tiled_labels, 0)
249+
tiled_diagonal_mask = F.one_hot(tiled_labels, global_batch_size * num_views)
250+
251+
# Mask to zero the diagonal at the end
252+
all_but_diagonal_mask = 1 - tiled_diagonal_mask
253+
254+
# All tiled positives
255+
uncapped_positives_mask = torch.tile(
256+
untiled_mask, [num_anchor_views, num_views]
257+
)
258+
259+
# The negatives is simply the bitflipped positives
260+
negatives_mask = 1 - uncapped_positives_mask
261+
262+
# For when positives_cap is implemented
263+
if positives_cap > -1:
264+
raise NotImplementedError("Capping positives is not yet implemented.")
265+
else:
266+
positives_mask = uncapped_positives_mask
267+
268+
# Zero out the self-contrast
269+
positives_mask *= all_but_diagonal_mask
270+
271+
return positives_mask, negatives_mask

0 commit comments

Comments
 (0)