Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
876d946
basic data processing for go-uniprot dataset
aditya0by0 Jul 15, 2024
b2d13e9
Merge branch 'dev' into protein_prediction
aditya0by0 Jul 21, 2024
4844380
prepare_data : sequence added to graph creation process
aditya0by0 Jul 21, 2024
795c017
prepare_data: filter out any rows without any True value
aditya0by0 Jul 21, 2024
4f06b62
setup data phase : preprocessing
aditya0by0 Jul 25, 2024
1367975
add reader for protein data
aditya0by0 Jul 26, 2024
f202579
config : GO 50
aditya0by0 Jul 26, 2024
a07c020
Update setup.py
aditya0by0 Jul 26, 2024
07e5114
fix - local permission error for swiss data
aditya0by0 Jul 26, 2024
b334929
go_uniprot : docstrings + variable namings
aditya0by0 Jul 28, 2024
5cdc9b8
chebi.py : additional/more specific docstrings
aditya0by0 Jul 31, 2024
0ee241a
base class for datasets following new dynamics splits feature
aditya0by0 Aug 2, 2024
d182a22
update _ChEBIDataExtractor as per newly inherited _DynamicDataset bas…
aditya0by0 Aug 2, 2024
25a9594
update _GOUniprotDataExtractor to inherit _DynamicDataset
aditya0by0 Aug 2, 2024
4ac6bc2
Merge branch 'dev' into protein_prediction
aditya0by0 Aug 9, 2024
5a4860d
add load_processed_data to base
aditya0by0 Aug 10, 2024
53daf97
go data: changes
aditya0by0 Aug 13, 2024
499fafc
update _graph_to_raw_dataset method
aditya0by0 Aug 14, 2024
19c47c1
fix tokenizing process in reader class for protein
aditya0by0 Aug 14, 2024
ecb276a
protein tokens - 20 natural amino acid tokens
aditya0by0 Aug 14, 2024
5f9ff93
minor updates
aditya0by0 Aug 14, 2024
b916994
filter out swiss protein as per given criterias in paper
aditya0by0 Aug 14, 2024
079269b
fixes: go_branch filtering, protein sequence
aditya0by0 Aug 15, 2024
638598a
update logic to select go classes based on proteins dataset
aditya0by0 Aug 15, 2024
9200b73
fix: dataframe column addition performance warning
aditya0by0 Aug 16, 2024
f9c10f7
consistent prefix "GOUniProt" for all classes
aditya0by0 Aug 25, 2024
f39916b
update go configs for new class names
aditya0by0 Aug 25, 2024
4db76ce
extra documentation for ragged coll as per the comment
aditya0by0 Sep 9, 2024
06ab981
minor changes
aditya0by0 Sep 9, 2024
62a3f45
parameter for maximum length (default: 1002)
aditya0by0 Sep 21, 2024
6f463de
remove label number for GO_UniProt classes
aditya0by0 Sep 21, 2024
108d9ca
trigrams / n-grams combining several amino acids into one token
aditya0by0 Sep 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions chebai/preprocessing/bin/protein_token/tokens.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
M
S
I
G
A
T
R
L
Q
N
D
K
Y
P
C
F
W
E
V
H
43 changes: 37 additions & 6 deletions chebai/preprocessing/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,41 @@ def __call__(self, data: List[Dict]) -> XYData:


class RaggedCollator(Collator):
"""Collator for handling ragged data samples."""
"""
Collator for handling ragged data samples, designed to support scenarios where some labels may be missing (None).

This class is specifically designed for preparing batches of "ragged" data, where the samples may have varying sizes,
such as molecular representations or variable-length protein sequences. Additionally, it supports cases where some
of the data samples might be partially labeled, which is useful for certain loss functions that allow training
with incomplete or fuzzy data (e.g., fuzzy loss).

During batching, the class pads the data samples to a uniform length, applies appropriate masks to differentiate
between valid and padded elements, and ensures that label misalignment is handled by filtering out unlabelled
data points. The indices of valid labels are stored in the `non_null_labels` field, which can be used later for
metrics computation such as F1-score or MSE, especially in cases where some data points lack labels.

Reference: https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829
"""

def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
"""Collate ragged data samples (i.e., samples of unequal size such as string representations of molecules) into
a batch.
"""
Collate ragged data samples (i.e., samples of unequal size, such as molecular sequences) into a batch.

Handles both fully and partially labeled data, where some samples may have `None` as their label. The indices
of non-null labels are stored in the `non_null_labels` field, which is used to filter out predictions for
unlabeled data during evaluation (e.g., F1, MSE). For models supporting partially labeled data, this method
ensures alignment between features and labels.

Args:
data (List[Union[Dict, Tuple]]): List of ragged data samples.
data (List[Union[Dict, Tuple]]): List of ragged data samples. Each sample can be a dictionary or tuple
with 'features', 'labels', and 'ident'.

Returns:
XYData: Batched data with appropriate padding and masks.
XYData: A batch of padded sequences and labels, including masks for valid positions and indices of
non-null labels for metric computation.
"""
model_kwargs: Dict = dict()
# Indices of non-null labels are stored in key `non_null_labels` of loss_kwargs.
loss_kwargs: Dict = dict()

if isinstance(data[0], tuple):
Expand All @@ -64,18 +86,23 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
*((d["features"], d["labels"], d.get("ident")) for d in data)
)
if any(x is not None for x in y):
# If any label is not None: (None, None, `1`, None)
if any(x is None for x in y):
# If any label is None: (`None`, `None`, 1, `None`)
non_null_labels = [i for i, r in enumerate(y) if r is not None]
y = self.process_label_rows(
tuple(ye for i, ye in enumerate(y) if i in non_null_labels)
)
loss_kwargs["non_null_labels"] = non_null_labels
else:
# If all labels are not None: (`0`, `2`, `1`, `3`)
y = self.process_label_rows(y)
else:
# If all labels are None : (`None`, `None`, `None`, `None`)
y = None
loss_kwargs["non_null_labels"] = []

# Calculate the lengths of each sequence, create a binary mask for valid (non-padded) positions
lens = torch.tensor(list(map(len, x)))
model_kwargs["mask"] = torch.arange(max(lens))[None, :] < lens[:, None]
model_kwargs["lens"] = lens
Expand All @@ -89,7 +116,11 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData:
)

def process_label_rows(self, labels: Tuple) -> torch.Tensor:
"""Process label rows by padding sequences.
"""
Process label rows by padding sequences to ensure uniform shape across the batch.

This method pads the label rows, converting sequences of labels of different lengths into a uniform tensor.
It ensures that `None` values in the labels are handled by substituting them with a default value(e.g.,`False`).

Args:
labels (Tuple): Tuple of label rows.
Expand Down
Loading
Loading