@@ -48,6 +48,9 @@ class ContrastMode(Enum):
4848 ONLY_NEGATIVES = 3
4949
5050
51+ VALID_CONTRAST_MODES = set (item .name for item in ContrastMode )
52+
53+
5154class SupConLoss (nn .Module ):
5255 """Implementation of the Supervised Contrastive Loss.
5356
@@ -68,21 +71,24 @@ class SupConLoss(nn.Module):
6871 Raises:
6972 ValueError: If abs(temperature) < 1e-8 to prevent divide by zero.
7073 ValueError: If gather_distributed is True but torch.distributed is not available.
71- NotImplementedError : If contrast_mode is outside the accepted ContrastMode values.
74+ ValueError : If contrast_mode is outside the accepted ContrastMode values.
7275
7376 Examples:
74- >>> # initialize loss function without memory bank
75- >>> loss_fn = NTXentLoss(memory_bank_size=0 )
77+ >>> # initialize loss function
78+ >>> loss_fn = SupConLoss( )
7679 >>>
77- >>> # generate two random transforms of images
80+ >>> # generate two or more views of images
7881 >>> t0 = transforms(images)
7982 >>> t1 = transforms(images)
8083 >>>
81- >>> # feed through SimCLR or MoCo model
84+ >>> # feed through SimCLR model
8285 >>> out0, out1 = model(t0), model(t1)
8386 >>>
87+ >>> # Stack views along 2nd dimensions
88+ >>> features = torch.stack([out0, out1], dim=1)
89+ >>>
8490 >>> # calculate loss
85- >>> loss = loss_fn(out0, out1 )
91+ >>> loss = loss_fn(features, labels )
8692
8793 """
8894
@@ -92,18 +98,21 @@ def __init__(
9298 contrast_mode : ContrastMode = ContrastMode .ALL ,
9399 gather_distributed : bool = False ,
94100 ):
95- """Initializes the NTXentLoss module with the specified parameters.
101+ """Initializes the SupConLoss module with the specified parameters.
96102
97103 Args:
98104 temperature:
99105 Scale logits by the inverse of the temperature.
106+ contrast_mode:
107+ Whether to use all positives, one positive, or none. All negatives are
108+ used in all cases.
100109 gather_distributed:
101110 If True, negatives from all GPUs are gathered before the loss calculation.
102111
103112 Raises:
104113 ValueError: If temperature is less than 1e-8 to prevent divide by zero.
105114 ValueError: If gather_distributed is True but torch.distributed is not available.
106- NotImplementedError : If contrast_mode is outside the accepted ContrastMode values.
115+ ValueError : If contrast_mode is outside the accepted ContrastMode values.
107116 """
108117 super ().__init__ ()
109118 self .temperature = temperature
@@ -124,6 +133,11 @@ def __init__(
124133 "distributed support."
125134 )
126135
136+ if contrast_mode .name not in VALID_CONTRAST_MODES :
137+ raise ValueError (
138+ f"contrast_mode is { contrast_mode } but must be one of ContrastMode.{ VALID_CONTRAST_MODES } "
139+ )
140+
127141 def forward (self , features : Tensor , labels : Optional [Tensor ] = None ) -> Tensor :
128142 """Forward pass through Supervised Contrastive Loss.
129143
@@ -140,14 +154,34 @@ def forward(self, features: Tensor, labels: Optional[Tensor] = None) -> Tensor:
140154 Raises:
141155 ValueError: If features does not have at least 3 dimensions.
142156 ValueError: If number of labels does not match batch_size.
157+ ValueError: If labels is not one-hot encoded.
143158
144159 Returns:
145160 Supervised Contrastive Loss value.
146161 """
147162
163+ if len (features .shape ) < 3 :
164+ raise ValueError (
165+ f"Features must have at least 3 dimensions, got { len (features .shape )} ."
166+ )
167+
148168 device = features .device
149169 batch_size , num_views = features .shape [:2 ]
150170
171+ if labels is not None and labels .size (0 ) != batch_size :
172+ raise ValueError (
173+ f"When setting labels, labels must match batch_size { batch_size } , got { labels .size (0 )} ."
174+ )
175+
176+ if labels is not None :
177+ if not self ._is_one_hot (labels ):
178+ raise ValueError (
179+ "labels must be a 2D matrix representing the one-hot encoded classes."
180+ )
181+
182+ # Flatten the features in case they are still images or other
183+ features = features .flatten (2 )
184+
151185 # Normalize the features to length 1
152186 features = F .normalize (features , dim = 2 )
153187
@@ -178,31 +212,43 @@ def forward(self, features: Tensor, labels: Optional[Tensor] = None) -> Tensor:
178212 else :
179213 mask = (labels @ global_labels .T ).to (device )
180214
181- # Get features in shape [num_views * n , c]
215+ # Get features in shape [num_views * batch_size , c]
182216 all_global_features = global_features .permute (1 , 0 , 2 ).reshape (
183217 - 1 , global_features .size (- 1 )
184218 )
185219
186220 if self .contrast_mode == ContrastMode .ONE_POSITIVE :
221+ # We take only the first view as anchor
187222 anchor_features = features [:, 0 ]
188223 num_anchor_views = 1
189224 else :
225+ # We take all views as anchors in the same shape as the global features
190226 anchor_features = features .permute (1 , 0 , 2 ).reshape (- 1 , features .size (- 1 ))
191227 num_anchor_views = num_views
192228
193229 # Obtain the logits between anchor features and features across all processes
194230 # Logits will be shaped [local_batch_size * num_anchor_views, global_batch_size * num_views]
195231 # We then temperature scale it and subtract the max to improve numerical stability
232+ # In the einsum, n is local_batch_size * num_anchor_views, m is global_batch_size * num_views,
233+ # and c is the flattened feature length
234+ # Note: features are ordered by view first, i.e. first all samples of view 0, then all samples
235+ # of view 1, and so on.
196236 logits = torch .einsum ("nc,mc->nm" , anchor_features , all_global_features )
197237 logits /= self .temperature
198238 logits -= logits .max (dim = 1 , keepdim = True )[0 ].detach ()
199239 exp_logits = torch .exp (logits )
200240
241+ # Get the positive and negative masks for numerator & denominator
201242 positives_mask , negatives_mask = self ._create_tiled_masks (
202- mask , diag_mask , num_views , num_anchor_views , self .positives_cap
243+ mask .long (),
244+ diag_mask .long (),
245+ num_views ,
246+ num_anchor_views ,
247+ self .positives_cap ,
203248 )
204249 num_positives_per_row = positives_mask .sum (dim = 1 )
205250
251+ # Calculate denominator based on contrast_mode
206252 if self .contrast_mode == ContrastMode .ONE_POSITIVE :
207253 denominator = exp_logits + (exp_logits * negatives_mask ).sum (
208254 dim = 1 , keepdim = True
@@ -216,13 +262,14 @@ def forward(self, features: Tensor, labels: Optional[Tensor] = None) -> Tensor:
216262 # num_positives_per_row can be zero iff 1 view is used. Here we use a safe
217263 # dividing method seting those values to zero to prevent division by zero errors.
218264
219- # Only implements SupCon_{out}
265+ # Only implements SupCon_{out}.
220266 log_probs = (logits - torch .log (denominator )) * positives_mask
221267 log_probs = log_probs .sum (dim = 1 )
222268 log_probs = divide_no_nan (log_probs , num_positives_per_row )
223269
224270 loss = - log_probs
225271
272+ # Adjust for num_positives_per_row being zero when using exactly 1 view
226273 if num_views != 1 :
227274 loss = loss .mean (dim = 0 )
228275 else :
@@ -232,21 +279,27 @@ def forward(self, features: Tensor, labels: Optional[Tensor] = None) -> Tensor:
232279 return loss
233280
234281 def _create_tiled_masks (
235- self , untiled_mask , diagonal_mask , num_views , num_anchor_views , positives_cap
282+ self ,
283+ untiled_mask : Tensor ,
284+ diagonal_mask : Tensor ,
285+ num_views : int ,
286+ num_anchor_views : int ,
287+ positives_cap : int ,
236288 ) -> Tuple [Tensor , Tensor ]:
237289 # Get total batch size across all processes
238- print (untiled_mask .shape )
239290 global_batch_size = untiled_mask .size (1 )
240291
241292 # Find index of the anchor for each sample
242- labels = torch .argmax (diagonal_mask . long () , dim = 1 )
293+ labels = torch .argmax (diagonal_mask , dim = 1 )
243294
244295 # Generate tiled labels across views
245296 tiled_labels = []
246297 for i in range (num_anchor_views ):
247298 tiled_labels .append (labels + global_batch_size * i )
248- tiled_labels = torch .cat (tiled_labels , 0 )
249- tiled_diagonal_mask = F .one_hot (tiled_labels , global_batch_size * num_views )
299+ tiled_labels_tensor = torch .cat (tiled_labels , 0 )
300+ tiled_diagonal_mask = F .one_hot (
301+ tiled_labels_tensor , global_batch_size * num_views
302+ )
250303
251304 # Mask to zero the diagonal at the end
252305 all_but_diagonal_mask = 1 - tiled_diagonal_mask
@@ -257,7 +310,7 @@ def _create_tiled_masks(
257310 )
258311
259312 # The negatives is simply the bitflipped positives
260- negatives_mask = 1 - uncapped_positives_mask
313+ negatives_mask = 1.0 - uncapped_positives_mask
261314
262315 # For when positives_cap is implemented
263316 if positives_cap > - 1 :
@@ -269,3 +322,17 @@ def _create_tiled_masks(
269322 positives_mask *= all_but_diagonal_mask
270323
271324 return positives_mask , negatives_mask
325+
326+ def _is_one_hot (self , tensor : Tensor ) -> bool :
327+ # Tensor is not a 2D matrix
328+ if tensor .ndim != 2 :
329+ return False
330+
331+ # Check values are only 0 or 1
332+ is_binary = ((tensor == 0 ) | (tensor == 1 )).all ()
333+
334+ # Check each row sums to 1
335+ row_sums = tensor .sum (dim = 1 )
336+ has_single_one = (row_sums == 1 ).all ()
337+
338+ return bool (is_binary .item () and has_single_one .item ())
0 commit comments