@@ -41,19 +41,41 @@ def __call__(self, data: List[Dict]) -> XYData:
4141
4242
4343class RaggedCollator (Collator ):
44- """Collator for handling ragged data samples."""
44+ """
45+ Collator for handling ragged data samples, designed to support scenarios where some labels may be missing (None).
46+
47+ This class is specifically designed for preparing batches of "ragged" data, where the samples may have varying sizes,
48+ such as molecular representations or variable-length protein sequences. Additionally, it supports cases where some
49+ of the data samples might be partially labeled, which is useful for certain loss functions that allow training
50+ with incomplete or fuzzy data (e.g., fuzzy loss).
51+
52+ During batching, the class pads the data samples to a uniform length, applies appropriate masks to differentiate
53+ between valid and padded elements, and ensures that label misalignment is handled by filtering out unlabelled
54+ data points. The indices of valid labels are stored in the `non_null_labels` field, which can be used later for
55+ metrics computation such as F1-score or MSE, especially in cases where some data points lack labels.
56+
57+ Reference: https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829
58+ """
4559
4660 def __call__ (self , data : List [Union [Dict , Tuple ]]) -> XYData :
47- """Collate ragged data samples (i.e., samples of unequal size such as string representations of molecules) into
48- a batch.
61+ """
62+ Collate ragged data samples (i.e., samples of unequal size, such as molecular sequences) into a batch.
63+
64+ Handles both fully and partially labeled data, where some samples may have `None` as their label. The indices
65+ of non-null labels are stored in the `non_null_labels` field, which is used to filter out predictions for
66+ unlabeled data during evaluation (e.g., F1, MSE). For models supporting partially labeled data, this method
67+ ensures alignment between features and labels.
4968
5069 Args:
51- data (List[Union[Dict, Tuple]]): List of ragged data samples.
70+ data (List[Union[Dict, Tuple]]): List of ragged data samples. Each sample can be a dictionary or tuple
71+ with 'features', 'labels', and 'ident'.
5272
5373 Returns:
54- XYData: Batched data with appropriate padding and masks.
74+ XYData: A batch of padded sequences and labels, including masks for valid positions and indices of
75+ non-null labels for metric computation.
5576 """
5677 model_kwargs : Dict = dict ()
78+ # Indices of non-null labels are stored in key `non_null_labels` of loss_kwargs.
5779 loss_kwargs : Dict = dict ()
5880
5981 if isinstance (data [0 ], tuple ):
@@ -64,18 +86,23 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
6486 * ((d ["features" ], d ["labels" ], d .get ("ident" )) for d in data )
6587 )
6688 if any (x is not None for x in y ):
89+ # If any label is not None: (None, None, `1`, None)
6790 if any (x is None for x in y ):
91+ # If any label is None: (`None`, `None`, 1, `None`)
6892 non_null_labels = [i for i , r in enumerate (y ) if r is not None ]
6993 y = self .process_label_rows (
7094 tuple (ye for i , ye in enumerate (y ) if i in non_null_labels )
7195 )
7296 loss_kwargs ["non_null_labels" ] = non_null_labels
7397 else :
98+ # If all labels are not None: (`0`, `2`, `1`, `3`)
7499 y = self .process_label_rows (y )
75100 else :
101+ # If all labels are None : (`None`, `None`, `None`, `None`)
76102 y = None
77103 loss_kwargs ["non_null_labels" ] = []
78104
105+ # Calculate the lengths of each sequence, create a binary mask for valid (non-padded) positions
79106 lens = torch .tensor (list (map (len , x )))
80107 model_kwargs ["mask" ] = torch .arange (max (lens ))[None , :] < lens [:, None ]
81108 model_kwargs ["lens" ] = lens
@@ -89,7 +116,11 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
89116 )
90117
91118 def process_label_rows (self , labels : Tuple ) -> torch .Tensor :
92- """Process label rows by padding sequences.
119+ """
120+ Process label rows by padding sequences to ensure uniform shape across the batch.
121+
122+ This method pads the label rows, converting sequences of labels of different lengths into a uniform tensor.
123+ It ensures that `None` values in the labels are handled by substituting them with a default value(e.g.,`False`).
93124
94125 Args:
95126 labels (Tuple): Tuple of label rows.
0 commit comments