diff --git a/chebai/preprocessing/bin/protein_token/tokens.txt b/chebai/preprocessing/bin/protein_token/tokens.txt new file mode 100644 index 00000000..72ad1b6d --- /dev/null +++ b/chebai/preprocessing/bin/protein_token/tokens.txt @@ -0,0 +1,20 @@ +M +S +I +G +A +T +R +L +Q +N +D +K +Y +P +C +F +W +E +V +H diff --git a/chebai/preprocessing/collate.py b/chebai/preprocessing/collate.py index 4e5e9e16..ecbcb876 100644 --- a/chebai/preprocessing/collate.py +++ b/chebai/preprocessing/collate.py @@ -41,19 +41,41 @@ def __call__(self, data: List[Dict]) -> XYData: class RaggedCollator(Collator): - """Collator for handling ragged data samples.""" + """ + Collator for handling ragged data samples, designed to support scenarios where some labels may be missing (None). + + This class is specifically designed for preparing batches of "ragged" data, where the samples may have varying sizes, + such as molecular representations or variable-length protein sequences. Additionally, it supports cases where some + of the data samples might be partially labeled, which is useful for certain loss functions that allow training + with incomplete or fuzzy data (e.g., fuzzy loss). + + During batching, the class pads the data samples to a uniform length, applies appropriate masks to differentiate + between valid and padded elements, and ensures that label misalignment is handled by filtering out unlabelled + data points. The indices of valid labels are stored in the `non_null_labels` field, which can be used later for + metrics computation such as F1-score or MSE, especially in cases where some data points lack labels. + + Reference: https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829 + """ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData: - """Collate ragged data samples (i.e., samples of unequal size such as string representations of molecules) into - a batch. + """ + Collate ragged data samples (i.e., samples of unequal size, such as molecular sequences) into a batch. + + Handles both fully and partially labeled data, where some samples may have `None` as their label. The indices + of non-null labels are stored in the `non_null_labels` field, which is used to filter out predictions for + unlabeled data during evaluation (e.g., F1, MSE). For models supporting partially labeled data, this method + ensures alignment between features and labels. Args: - data (List[Union[Dict, Tuple]]): List of ragged data samples. + data (List[Union[Dict, Tuple]]): List of ragged data samples. Each sample can be a dictionary or tuple + with 'features', 'labels', and 'ident'. Returns: - XYData: Batched data with appropriate padding and masks. + XYData: A batch of padded sequences and labels, including masks for valid positions and indices of + non-null labels for metric computation. """ model_kwargs: Dict = dict() + # Indices of non-null labels are stored in key `non_null_labels` of loss_kwargs. loss_kwargs: Dict = dict() if isinstance(data[0], tuple): @@ -64,18 +86,23 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData: *((d["features"], d["labels"], d.get("ident")) for d in data) ) if any(x is not None for x in y): + # If any label is not None: (None, None, `1`, None) if any(x is None for x in y): + # If any label is None: (`None`, `None`, 1, `None`) non_null_labels = [i for i, r in enumerate(y) if r is not None] y = self.process_label_rows( tuple(ye for i, ye in enumerate(y) if i in non_null_labels) ) loss_kwargs["non_null_labels"] = non_null_labels else: + # If all labels are not None: (`0`, `2`, `1`, `3`) y = self.process_label_rows(y) else: + # If all labels are None : (`None`, `None`, `None`, `None`) y = None loss_kwargs["non_null_labels"] = [] + # Calculate the lengths of each sequence, create a binary mask for valid (non-padded) positions lens = torch.tensor(list(map(len, x))) model_kwargs["mask"] = torch.arange(max(lens))[None, :] < lens[:, None] model_kwargs["lens"] = lens @@ -89,7 +116,11 @@ def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData: ) def process_label_rows(self, labels: Tuple) -> torch.Tensor: - """Process label rows by padding sequences. + """ + Process label rows by padding sequences to ensure uniform shape across the batch. + + This method pads the label rows, converting sequences of labels of different lengths into a uniform tensor. + It ensures that `None` values in the labels are handled by substituting them with a default value(e.g.,`False`). Args: labels (Tuple): Tuple of label rows. diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 30aa6551..02877ad3 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1,10 +1,17 @@ import os import random -from typing import Any, Dict, Generator, List, Optional, Union +from abc import ABC, abstractmethod +from typing import Any, Dict, Generator, List, Optional, Tuple, Union import lightning as pl +import networkx as nx +import pandas as pd import torch import tqdm +from iterstrat.ml_stratifiers import ( + MultilabelStratifiedKFold, + MultilabelStratifiedShuffleSplit, +) from lightning.pytorch.core.datamodule import LightningDataModule from lightning_utilities.core.rank_zero import rank_zero_info from torch.utils.data import DataLoader @@ -148,6 +155,8 @@ def _name(self) -> str: def _filter_labels(self, row: dict) -> dict: """ Filter labels based on `label_filter`. + This method selects specific labels from the `labels` list within the row dictionary + according to the index or indices provided by the `label_filter` attribute of the class. Args: row (dict): A dictionary containing the row data. @@ -583,3 +592,554 @@ def limits(self): Returns None, assuming no limits on data slicing. """ return None + + +class _DynamicDataset(XYBaseDataModule, ABC): + """ + A class for extracting and processing data from the given dataset. + + The processed and transformed data is stored in `data.pkl` and `data.pt` format as a whole respectively, + rather than as separate train, validation, and test splits, with dynamic splitting of data.pt occurring at runtime. + The `_DynamicDataset` class manages data splits by either generating them during execution or retrieving them from + a CSV file. + If no split file path is provided, `_generate_dynamic_splits` creates the training, validation, and test splits + from the encoded/transformed data, storing them in `_dynamic_df_train`, `_dynamic_df_val`, and `_dynamic_df_test`. + When a split file path is provided, `_retrieve_splits_from_csv` loads splits from the CSV file, which must + include 'id' and 'split' columns. + The `dynamic_split_dfs` property ensures that the necessary splits are loaded as required. + + Args: + dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. + splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. + **kwargs: Additional keyword arguments passed to XYBaseDataModule. + + Attributes: + dynamic_data_split_seed (int): The seed for random data splitting, default is 42. + splits_file_path (Optional[str]): Path to the CSV file containing split assignments. + """ + + # ---- Index for columns of processed `data.pkl` (should be derived from `_graph_to_raw_dataset` method) ------ + _ID_IDX: int = None + _DATA_REPRESENTATION_IDX: int = None + _LABELS_START_IDX: int = None + + def __init__( + self, + **kwargs, + ): + super(_DynamicDataset, self).__init__(**kwargs) + self.dynamic_data_split_seed = int(kwargs.get("seed", 42)) # default is 42 + # Class variables to store the dynamics splits + self._dynamic_df_train = None + self._dynamic_df_test = None + self._dynamic_df_val = None + # Path of csv file which contains a list of ids & their assignment to a dataset (either train, + # validation or test). + self.splits_file_path = self._validate_splits_file_path( + kwargs.get("splits_file_path", None) + ) + + @staticmethod + def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]: + """ + Validates the file in provided splits file path. + + Args: + splits_file_path (Optional[str]): Path to the splits CSV file. + + Returns: + Optional[str]: Validated splits file path if checks pass, None if splits_file_path is None. + + Raises: + FileNotFoundError: If the splits file does not exist. + ValueError: If splits file is empty or missing required columns ('id' and/or 'split'), or not a CSV file. + """ + if splits_file_path is None: + return None + + if not os.path.isfile(splits_file_path): + raise FileNotFoundError(f"File {splits_file_path} does not exist") + + file_size = os.path.getsize(splits_file_path) + if file_size == 0: + raise ValueError(f"File {splits_file_path} is empty") + + # Check if the file has a CSV extension + if not splits_file_path.lower().endswith(".csv"): + raise ValueError(f"File {splits_file_path} is not a CSV file") + + # Read the first row of CSV file into a DataFrame + splits_df = pd.read_csv(splits_file_path, nrows=1) + + # Check if 'id' and 'split' columns are in the DataFrame + required_columns = {"id", "split"} + if not required_columns.issubset(splits_df.columns): + raise ValueError( + f"CSV file {splits_file_path} is missing required columns ('id' and/or 'split')." + ) + + return splits_file_path + + # ------------------------------ Phase: Prepare data ----------------------------------- + def prepare_data(self, *args: Any, **kwargs: Any) -> None: + """ + Prepares the data for the dataset. + + This method checks for the presence of raw data in the specified directory. + If the raw data is missing, it fetches the ontology and creates a dataframe and saves it to a data.pkl file. + + The resulting dataframe/pickle file is expected to contain columns with the following structure: + - Column at index `self._ID_IDX`: ID of data instance + - Column at index `self._DATA_REPRESENTATION_IDX`: Sequence representation of the protein + - Column from index `self._LABELS_START_IDX` onwards: Labels + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + None + """ + print("Checking for processed data in", self.processed_dir_main) + + processed_name = self.processed_dir_main_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): + print("Missing processed data file (`data.pkl` file)") + os.makedirs(self.processed_dir_main, exist_ok=True) + data_path = self._download_required_data() + g = self._extract_class_hierarchy(data_path) + data_df = self._graph_to_raw_dataset(g) + self.save_processed(data_df, processed_name) + + @abstractmethod + def _download_required_data(self) -> str: + """ + Downloads the required raw data. + + Returns: + str: Path to the downloaded data. + """ + pass + + @abstractmethod + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + """ + Extracts the class hierarchy from the data. + Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from + the term documents. + + Args: + data_path (str): Path to the data. + + Returns: + nx.DiGraph: The class hierarchy graph. + """ + pass + + @abstractmethod + def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: + """ + Converts the graph to a raw dataset. + Uses the graph created by `_extract_class_hierarchy` method to extract the + raw data in Dataframe format with additional columns corresponding to each multi-label class. + + Args: + graph (nx.DiGraph): The class hierarchy graph. + + Returns: + pd.DataFrame: The raw dataset. + """ + pass + + @abstractmethod + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + """ + Selects classes from the dataset based on a specified criteria. + + Args: + g (nx.Graph): The graph representing the dataset. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + List: A sorted list of node IDs that meet the specified criteria. + """ + pass + + def save_processed(self, data: pd.DataFrame, filename: str) -> None: + """ + Save the processed dataset to a pickle file. + + Args: + data (pd.DataFrame): The processed dataset to be saved. + filename (str): The filename for the pickle file. + """ + pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb")) + + # ------------------------------ Phase: Setup data ----------------------------------- + def setup_processed(self) -> None: + """ + Transforms `data.pkl` into a model input data format (`data.pt`), ensuring that the data is in a format + compatible for input to the model. + The transformed data contains the following keys: `ident`, `features`, `labels`, and `group`. + This method uses a subclass of Data Reader to perform the transformation. + + Returns: + None + """ + os.makedirs(self.processed_dir, exist_ok=True) + print("Missing transformed data (`data.pt` file). Transforming data.... ") + torch.save( + self._load_data_from_file( + os.path.join( + self.processed_dir_main, + self.processed_dir_main_file_names_dict["data"], + ) + ), + os.path.join(self.processed_dir, self.processed_file_names_dict["data"]), + ) + + @staticmethod + def _get_data_size(input_file_path: str) -> int: + """ + Get the size of the data from a pickled file. + + Args: + input_file_path (str): The path to the file. + + Returns: + int: The size of the data. + """ + with open(input_file_path, "rb") as f: + return len(pd.read_pickle(f)) + + @abstractmethod + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from given pickled file and yields individual dictionaries for each row. + + This method is used by `_load_data_from_file` to generate dictionaries that are then + processed and converted into a list of dictionaries containing the features and labels. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Generator[Dict[str, Any], None, None]: Generator yielding dictionaries. + + """ + pass + + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + @property + def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: + """ + Property to retrieve dynamic train, validation, and test splits. + + This property checks if dynamic data splits (`_dynamic_df_train`, `_dynamic_df_val`, `_dynamic_df_test`) + are already loaded. If any of them is None, it either generates them dynamically or retrieves them + from data file with help of pre-existing split csv file (`splits_file_path`) containing splits assignments. + + Returns: + dict: A dictionary containing the dynamic train, validation, and test DataFrames. + Keys are 'train', 'validation', and 'test'. + """ + if any( + split is None + for split in [ + self._dynamic_df_test, + self._dynamic_df_val, + self._dynamic_df_train, + ] + ): + if self.splits_file_path is None: + # Generate splits based on given seed, create csv file to records the splits + self._generate_dynamic_splits() + else: + # If user has provided splits file path, use it to get the splits from the data + self._retrieve_splits_from_csv() + return { + "train": self._dynamic_df_train, + "validation": self._dynamic_df_val, + "test": self._dynamic_df_test, + } + + def _generate_dynamic_splits(self) -> None: + """ + Generate data splits during runtime and save them in class variables. + + This method loads encoded data and generates train, validation, and test splits based on the loaded data. + """ + print("\nGenerate dynamic splits...") + df_train, df_val, df_test = self._get_data_splits() + + # Generate splits.csv file to store ids of each corresponding split + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": df_train["ident"], "split": "train"}), + pd.DataFrame({"id": df_val["ident"], "split": "validation"}), + pd.DataFrame({"id": df_test["ident"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + combined_split_assignment.to_csv( + os.path.join(self.processed_dir_main, "splits.csv"), index=False + ) + + # Store the splits in class variables + self._dynamic_df_train = df_train + self._dynamic_df_val = df_val + self._dynamic_df_test = df_test + + @abstractmethod + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Retrieve the train, validation, and test data splits for the dataset. + + This method returns data splits according to specific criteria implemented + in the subclasses. + + Returns: + tuple: A tuple containing DataFrames for train, validation, and test splits. + """ + pass + + def get_test_split( + self, df: pd.DataFrame, seed: Optional[int] = None + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Split the input DataFrame into training and testing sets based on multilabel stratified sampling. + + This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels + in the training and testing sets is approximately the same. The split is based on the "labels" column + in the DataFrame. + + Args: + df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column + named "labels" with the multilabel data. + seed (int, optional): The random seed to be used for reproducibility. Default is None. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames. + + Raises: + ValueError: If the DataFrame does not contain a column named "labels". + """ + print("Get test data split") + + labels_list = df["labels"].tolist() + + test_size = 1 - self.train_split - (1 - self.train_split) ** 2 + msss = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=seed + ) + + train_indices, test_indices = next(msss.split(labels_list, labels_list)) + + df_train = df.iloc[train_indices] + df_test = df.iloc[test_indices] + return df_train, df_test + + def get_train_val_splits_given_test( + self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None + ) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: + """ + Split the dataset into train and validation sets, given a test set. + Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap + + Args: + df (pd.DataFrame): The original dataset. + test_df (pd.DataFrame): The test dataset. + seed (int, optional): The random seed to be used for reproducibility. Default is None. + + Returns: + Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and + validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train + and validation DataFrames. The keys are the names of the train and validation sets, and the values + are the corresponding DataFrames. + """ + print(f"Split dataset into train / val with given test set") + + test_ids = test_df["ident"].tolist() + df_trainval = df[~df["ident"].isin(test_ids)] + labels_list_trainval = df_trainval["labels"].tolist() + + if self.use_inner_cross_validation: + folds = {} + kfold = MultilabelStratifiedKFold( + n_splits=self.inner_k_folds, random_state=seed + ) + for fold, (train_ids, val_ids) in enumerate( + kfold.split( + labels_list_trainval, + labels_list_trainval, + ) + ): + df_validation = df_trainval.iloc[val_ids] + df_train = df_trainval.iloc[train_ids] + folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train + folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = ( + df_validation + ) + + return folds + + # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split) + test_size = ((1 - self.train_split) ** 2) / self.train_split + msss = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=seed + ) + + train_indices, validation_indices = next( + msss.split(labels_list_trainval, labels_list_trainval) + ) + + df_validation = df_trainval.iloc[validation_indices] + df_train = df_trainval.iloc[train_indices] + return df_train, df_validation + + def _retrieve_splits_from_csv(self) -> None: + """ + Retrieve previously saved data splits from splits.csv file or from provided file path. + + This method loads the splits.csv file located at `self.splits_file_path`. + It then loads the encoded data (`data.pt`) and filters it based on the IDs retrieved from + splits.csv to reconstruct the train, validation, and test splits. + """ + print(f"\nLoading splits from {self.splits_file_path}...") + splits_df = pd.read_csv(self.splits_file_path) + + filename = self.processed_file_names_dict["data"] + data = torch.load(os.path.join(self.processed_dir, filename)) + df_data = pd.DataFrame(data) + + train_ids = splits_df[splits_df["split"] == "train"]["id"] + validation_ids = splits_df[splits_df["split"] == "validation"]["id"] + test_ids = splits_df[splits_df["split"] == "test"]["id"] + + self._dynamic_df_train = df_data[df_data["ident"].isin(train_ids)] + self._dynamic_df_val = df_data[df_data["ident"].isin(validation_ids)] + self._dynamic_df_test = df_data[df_data["ident"].isin(test_ids)] + + def load_processed_data( + self, kind: Optional[str] = None, filename: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Loads processed data from a specified dataset type or file. + + This method retrieves processed data based on the dataset type (`kind`) such as "train", + "val", or "test", or directly from a provided filename. When `kind` is specified, the method + leverages the `dynamic_split_dfs` property to dynamically generate or retrieve the corresponding + data splits if they are not already loaded. If both `kind` and `filename` are provided, `filename` + takes precedence. + + Args: + kind (str, optional): The type of dataset to load ("train", "val", or "test"). + If `filename` is provided, this argument is ignored. Defaults to None. + filename (str, optional): The name of the file to load the dataset from. + If provided, this takes precedence over `kind`. Defaults to None. + + Returns: + List[Dict[str, Any]]: A list of dictionaries, where each dictionary contains + the processed data for an individual data point. + + Raises: + ValueError: If both `kind` and `filename` are None, as one of them is required to load the dataset. + KeyError: If the specified `kind` does not exist in the `dynamic_split_dfs` property or + `processed_file_names_dict`, when expected. + FileNotFoundError: If the file corresponding to the provided `filename` does not exist. + """ + if kind is None and filename is None: + raise ValueError( + "Either kind or filename is required to load the correct dataset, both are None" + ) + + # If both kind and filename are given, use filename + if kind is not None and filename is None: + try: + if self.use_inner_cross_validation and kind != "test": + filename = self.processed_file_names_dict[ + f"fold_{self.fold_index}_{kind}" + ] + else: + data_df = self.dynamic_split_dfs[kind] + return data_df.to_dict(orient="records") + except KeyError: + kind = f"{kind}" + + # If filename is provided + try: + return torch.load(os.path.join(self.processed_dir, filename)) + except FileNotFoundError: + raise FileNotFoundError(f"File {filename} doesn't exist") + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + @abstractmethod + def base_dir(self) -> str: + """ + Returns the base directory path for storing data. + + Returns: + str: The path to the base directory. + """ + pass + + @property + def processed_dir_main(self) -> str: + """ + Returns the main directory path where processed data is stored. + + Returns: + str: The path to the main processed data directory, based on the base directory and the instance's name. + """ + return os.path.join( + self.base_dir, + self._name, + "processed", + ) + + @property + def processed_dir(self) -> str: + """ + Returns the specific directory path for processed data, including identifiers. + + Returns: + str: The path to the processed data directory, including additional identifiers. + """ + return os.path.join( + self.processed_dir_main, + *self.identifier, + ) + + @property + def processed_dir_main_file_names_dict(self) -> dict: + """ + Returns a dictionary mapping processed data file names, processed by `prepare_data` method. + + Returns: + dict: A dictionary mapping dataset types to their respective processed file names. + For example, {"data": "data.pkl"}. + """ + return {"data": "data.pkl"} + + @property + def processed_file_names_dict(self) -> dict: + """ + Returns a dictionary mapping processed and transformed data file names to their final formats, which are + processed by `setup` method. + + Returns: + dict: A dictionary mapping dataset types to their respective final file names. + For example, {"data": "data.pt"}. + """ + return {"data": "data.pt"} + + @property + def processed_file_names(self) -> List[str]: + """ + Returns a list of file names for processed data. + + Returns: + List[str]: A list of file names corresponding to the processed data. + """ + return list(self.processed_file_names_dict.values()) diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 5876577f..1c0cb2f9 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -13,20 +13,16 @@ import pickle from abc import ABC from collections import OrderedDict -from typing import Any, Dict, Generator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Optional, Tuple import fastobo import networkx as nx import pandas as pd import requests import torch -from iterstrat.ml_stratifiers import ( - MultilabelStratifiedKFold, - MultilabelStratifiedShuffleSplit, -) from chebai.preprocessing import reader as dr -from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.preprocessing.datasets.base import XYBaseDataModule, _DynamicDataset # exclude some entities from the dataset because the violate disjointness axioms CHEBI_BLACKLIST = [ @@ -109,7 +105,7 @@ class JCITokenData(JCIBase): READER = dr.ChemDataReader -class _ChEBIDataExtractor(XYBaseDataModule, ABC): +class _ChEBIDataExtractor(_DynamicDataset, ABC): """ A class for extracting and processing data from the ChEBI dataset. @@ -126,12 +122,18 @@ class _ChEBIDataExtractor(XYBaseDataModule, ABC): single_class (Optional[int]): The ID of the single class to predict. chebi_version_train (Optional[int]): The version of ChEBI to use for training and validation. dynamic_data_split_seed (int): The seed for random data splitting, default is 42. - dynamic_df_train (Optional[pd.DataFrame]): DataFrame to store the training data split. - dynamic_df_test (Optional[pd.DataFrame]): DataFrame to store the test data split. - dynamic_df_val (Optional[pd.DataFrame]): DataFrame to store the validation data split. splits_file_path (Optional[str]): Path to csv file containing split assignments. """ + # ---- Index for columns of processed `data.pkl` (derived from `_graph_to_raw_dataset` method) ------ + # "id" at row index 0 + # "name" at row index 1 + # "SMILES" at row index 2 + # labels starting from row index 3 + _ID_IDX: int = 0 + _DATA_REPRESENTATION_IDX: int = 2 + _LABELS_START_IDX: int = 3 + def __init__( self, chebi_version_train: Optional[int] = None, @@ -144,11 +146,6 @@ def __init__( # use different version of chebi for training and validation (if not None) # (still uses self.chebi_version for test set) self.chebi_version_train = chebi_version_train - self.dynamic_data_split_seed = int(kwargs.get("seed", 42)) # default is 42 - # Class variables to store the dynamics splits - self.dynamic_df_train = None - self.dynamic_df_test = None - self.dynamic_df_val = None if self.chebi_version_train is not None: # Instantiate another same class with "chebi_version" as "chebi_version_train", if train_version is given @@ -159,89 +156,116 @@ def __init__( single_class=self.single_class, **_init_kwargs, ) - # Path of csv file which contains a list of chebi ids & their assignment to a dataset (either train, validation or test). - self.splits_file_path = self._validate_splits_file_path( - kwargs.get("splits_file_path", None) - ) - @staticmethod - def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]: + # ------------------------------ Phase: Prepare data ----------------------------------- + def prepare_data(self, *args: Any, **kwargs: Any) -> None: """ - Validates the provided splits file path. + Prepares the data for the Chebi dataset. + + This method checks for the presence of raw data in the specified directory. + If the raw data is missing, it fetches the ontology and creates a dataframe & saves it as data.pkl pickle file. + + The resulting dataframe/pickle file is expected to contain columns with the following structure: + - Column at index `self._ID_IDX`: ID of chebi data instance + - Column at index `self._DATA_REPRESENTATION_IDX`: SMILES representation of the chemical + - Column from index `self._LABELS_START_IDX` onwards: Labels + + It will pre-process the data related to `chebi_version_train`, if specified. Args: - splits_file_path (Optional[str]): Path to the splits CSV file. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. Returns: - Optional[str]: Validated splits file path if checks pass, None if splits_file_path is None. - - Raises: - FileNotFoundError: If the splits file does not exist. - ValueError: If the splits file is empty or missing required columns ('id' and/or 'split'), or not a CSV file. + None """ - if splits_file_path is None: - return None + super().prepare_data(args, kwargs) - if not os.path.isfile(splits_file_path): - raise FileNotFoundError(f"File {splits_file_path} does not exist") + if self.chebi_version_train is not None: + if not os.path.isfile( + os.path.join( + self._chebi_version_train_obj.processed_dir_main, + self._chebi_version_train_obj.processed_dir_main_file_names_dict[ + "data" + ], + ) + ): + print( + f"Missing processed data related to train version: {self.chebi_version_train}" + ) + print("Calling the prepare_data method related to it") + # Generate the "chebi_version_train" data if it doesn't exist + self._chebi_version_train_obj.prepare_data(*args, **kwargs) + + def _download_required_data(self) -> str: + """ + Downloads the required raw data related to chebi. - file_size = os.path.getsize(splits_file_path) - if file_size == 0: - raise ValueError(f"File {splits_file_path} is empty") + Returns: + str: Path to the downloaded data. + """ + return self._load_chebi(self.chebi_version) - # Check if the file has a CSV extension - if not splits_file_path.lower().endswith(".csv"): - raise ValueError(f"File {splits_file_path} is not a CSV file") + def _load_chebi(self, version: int) -> str: + """ + Load the ChEBI ontology file. - # Read the first row of CSV file into a DataFrame - splits_df = pd.read_csv(splits_file_path, nrows=1) + Args: + version (int): The version of the ChEBI ontology to load. - # Check if 'id' and 'split' columns are in the DataFrame - required_columns = {"id", "split"} - if not required_columns.issubset(splits_df.columns): - raise ValueError( - f"CSV file {splits_file_path} is missing required columns ('id' and/or 'split')." + Returns: + str: The file path of the loaded ChEBI ontology. + """ + chebi_name = ( + f"chebi.obo" if version == self.chebi_version else f"chebi_v{version}.obo" + ) + chebi_path = os.path.join(self.raw_dir, chebi_name) + if not os.path.isfile(chebi_path): + print( + f"Missing raw chebi data related to version: v_{version}, Downloading..." ) + url = f"http://purl.obolibrary.org/obo/chebi/{version}/chebi.obo" + r = requests.get(url, allow_redirects=True) + open(chebi_path, "wb").write(r.content) + return chebi_path - return splits_file_path - - def extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: """ Extracts the class hierarchy from the ChEBI ontology. + Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from + the chebi term documents from `.obo` file. Args: - chebi_path (str): The path to the ChEBI ontology. + data_path (str): The path to the ChEBI ontology. Returns: nx.DiGraph: The class hierarchy. """ - with open(chebi_path, encoding="utf-8") as chebi: + with open(data_path, encoding="utf-8") as chebi: chebi = "\n".join(l for l in chebi if not l.startswith("xref:")) + elements = [ term_callback(clause) for clause in fastobo.loads(chebi) if clause and ":" in str(clause.id) ] + g = nx.DiGraph() for n in elements: g.add_node(n["id"], **n) g.add_edges_from([(p, q["id"]) for q in elements for p in q["parents"]]) + print("Compute transitive closure") return nx.transitive_closure_dag(g) - def select_classes(self, g, split_name, *args, **kwargs): - raise NotImplementedError - - def graph_to_raw_dataset( - self, g: nx.DiGraph, split_name: Optional[str] = None - ) -> pd.DataFrame: + def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: """ - Preparation step before creating splits, uses graph created by extract_class_hierarchy(), - split_name is only relevant, if a separate train_version is set. + Converts the graph to a raw dataset. + Uses the graph created by `_extract_class_hierarchy` method to extract the + raw data in Dataframe format with additional columns corresponding to each multi-label class. Args: g (nx.DiGraph): The class hierarchy graph. - split_name (Optional[str], optional): Name of the split. Defaults to None. Returns: pd.DataFrame: The raw dataset created from the graph. @@ -258,10 +282,14 @@ def graph_to_raw_dataset( if smiles ) ) - data = OrderedDict(id=molecules) - data["name"] = [names.get(node) for node in molecules] - data["SMILES"] = smiles_list - for n in self.select_classes(g, split_name): + data = OrderedDict(id=molecules) # `id` column at index 0 + data["name"] = [ + names.get(node) for node in molecules + ] # `name` column at index 1 + data["SMILES"] = smiles_list # `SMILES` (data representation) column at index 2 + + # Labels columns from index 3 onwards + for n in self.select_classes(g): data[n] = [ ((n in g.predecessors(node)) or (n == node)) for node in molecules ] @@ -269,527 +297,104 @@ def graph_to_raw_dataset( data = pd.DataFrame(data) data = data[~data["SMILES"].isnull()] data = data[[name not in CHEBI_BLACKLIST for name, _ in data.iterrows()]] - data = data[data.iloc[:, 3:].any(axis=1)] + # This filters the DataFrame to include only the rows where at least one value in the row from 4th column + # onwards is True/non-zero. + data = data[data.iloc[:, self._LABELS_START_IDX :].any(axis=1)] return data - def save_raw(self, data: pd.DataFrame, filename: str) -> None: - """ - Save the raw dataset to a pickle file. - - Args: - data (pd.DataFrame): The raw dataset to be saved. - filename (str): The filename for the pickle file. - """ - pd.to_pickle(data, open(os.path.join(self.raw_dir, filename), "wb")) - - def save_processed(self, data: pd.DataFrame, filename: str) -> None: - """ - Save the processed dataset to a pickle file. - - Args: - data (pd.DataFrame): The processed dataset to be saved. - filename (str): The filename for the pickle file. - """ - pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb")) - - def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: - """ - Loads a dictionary from a pickled file, yielding individual dictionaries for each row. - - Args: - input_file_path (str): The path to the file. - - Yields: - Dict[str, Any]: The dictionary, keys are `features`, `labels` and `ident`. - """ - with open(input_file_path, "rb") as input_file: - df = pd.read_pickle(input_file) - if self.single_class is not None: - single_cls_index = list(df.columns).index(int(self.single_class)) - for row in df.values: - if self.single_class is None: - labels = row[3:].astype(bool) - else: - labels = [bool(row[single_cls_index])] - yield dict(features=row[2], labels=labels, ident=row[0]) - - @staticmethod - def _get_data_size(input_file_path: str) -> int: - """ - Get the size of the data from a pickled file. - - Args: - input_file_path (str): The path to the file. - - Returns: - int: The size of the data. - """ - with open(input_file_path, "rb") as f: - return len(pd.read_pickle(f)) - - def _setup_pruned_test_set( - self, df_test_chebi_version: pd.DataFrame - ) -> pd.DataFrame: - """ - Create a test set with the same leaf nodes, but use only classes that appear in the training set. - - Args: - df_test_chebi_version (pd.DataFrame): The test dataset. - - Returns: - pd.DataFrame: The pruned test dataset. - """ - # TODO: find a more efficient way to do this - filename_old = "classes.txt" - # filename_new = f"classes_v{self.chebi_version_train}.txt" - # dataset = torch.load(os.path.join(self.processed_dir, "test.pt")) - - # Load original classes (from the current ChEBI version - chebi_version) - with open(os.path.join(self.processed_dir_main, filename_old), "r") as file: - orig_classes = file.readlines() - - # Load new classes (from the training ChEBI version - chebi_version_train) - with open( - os.path.join( - self._chebi_version_train_obj.processed_dir_main, filename_old - ), - "r", - ) as file: - new_classes = file.readlines() - - # Create a mapping which give index of a class from chebi_version, if the corresponding - # class exists in chebi_version_train, Size = Number of classes in chebi_version - mapping = [ - None if or_class not in new_classes else new_classes.index(or_class) - for or_class in orig_classes - ] - - # Iterate over each data instance in the test set which is derived from chebi_version - for _, row in df_test_chebi_version.iterrows(): - # Size = Number of classes in chebi_version_train - new_labels = [False for _ in new_classes] - for ind, label in enumerate(row["labels"]): - # If the chebi_version class exists in the chebi_version_train and has a True label, - # set the corresponding label in new_labels to True - if mapping[ind] is not None and label: - new_labels[mapping[ind]] = label - # Update the labels from test instance from chebi_version to the new labels, which are compatible to both versions - row["labels"] = new_labels - - # torch.save( - # chebi_ver_test_data, - # os.path.join(self.processed_dir, self.processed_file_names_dict["test"]), - # ) - return df_test_chebi_version - + # ------------------------------ Phase: Setup data ----------------------------------- def setup_processed(self) -> None: """ Transform and prepare processed data for the ChEBI dataset. - This method sets up the processed data directories and files based on the ChEBI version - and train version (if specified). It ensures that the required processed data files exist - by loading raw data, transforming it into processed format, and saving it. - - It also handles special cases, such as generating a pruned test set if `chebi_version_train` - is specified and the test set does not already exist. This pruned test set includes only - classes that appear in the training set. + Main function of this method is to transform `data.pkl` into a model input data format (`data.pt`), + ensuring that the data is in a format compatible for input to the model. + The transformed data must contain the following keys: `ident`, `features`, `labels`, and `group`. + This method uses a subclass of Data Reader to perform the transformation. + It will transform the data related to `chebi_version_train`, if specified. """ - print("Transform data") - os.makedirs(self.processed_dir, exist_ok=True) - # -------- Commented the code for Data Handling Restructure for Issue No.10 - # -------- https://github.com/ChEB-AI/python-chebai/issues/10 - # for k in self.processed_file_names_dict.keys(): - # processed_name = ( - # "test.pt" if k == "test" else self.processed_file_names_dict[k] - # ) - # if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): - # print("transform", k) - # torch.save( - # self._load_data_from_file( - # os.path.join(self.raw_dir, self.raw_file_names_dict[k]) - # ), - # os.path.join(self.processed_dir, processed_name), - # ) - # # create second test set with classes used in train - # if self.chebi_version_train is not None and not os.path.isfile( - # os.path.join(self.processed_dir, self.processed_file_names_dict["test"]) - # ): - # print("transform test (select classes)") - # self._setup_pruned_test_set() - # - # processed_name = self.processed_file_names_dict[k] - # if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): - # print( - # "Missing encoded data, transform processed data into encoded data", - # k, - # ) - # torch.save( - # self._load_data_from_file( - # os.path.join( - # self.processed_dir_main, self.raw_file_names_dict[k] - # ) - # ), - # os.path.join(self.processed_dir, processed_name), - # ) - - # Transform the processed data into encoded data - processed_name = self.processed_file_names_dict["data"] - if not os.path.isfile(os.path.join(self.processed_dir, processed_name)): - print( - f"Missing encoded data related to version {self.chebi_version}, transform processed data into encoded data:", - processed_name, - ) - torch.save( - self._load_data_from_file( - os.path.join( - self.processed_dir_main, - self.raw_file_names_dict["data"], - ) - ), - os.path.join(self.processed_dir, processed_name), - ) + super().setup_processed() # Transform the data related to "chebi_version_train" to encoded data, if it doesn't exist if self.chebi_version_train is not None and not os.path.isfile( os.path.join( self._chebi_version_train_obj.processed_dir, - self._chebi_version_train_obj.raw_file_names_dict["data"], + self._chebi_version_train_obj.processed_file_names_dict["data"], ) ): print( f"Missing encoded data related to train version: {self.chebi_version_train}" ) - print("Call the setup method related to it") + print("Calling the setup method related to it") self._chebi_version_train_obj.setup() - def get_test_split( - self, df: pd.DataFrame, seed: Optional[int] = None - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - """ - Split the input DataFrame into training and testing sets based on multilabel stratified sampling. - - This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels - in the training and testing sets is approximately the same. The split is based on the "labels" column - in the DataFrame. - - Args: - df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column - named "labels" with the multilabel data. - seed (int, optional): The random seed to be used for reproducibility. Default is None. - - Returns: - Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames. - - Raises: - ValueError: If the DataFrame does not contain a column named "labels". + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: """ - print("\nGet test data split") - - labels_list = df["labels"].tolist() + Loads a dictionary from a pickled file, yielding individual dictionaries for each row. - test_size = 1 - self.train_split - (1 - self.train_split) ** 2 - msss = MultilabelStratifiedShuffleSplit( - n_splits=1, test_size=test_size, random_state=seed - ) + This method reads data from a specified pickled file, processes each row to extract relevant + information, and yields dictionaries containing the keys `features`, `labels`, and `ident`. + If `single_class` is specified, it only includes the label for that specific class; otherwise, + it includes labels for all classes starting from the fourth column. - train_indices, test_indices = next(msss.split(labels_list, labels_list)) + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of the chebi data instance + - Data at row index `self._DATA_REPRESENTATION_IDX` : SMILES representation for the chemical + - Data from row index `self._LABELS_START_IDX` onwards: Labels - df_train = df.iloc[train_indices] - df_test = df.iloc[test_indices] - return df_train, df_test - - def get_train_val_splits_given_test( - self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None - ) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: - """ - Split the dataset into train and validation sets, given a test set. - Use test set (e.g., loaded from another chebi version or generated in get_test_split), to avoid overlap + This method is used in `_load_data_from_file` to process each row of data and convert it + into the desired dictionary format before loading it into the model. Args: - df (pd.DataFrame): The original dataset. - test_df (pd.DataFrame): The test dataset. - seed (int, optional): The random seed to be used for reproducibility. Default is None. + input_file_path (str): The path to the input pickled file. - Returns: - Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and - validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train - and validation DataFrames. The keys are the names of the train and validation sets, and the values - are the corresponding DataFrames. + Yields: + Dict[str, Any]: A dictionary with keys `features`, `labels`, and `ident`. + `features` contains the sequence, `labels` contains the labels as boolean values, + and `ident` contains the identifier. """ - print(f"Split dataset into train / val with given test set") - - test_ids = test_df["ident"].tolist() - # ---- list comprehension degrades performance, dataframe operations are faster - # mask = [trainval_id not in test_ids for trainval_id in df_trainval["ident"]] - # df_trainval = df_trainval[mask] - df_trainval = df[~df["ident"].isin(test_ids)] - labels_list_trainval = df_trainval["labels"].tolist() - - if self.use_inner_cross_validation: - folds = {} - kfold = MultilabelStratifiedKFold( - n_splits=self.inner_k_folds, random_state=seed - ) - for fold, (train_ids, val_ids) in enumerate( - kfold.split( - labels_list_trainval, - labels_list_trainval, - ) - ): - df_validation = df_trainval.iloc[val_ids] - df_train = df_trainval.iloc[train_ids] - folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train - folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = ( - df_validation + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + if self.single_class is not None: + single_cls_index = list(df.columns).index(int(self.single_class)) + for row in df.values: + if self.single_class is None: + labels = row[self._LABELS_START_IDX :].astype(bool) + else: + labels = [bool(row[single_cls_index])] + yield dict( + features=row[self._DATA_REPRESENTATION_IDX], + labels=labels, + ident=row[self._ID_IDX], ) - return folds - - # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split) - test_size = ((1 - self.train_split) ** 2) / self.train_split - msss = MultilabelStratifiedShuffleSplit( - n_splits=1, test_size=test_size, random_state=seed - ) - - train_indices, validation_indices = next( - msss.split(labels_list_trainval, labels_list_trainval) - ) - - df_validation = df_trainval.iloc[validation_indices] - df_train = df_trainval.iloc[train_indices] - return df_train, df_validation - - @property - def processed_dir_main(self) -> str: - """ - Return the main directory path for processed data. - - Returns: - str: The path to the main processed data directory. - """ - return os.path.join( - self.base_dir, - self._name, - "processed", - ) - - @property - def processed_dir(self) -> str: - """ - Return the directory path for processed data. - - Returns: - str: The path to the processed data directory. - """ - res = os.path.join( - self.processed_dir_main, - *self.identifier, - ) - if self.single_class is None: - return res - else: - return os.path.join(res, f"single_{self.single_class}") - - @property - def base_dir(self) -> str: - """ - Return the base directory path for data. - - Returns: - str: The base directory path for data. - """ - return os.path.join("data", f"chebi_v{self.chebi_version}") - - @property - def processed_file_names_dict(self) -> dict: - """ - Return a dictionary of processed file names. - - Returns: - dict: A dictionary where keys are file names and values are paths. - """ - train_v_str = ( - f"_v{self.chebi_version_train}" if self.chebi_version_train else "" - ) - # res = {"test": f"test{train_v_str}.pt"} - res = {} - - for set in ["train", "validation"]: - # TODO: code will be modified into CV issue for dynamic splits - if self.use_inner_cross_validation: - for i in range(self.inner_k_folds): - res[f"fold_{i}_{set}"] = os.path.join( - self.fold_dir, f"fold_{i}_{set}{train_v_str}.pt" - ) - # else: - # res[set] = f"{set}{train_v_str}.pt" - res["data"] = "data.pt" - return res - - @property - def raw_file_names_dict(self) -> dict: - """ - Return a dictionary of raw file names. - - Returns: - dict: A dictionary where keys are file names and values are paths. - """ - train_v_str = ( - f"_v{self.chebi_version_train}" if self.chebi_version_train else "" - ) - # res = { - # "test": f"test.pkl" - # } # no extra raw test version for chebi_version_train - use default test set and only - # adapt processed file - res = {} - for set in ["train", "validation"]: - # TODO: code will be modified into CV issue for dynamic splits - if self.use_inner_cross_validation: - for i in range(self.inner_k_folds): - res[f"fold_{i}_{set}"] = os.path.join( - self.fold_dir, f"fold_{i}_{set}{train_v_str}.pkl" - ) - # else: - # res[set] = f"{set}{train_v_str}.pkl" - res["data"] = "data.pkl" - return res - - @property - def processed_file_names(self) -> List[str]: - """ - Return a list of processed file names. - - Returns: - List[str]: A list containing processed file names. - """ - return list(self.processed_file_names_dict.values()) - - @property - def raw_file_names(self) -> List[str]: - """ - Return a list of raw file names. - - Returns: - List[str]: A list containing raw file names. + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ - return list(self.raw_file_names_dict.values()) + Loads encoded/transformed data and generates training, validation, and test splits. - def _load_chebi(self, version: int) -> str: - """ - Load the ChEBI ontology file. - - Args: - version (int): The version of the ChEBI ontology to load. - - Returns: - str: The file path of the loaded ChEBI ontology. - """ - chebi_name = ( - f"chebi.obo" if version == self.chebi_version else f"chebi_v{version}.obo" - ) - chebi_path = os.path.join(self.raw_dir, chebi_name) - if not os.path.isfile(chebi_path): - print(f"Load ChEBI ontology (v_{version})") - url = f"http://purl.obolibrary.org/obo/chebi/{version}/chebi.obo" - r = requests.get(url, allow_redirects=True) - open(chebi_path, "wb").write(r.content) - return chebi_path + This method first loads encoded data from a file named `data.pt`, which is derived from either + `chebi_version` or `chebi_version_train`. It then splits the data into training, validation, and test sets. - def prepare_data(self, *args: Any, **kwargs: Any) -> None: - """ - Prepares the data for the Chebi dataset. + If `chebi_version_train` is provided: + - Loads additional encoded data from `chebi_version_train`. + - Splits this data into training and validation sets, while using the test set from `chebi_version`. + - Prunes the test set from `chebi_version` to include only labels that exist in `chebi_version_train`. - This method checks for the presence of raw data in the specified directory. - If the raw data is missing, it fetches the ontology and creates a test set. - If the test set already exists, it loads it from the file. - Then, it creates the train/validation split based on the test set. + If `chebi_version_train` is not provided: + - Splits the data from `chebi_version` into training, validation, and test sets without modification. - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. + Raises: + FileNotFoundError: If the required `data.pt` file(s) do not exist. Ensure that `prepare_data` + and/or `setup` methods have been called to generate the dataset files. Returns: - None + Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing three DataFrames: + - Training set + - Validation set + - Test set """ - print("Check for processed data in", self.processed_dir_main) - if any( - not os.path.isfile(os.path.join(self.processed_dir_main, f)) - for f in self.raw_file_names - ): - os.makedirs(self.processed_dir_main, exist_ok=True) - print("Missing raw data. Go fetch...") - - # -------- Commented the code for Data Handling Restructure for Issue No.10 - # -------- https://github.com/ChEB-AI/python-chebai/issues/10 - # missing test set -> create - # if not os.path.isfile( - # os.path.join(self.raw_dir, self.raw_file_names_dict["test"]) - # ): - # chebi_path = self._load_chebi(self.chebi_version) - # g = self.extract_class_hierarchy(chebi_path) - # df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["test"]) - # _, test_df = self.get_test_split(df) - # self.save_raw(test_df, self.raw_file_names_dict["test"]) - # # load test_split from file - # else: - # with open( - # os.path.join(self.raw_dir, self.raw_file_names_dict["test"]), "rb" - # ) as input_file: - # test_df = pickle.load(input_file) - # # create train/val split based on test set - # chebi_path = self._load_chebi( - # self.chebi_version_train - # if self.chebi_version_train is not None - # else self.chebi_version - # ) - # g = self.extract_class_hierarchy(chebi_path) - # if self.use_inner_cross_validation: - # df = self.graph_to_raw_dataset( - # g, self.raw_file_names_dict[f"fold_0_train"] - # ) - # else: - # df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["train"]) - # train_val_dict = self.get_train_val_splits_given_test(df, test_df) - # for name, df in train_val_dict.items(): - # self.save_raw(df, name) - - # Data from chebi_version - chebi_path = self._load_chebi(self.chebi_version) - g = self.extract_class_hierarchy(chebi_path) - df = self.graph_to_raw_dataset(g, self.raw_file_names_dict["data"]) - self.save_processed(df, filename=self.raw_file_names_dict["data"]) - - if self.chebi_version_train is not None: - if not os.path.isfile( - os.path.join( - self._chebi_version_train_obj.processed_dir_main, - self._chebi_version_train_obj.raw_file_names_dict["data"], - ) - ): - print( - f"Missing processed data related to train version: {self.chebi_version_train}" - ) - print("Call the prepare_data method related to it") - # Generate the "chebi_version_train" data if it doesn't exist - self._chebi_version_train_obj.prepare_data(*args, **kwargs) - - def _generate_dynamic_splits(self) -> None: - """ - Generate data splits during runtime and save them in class variables. - - This method loads encoded data derived from either `chebi_version` or `chebi_version_train` - and generates train, validation, and test splits based on the loaded data. - If `chebi_version_train` is specified, the test set is pruned to include only labels that - exist in `chebi_version_train`. - - Raises: - FileNotFoundError: If the required data file (`data.pt`) for either `chebi_version` or `chebi_version_train` - does not exist. It advises calling `prepare_data` or `setup` methods to generate - the dataset files. - """ - print("Generate dynamic splits...") - # Load encoded data derived from "chebi_version" try: filename = self.processed_file_names_dict["data"] data_chebi_version = torch.load(os.path.join(self.processed_dir, filename)) @@ -841,130 +446,89 @@ def _generate_dynamic_splits(self) -> None: ) df_test = df_test_chebi_ver - # Generate splits.csv file to store ids of each corresponding split - split_assignment_list: List[pd.DataFrame] = [ - pd.DataFrame({"id": df_train["ident"], "split": "train"}), - pd.DataFrame({"id": df_val["ident"], "split": "validation"}), - pd.DataFrame({"id": df_test["ident"], "split": "test"}), - ] - combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) - combined_split_assignment.to_csv( - os.path.join(self.processed_dir_main, "splits.csv") - ) - - # Store the splits in class variables - self.dynamic_df_train = df_train - self.dynamic_df_val = df_val - self.dynamic_df_test = df_test + return df_train, df_val, df_test - def _retrieve_splits_from_csv(self) -> None: + def _setup_pruned_test_set( + self, df_test_chebi_version: pd.DataFrame + ) -> pd.DataFrame: """ - Retrieve previously saved data splits from splits.csv file or from provided file path. + Create a test set with the same leaf nodes, but use only classes that appear in the training set. + + Args: + df_test_chebi_version (pd.DataFrame): The test dataset. - This method loads the splits.csv file located at `self.splits_file_path`. - It then loads the encoded data (`data.pt`) derived from `chebi_version` and filters - it based on the IDs retrieved from splits.csv to reconstruct the train, validation, - and test splits. + Returns: + pd.DataFrame: The pruned test dataset. """ - print(f"Loading splits from {self.splits_file_path}...") - splits_df = pd.read_csv(self.splits_file_path) + # TODO: find a more efficient way to do this + filename_old = "classes.txt" + # filename_new = f"classes_v{self.chebi_version_train}.txt" + # dataset = torch.load(os.path.join(self.processed_dir, "test.pt")) - filename = self.processed_file_names_dict["data"] - data_chebi_version = torch.load(os.path.join(self.processed_dir, filename)) - df_chebi_version = pd.DataFrame(data_chebi_version) + # Load original classes (from the current ChEBI version - chebi_version) + with open(os.path.join(self.processed_dir_main, filename_old), "r") as file: + orig_classes = file.readlines() - train_ids = splits_df[splits_df["split"] == "train"]["id"] - validation_ids = splits_df[splits_df["split"] == "validation"]["id"] - test_ids = splits_df[splits_df["split"] == "test"]["id"] + # Load new classes (from the training ChEBI version - chebi_version_train) + with open( + os.path.join( + self._chebi_version_train_obj.processed_dir_main, filename_old + ), + "r", + ) as file: + new_classes = file.readlines() - self.dynamic_df_train = df_chebi_version[ - df_chebi_version["ident"].isin(train_ids) - ] - self.dynamic_df_val = df_chebi_version[ - df_chebi_version["ident"].isin(validation_ids) - ] - self.dynamic_df_test = df_chebi_version[ - df_chebi_version["ident"].isin(test_ids) + # Create a mapping which give index of a class from chebi_version, if the corresponding + # class exists in chebi_version_train, Size = Number of classes in chebi_version + mapping = [ + None if or_class not in new_classes else new_classes.index(or_class) + for or_class in orig_classes ] + # Iterate over each data instance in the test set which is derived from chebi_version + for _, row in df_test_chebi_version.iterrows(): + # Size = Number of classes in chebi_version_train + new_labels = [False for _ in new_classes] + for ind, label in enumerate(row["labels"]): + # If the chebi_version class exists in the chebi_version_train and has a True label, + # set the corresponding label in new_labels to True + if mapping[ind] is not None and label: + new_labels[mapping[ind]] = label + # Update the labels from test instance from chebi_version to the new labels, which are compatible to both versions + row["labels"] = new_labels + + return df_test_chebi_version + + # ------------------------------ Phase: Raw Properties ----------------------------------- @property - def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: + def base_dir(self) -> str: """ - Property to retrieve dynamic train, validation, and test splits. - - This property checks if dynamic data splits (`dynamic_df_train`, `dynamic_df_val`, `dynamic_df_test`) - are already loaded. If any of them is None, it either generates them dynamically or retrieves them - from data file with help of pre-existing Split csv file (`splits_file_path`) containing splits assignments. + Return the base directory path for data. Returns: - dict: A dictionary containing the dynamic train, validation, and test DataFrames. - Keys are 'train', 'validation', and 'test'. - """ - if any( - split is None - for split in [ - self.dynamic_df_test, - self.dynamic_df_val, - self.dynamic_df_train, - ] - ): - if self.splits_file_path is None: - # Generate splits based on given seed, create csv file to records the splits - self._generate_dynamic_splits() - else: - # If user has provided splits file path, use it to get the splits from the data - self._retrieve_splits_from_csv() - return { - "train": self.dynamic_df_train, - "validation": self.dynamic_df_val, - "test": self.dynamic_df_test, - } - - def load_processed_data( - self, kind: Optional[str] = None, filename: Optional[str] = None - ) -> List[Dict[str, Any]]: + str: The base directory path for data. """ - Load processed data from a file. + return os.path.join("data", f"chebi_v{self.chebi_version}") - Args: - kind (str, optional): The kind of dataset to load such as "train", "val", or "test". Defaults to None. - filename (str, optional): The name of the file to load the dataset from. Defaults to None. + @property + def processed_dir(self) -> str: + """ + Return the directory path for processed data. Returns: - List[Dict[str, Any]] : The loaded processed data. - - Raises: - KeyError: If specified kind key doesn't exist. - FileNotFoundError: If the specified file does not exist. + str: The path to the processed data directory. """ - if kind is None and filename is None: - raise ValueError( - "Either kind or filename is required to load the correct dataset, both are None" - ) - - # If both kind and filename are given, use filename - if kind is not None and filename is None: - try: - if self.use_inner_cross_validation and kind != "test": - filename = self.processed_file_names_dict[ - f"fold_{self.fold_index}_{kind}" - ] - else: - data_df = self.dynamic_split_dfs[kind] - return data_df.to_dict(orient="records") - except KeyError: - kind = f"{kind}" - - # If filename is provided - try: - return torch.load(os.path.join(self.processed_dir, filename)) - except FileNotFoundError: - raise FileNotFoundError(f"File {filename} doesn't exist") + res = os.path.join( + self.processed_dir_main, + *self.identifier, + ) + if self.single_class is None: + return res + else: + return os.path.join(res, f"single_{self.single_class}") class JCIExtendedBase(_ChEBIDataExtractor): - LABEL_INDEX = 3 - SMILES_INDEX = 2 @property def label_number(self): @@ -981,6 +545,8 @@ def select_classes(self, g, *args, **kwargs): class ChEBIOverX(_ChEBIDataExtractor): """ A class for extracting data from the ChEBI dataset with a threshold for selecting classes. + This class is designed to filter Chebi classes based on a specified threshold, selecting only those classes + which have a certain number of subclasses in the hierarchy. Attributes: LABEL_INDEX (int): The index of the label in the dataset. @@ -989,8 +555,6 @@ class ChEBIOverX(_ChEBIDataExtractor): THRESHOLD (None): The threshold for selecting classes. """ - LABEL_INDEX: int = 3 - SMILES_INDEX: int = 2 READER: dr.ChemDataReader = dr.ChemDataReader THRESHOLD: int = None @@ -1014,18 +578,30 @@ def _name(self) -> str: """ return f"ChEBI{self.THRESHOLD}" - def select_classes(self, g: nx.Graph, split_name: str, *args, **kwargs) -> List: + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: """ - Selects classes from the ChEBI dataset. + Selects classes from the ChEBI dataset based on the number of successors meeting a specified threshold. + + This method iterates over the nodes in the graph, counting the number of successors for each node. + Nodes with a number of successors greater than or equal to the defined threshold are selected. + + Note: + The input graph must be transitive closure of a directed acyclic graph. Args: g (nx.Graph): The graph representing the dataset. - split_name (str): The name of the split. - *args: Additional arguments (not used). + *args: Additional positional arguments (not used). **kwargs: Additional keyword arguments (not used). Returns: - list: The list of selected classes. + List: A sorted list of node IDs that meet the successor threshold criteria. + + Side Effects: + Writes the list of selected nodes to a file named "classes.txt" in the specified processed directory. + + Notes: + - The `THRESHOLD` attribute should be defined in the subclass of this class. + - Nodes without a 'smiles' attribute are ignored in the successor count. """ smiles = nx.get_node_attributes(g, "smiles") nodes = list( @@ -1041,12 +617,6 @@ def select_classes(self, g: nx.Graph, split_name: str, *args, **kwargs) -> List: ) ) filename = "classes.txt" - # if ( - # self.chebi_version_train - # is not None - # # and self.raw_file_names_dict["test"] != split_name - # ): - # filename = f"classes_v{self.chebi_version_train}.txt" with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: fout.writelines(str(node) + "\n" for node in nodes) return nodes @@ -1184,7 +754,7 @@ def processed_dir_main(self) -> str: "processed", ) - def extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: + def _extract_class_hierarchy(self, chebi_path: str) -> nx.DiGraph: """ Extracts a subset of ChEBI based on subclasses of the top class ID. diff --git a/chebai/preprocessing/datasets/go_uniprot.py b/chebai/preprocessing/datasets/go_uniprot.py new file mode 100644 index 00000000..574ecdbd --- /dev/null +++ b/chebai/preprocessing/datasets/go_uniprot.py @@ -0,0 +1,725 @@ +# Reference for this file : +# Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf; +# DeepGO: Predicting protein functions from sequence and interactions +# using a deep ontology-aware classifier, Bioinformatics, 2017. +# https://doi.org/10.1093/bioinformatics/btx624 +# Github: https://github.com/bio-ontology-research-group/deepgo +# https://www.ebi.ac.uk/GOA/downloads +# https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt +# https://www.uniprot.org/uniprotkb + +__all__ = ["GOUniProtOver250", "GOUniProtOver50"] + +import gzip +import os +import shutil +from abc import ABC, abstractmethod +from collections import OrderedDict +from tempfile import NamedTemporaryFile +from typing import Any, Dict, Generator, List, Optional, Tuple, Union + +import fastobo +import networkx as nx +import pandas as pd +import requests +import torch +from Bio import SwissProt +from torch.utils.data import DataLoader + +from chebai.preprocessing import reader as dr +from chebai.preprocessing.datasets.base import _DynamicDataset + + +class _GOUniProtDataExtractor(_DynamicDataset, ABC): + """ + A class for extracting and processing data from the Gene Ontology (GO) dataset and the Swiss UniProt dataset. + + Args: + dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. + splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. + **kwargs: Additional keyword arguments passed to XYBaseDataModule. + + Attributes: + dynamic_data_split_seed (int): The seed for random data splitting, default is 42. + splits_file_path (Optional[str]): Path to the CSV file containing split assignments. + """ + + _GO_DATA_INIT = "GO" + _SWISS_DATA_INIT = "SWISS" + + # -- Index for columns of processed `data.pkl` (derived from `_get_swiss_to_go_mapping` & `_graph_to_raw_dataset` + # "swiss_id" at row index 0 + # "accession" at row index 1 + # "go_ids" at row index 2 + # "sequence" at row index 3 + # labels starting from row index 4 + _ID_IDX: int = 0 + _DATA_REPRESENTATION_IDX: int = 3 # here `sequence` column + _LABELS_START_IDX: int = 4 + + _GO_DATA_URL: str = "https://purl.obolibrary.org/obo/go/go-basic.obo" + _SWISS_DATA_URL: str = ( + "https://ftp.uniprot.org/pub/databases/uniprot/knowledgebase/complete/uniprot_sprot.dat.gz" + ) + + # Gene Ontology (GO) has three major branches, one for biological processes (BP), molecular functions (MF) and + # cellular components (CC). The value "all" will take data related to all three branches into account. + _ALL_GO_BRANCHES: str = "all" + _GO_BRANCH_NAMESPACE: Dict[str, str] = { + "BP": "biological_process", + "MF": "molecular_function", + "CC": "cellular_component", + } + + def __init__(self, **kwargs): + self.go_branch: str = self._get_go_branch(**kwargs) + super(_GOUniProtDataExtractor, self).__init__(**kwargs) + + self.max_sequence_length: int = int(kwargs.get("max_sequence_length", 1002)) + assert ( + self.max_sequence_length >= 1 + ), "Max sequence length should be greater than or equal to 1." + + @classmethod + def _get_go_branch(cls, **kwargs) -> str: + """ + Retrieves the Gene Ontology (GO) branch based on provided keyword arguments. + This method checks if a valid GO branch value is provided in the keyword arguments. + + Args: + **kwargs: Arbitrary keyword arguments. Specifically looks for: + - "go_branch" (str): The desired GO branch. + Returns: + str: The GO branch value. This will be one of the allowed values. + + Raises: + ValueError: If the provided 'go_branch' value is not in the allowed list of values. + """ + + go_branch_value = kwargs.get("go_branch", cls._ALL_GO_BRANCHES) + allowed_values = list(cls._GO_BRANCH_NAMESPACE.keys()) + [cls._ALL_GO_BRANCHES] + if go_branch_value not in allowed_values: + raise ValueError( + f"Invalid value for go_branch: {go_branch_value}, Allowed values: {allowed_values}" + ) + return go_branch_value + + # ------------------------------ Phase: Prepare data ----------------------------------- + def _download_required_data(self) -> str: + """ + Downloads the required raw data related to Gene Ontology (GO) and Swiss-UniProt dataset. + + Returns: + str: Path to the downloaded data. + """ + self._download_swiss_uni_prot_data() + return self._download_gene_ontology_data() + + def _download_gene_ontology_data(self) -> str: + """ + Download the Gene Ontology data `.obo` file. + + Note: + Quote from : https://geneontology.org/docs/download-ontology/ + Three versions of the ontology are available, the one use in this method is described below: + https://purl.obolibrary.org/obo/go/go-basic.obo + The basic version of the GO, filtered such that the graph is guaranteed to be acyclic and annotations + can be propagated up the graph. The relations included are `is a, part of, regulates, negatively` + `regulates` and `positively regulates`. This version excludes relationships that cross the 3 GO + hierarchies. This version should be used with most GO-based annotation tools. + + Returns: + str: The file path of the loaded Gene Ontology data. + """ + go_path = os.path.join(self.raw_dir, self.raw_file_names_dict["GO"]) + os.makedirs(os.path.dirname(go_path), exist_ok=True) + + if not os.path.isfile(go_path): + print("Missing Gene Ontology raw data") + print(f"Downloading Gene Ontology data....") + r = requests.get(self._GO_DATA_URL, allow_redirects=True) + r.raise_for_status() # Check if the request was successful + open(go_path, "wb").write(r.content) + return go_path + + def _download_swiss_uni_prot_data(self) -> Optional[str]: + """ + Download the Swiss-Prot data file from UniProt Knowledgebase. + + Note: + UniProt Knowledgebase is collection of functional information on proteins, with accurate, consistent + and rich annotation. + + Swiss-Prot contains manually-annotated records with information extracted from literature and + curator-evaluated computational analysis. + + Returns: + str: The file path of the loaded Swiss-Prot data file. + """ + uni_prot_file_path = os.path.join( + self.raw_dir, self.raw_file_names_dict["SwissUniProt"] + ) + os.makedirs(os.path.dirname(uni_prot_file_path), exist_ok=True) + + if not os.path.isfile(uni_prot_file_path): + print(f"Downloading Swiss UniProt data....") + + # Create a temporary file + with NamedTemporaryFile(delete=False) as tf: + temp_filename = tf.name + print(f"Downloading to temporary file {temp_filename}") + + # Download the file + response = requests.get(self._SWISS_DATA_URL, stream=True) + with open(temp_filename, "wb") as temp_file: + shutil.copyfileobj(response.raw, temp_file) + + print(f"Downloaded to {temp_filename}") + + # Unpack the gzipped file + try: + print(f"Unzipping the file....") + with gzip.open(temp_filename, "rb") as f_in: + output_file_path = uni_prot_file_path + with open(output_file_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + print(f"Unpacked and saved to {output_file_path}") + + except Exception as e: + print(f"Failed to unpack the file: {e}") + finally: + # Clean up the temporary file + os.remove(temp_filename) + print(f"Removed temporary file {temp_filename}") + + return uni_prot_file_path + + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + """ + Extracts the class hierarchy from the GO ontology. + Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with GO term data. + + Args: + data_path (str): The path to the GO ontology. + + Returns: + nx.DiGraph: A directed graph representing the class hierarchy, where nodes are GO terms and edges + represent parent-child relationships. + """ + print("Extracting class hierarchy...") + elements = [] + for term in fastobo.load(data_path): + if isinstance(term, fastobo.typedef.TypedefFrame): + # ---- To avoid term frame of the below format/structure ---- + # [Typedef] + # id: part_of + # name: part of + # namespace: external + # xref: BFO:0000050 + # is_transitive: true + continue + + if ( + term + and isinstance(term.id, fastobo.id.PrefixedIdent) + and term.id.prefix == self._GO_DATA_INIT + ): + # Consider only terms with id in following format - GO:2001271 + term_dict = self.term_callback(term) + if term_dict: + elements.append(term_dict) + + g = nx.DiGraph() + + # Add GO term nodes to the graph and their hierarchical ontology + for n in elements: + g.add_node(n["go_id"], **n) + g.add_edges_from( + [ + (parent_id, node_id) + for node_id in g.nodes + for parent_id in g.nodes[node_id]["parents"] + if parent_id in g.nodes + ] + ) + + print("Compute transitive closure") + return nx.transitive_closure_dag(g) + + def term_callback(self, term: fastobo.term.TermFrame) -> Union[Dict, bool]: + """ + Extracts information from a Gene Ontology (GO) term document. + + Args: + term: A Gene Ontology term Frame document. + + Returns: + Optional[Dict]: A dictionary containing the extracted information if the term is not obsolete, + otherwise None. The dictionary includes: + - "id" (str): The ID of the GO term. + - "parents" (List[str]): A list of parent term IDs. + - "name" (str): The name of the GO term. + """ + parents = [] + name = None + + for clause in term: + if isinstance(clause, fastobo.term.NamespaceClause): + if ( + self.go_branch != self._ALL_GO_BRANCHES + and clause.namespace.escaped + != self._GO_BRANCH_NAMESPACE[self.go_branch] + ): + # if the term document is not related to given go branch (except `all`), skip this document. + return False + + if isinstance(clause, fastobo.term.IsObsoleteClause): + if clause.obsolete: + # if the term document contains clause as obsolete as true, skips this document. + return False + + if isinstance(clause, fastobo.term.IsAClause): + parents.append(self._parse_go_id(clause.term)) + elif isinstance(clause, fastobo.term.NameClause): + name = clause.name + + return { + "go_id": self._parse_go_id(term.id), + "parents": parents, + "name": name, + } + + @staticmethod + def _parse_go_id(go_id: str) -> int: + """ + Helper function to parse and normalize GO term IDs. + + Args: + go_id: The raw GO term ID string. + + Returns: + str: The parsed and normalized GO term ID. + """ + # `is_a` clause has GO id in the following formats: + # GO:0009968 ! negative regulation of signal transduction + # GO:0046780 + return int(str(go_id).split(":")[1].split("!")[0].strip()) + + def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: + """ + Processes a directed acyclic graph (DAG) to create a raw dataset in DataFrame format. The dataset includes + Swiss-Prot protein data and their associations with Gene Ontology (GO) terms. + + Note: + - GO classes are used as labels in the dataset. Each GO term is represented as a column, and its value + indicates whether a Swiss-Prot protein is associated with that GO term. + - Swiss-Prot proteins serve as samples. There is no 1-to-1 correspondence between Swiss-Prot proteins + and GO terms. + + Data Format: pd.DataFrame + - Column 0 : swiss_id (Identifier for SwissProt protein) + - Column 1 : Accession of the protein + - Column 2 : GO IDs (associated GO terms) + - Column 3 : Sequence of the protein + - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating whether the + protein is associated with this GO term. + + Args: + g (nx.DiGraph): The class hierarchy graph. + + Returns: + pd.DataFrame: The raw dataset created from the graph. + """ + print(f"Processing graph") + + data_df = self._get_swiss_to_go_mapping() + + # Initialize the GO term labels/columns to False + selected_classes = self.select_classes(g, data_df=data_df) + new_label_columns = pd.DataFrame( + False, index=data_df.index, columns=selected_classes + ) + data_df = pd.concat([data_df, new_label_columns], axis=1) + + # Set True for the corresponding GO IDs in the DataFrame go labels/columns + for index, row in data_df.iterrows(): + for go_id in row["go_ids"]: + if go_id in data_df.columns: + data_df.at[index, go_id] = True + + # This filters the DataFrame to include only the rows where at least one value in the row from 5th column + # onwards is True/non-zero. + # Quote from DeepGo Paper: `For training and testing, we use proteins which have been annotated with at least + # one GO term from the set of the GO terms for the model` + data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + return data_df + + def _get_swiss_to_go_mapping(self) -> pd.DataFrame: + """ + Parses the Swiss-Prot data and returns a DataFrame mapping Swiss-Prot records to Gene Ontology (GO) data. + + The DataFrame includes the following columns: + - "swiss_id": The unique identifier for each Swiss-Prot record. + - "sequence": The protein sequence. + - "accessions": Comma-separated list of accession numbers. + - "go_ids": List of GO IDs associated with the Swiss-Prot record. + + Note: + This mapping is necessary because the GO data does not include the protein sequence representation. + + Quote from the DeepGo Paper: + `We select proteins with annotations having experimental evidence codes + (EXP, IDA, IPI, IMP, IGI, IEP, TAS, IC) and filter the proteins by a + maximum length of 1002, ignoring proteins with ambiguous amino acid codes + (B, O, J, U, X, Z) in their sequence.` + + Check the link below for keyword details: + https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt + + Returns: + pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with its associated GO data. + """ + + print("Parsing swiss uniprot raw data....") + + swiss_ids, sequences, accessions, go_ids_list = [], [], [], [] + + swiss_data = SwissProt.parse( + open( + os.path.join(self.raw_dir, self.raw_file_names_dict["SwissUniProt"]), + "r", + ) + ) + + EXPERIMENTAL_EVIDENCE_CODES = { + "EXP", + "IDA", + "IPI", + "IMP", + "IGI", + "IEP", + "TAS", + "IC", + } + # https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8 + AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "X", "Z", "*"} + + for record in swiss_data: + if record.data_class != "Reviewed": + # To consider only manually-annotated swiss data + continue + + if not record.sequence: + # Consider protein with only sequence representation + continue + + if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence): + # Skip proteins with ambiguous amino acid codes + continue + + go_ids = [] + + for cross_ref in record.cross_references: + if cross_ref[0] == self._GO_DATA_INIT: + # One swiss data protein can correspond to many GO data instances + + if len(cross_ref) <= 3: + # No evidence code + continue + + # https://github.com/bio-ontology-research-group/deepgo/blob/master/get_functions.py#L63-L66 + evidence_code = cross_ref[3].split(":")[0] + if evidence_code not in EXPERIMENTAL_EVIDENCE_CODES: + # Skip GO id without the required experimental evidence codes + continue + + go_ids.append(self._parse_go_id(cross_ref[1])) + + if not go_ids: + # Skip Swiss proteins without mapping to GO data + continue + + swiss_ids.append(record.entry_name) + sequences.append(record.sequence) + accessions.append(",".join(record.accessions)) + go_ids.sort() + go_ids_list.append(go_ids) + + data_dict = OrderedDict( + swiss_id=swiss_ids, # swiss_id column at index 0 + accession=accessions, # Accession column at index 1 + go_ids=go_ids_list, # Go_ids (data representation) column at index 2 + sequence=sequences, # Sequence column at index 3 + ) + + return pd.DataFrame(data_dict) + + # ------------------------------ Phase: Setup data ----------------------------------- + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from a pickled file and yields individual dictionaries for each row. + + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of go data instance + - Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein + - Data from row index `self._LABELS_START_IDX` onwards: Labels + + This method is used by `_load_data_from_file` to generate dictionaries that are then + processed and converted into a list of dictionaries containing the features and labels. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Dict[str, Any]: A dictionary containing: + - `features` (str): The sequence data from the file. + - `labels` (np.ndarray): A boolean array of labels starting from row index 4. + - `ident` (Any): The identifier from row index 0. + """ + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + for row in df.values: + labels = row[self._LABELS_START_IDX :].astype(bool) + # chebai.preprocessing.reader.DataReader only needs features, labels, ident, group + # "group" set to None, by default as no such entity for this data + yield dict( + features=row[self._DATA_REPRESENTATION_IDX], + labels=labels, + ident=row[self._ID_IDX], + ) + + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Loads encoded data and generates training, validation, and test splits. + + This method attempts to load encoded data from a file named `data.pt`. It then splits this data into + training, validation, and test sets. + + Raises: + FileNotFoundError: If the `data.pt` file does not exist. Ensure that `prepare_data` and/or + `setup` methods are called to generate the necessary dataset files. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing three DataFrames: + - Training set + - Validation set + - Test set + """ + try: + filename = self.processed_file_names_dict["data"] + data_go = torch.load(os.path.join(self.processed_dir, filename)) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists. " + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) + + df_go_data = pd.DataFrame(data_go) + train_df_go, df_test = self.get_test_split( + df_go_data, seed=self.dynamic_data_split_seed + ) + + # Get all splits + df_train, df_val = self.get_train_val_splits_given_test( + train_df_go, + df_test, + seed=self.dynamic_data_split_seed, + ) + + return df_train, df_val, df_test + + # ------------------------------ Phase: DataLoaders ----------------------------------- + def dataloader(self, kind: str, **kwargs) -> DataLoader: + """ + Returns a DataLoader object with truncated sequences for the specified kind of data (train, val, or test). + + This method overrides the dataloader method from the superclass. After fetching the dataset from the + superclass, it truncates the 'features' of each data instance to a maximum length specified by + `self.max_sequence_length`. + + Args: + kind (str): The kind of data to load (e.g., 'train', 'val', 'test'). + **kwargs: Additional keyword arguments passed to the superclass dataloader method. + + Returns: + DataLoader: A DataLoader object with the truncated sequences. + """ + dataloader = super().dataloader(kind, **kwargs) + + # Truncate the 'features' to max_sequence_length for each instance + for instance in dataloader.dataset: + instance["features"] = instance["features"][: self.max_sequence_length] + return dataloader + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + def base_dir(self) -> str: + """ + Returns the base directory path for storing GO-Uniprot data. + + Returns: + str: The path to the base directory, which is "data/GO_UniProt". + """ + return os.path.join("data", f"GO_UniProt") + + @property + def identifier(self) -> tuple: + """Identifier for the dataset.""" + # overriding identifier instead of reader.name to keep same tokens.txt file, but different processed_dir folder + if not isinstance(self.reader, dr.ProteinDataReader): + raise ValueError("Need Protein DataReader for identifier") + if self.reader.n_gram is not None: + return (f"{self.reader.name()}_{self.reader.n_gram}_gram",) + return (self.reader.name(),) + + @property + def raw_file_names_dict(self) -> dict: + """ + Returns a dictionary of raw file names used in data processing. + + Returns: + dict: A dictionary mapping dataset names to their respective file names. + For example, {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"}. + """ + return {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"} + + +class _GOUniProtOverX(_GOUniProtDataExtractor, ABC): + """ + A class for extracting data from the Gene Ontology (GO) dataset with a threshold for selecting classes based on + the number of subclasses. + + This class is designed to filter GO classes based on a specified threshold, selecting only those classes + which have a certain number of subclasses in the hierarchy. + + Attributes: + READER (dr.ProteinDataReader): The reader used for reading the dataset. + THRESHOLD (int): The threshold for selecting classes based on the number of subclasses. + + Property: + label_number (int): The number of labels in the dataset. This property must be implemented by subclasses. + """ + + READER: dr.ProteinDataReader = dr.ProteinDataReader + THRESHOLD: int = None + + @property + def _name(self) -> str: + """ + Returns the name of the dataset. + + Returns: + str: The dataset name, formatted with the current threshold value and/or given go_branch. + """ + if self.go_branch != self._ALL_GO_BRANCHES: + return f"GO{self.THRESHOLD}_{self.go_branch}" + + return f"GO{self.THRESHOLD}" + + def select_classes( + self, g: nx.DiGraph, *args: Any, **kwargs: Dict[str, Any] + ) -> List[int]: + """ + Selects classes (GO terms) from the Gene Ontology (GO) dataset based on the number of annotations meeting a + specified threshold. + + The selection process is based on the annotations of the GO terms with its ancestors across the dataset. + + Annotations are calculated by counting how many times each GO term, along with its ancestral hierarchy, + is annotated per protein across the dataset. + This means that for each protein, the GO terms associated with it are considered, and the entire hierarchical + structure (ancestors) of each GO term is taken into account. The total count for each GO term and its ancestors + reflects how frequently these terms are annotated across all proteins in the dataset. + + Args: + g (nx.DiGraph): The directed acyclic graph representing the GO dataset, where each node corresponds to a GO term. + *args: Additional positional arguments (not used). + **kwargs: Additional keyword arguments, including: + - data_df (pd.DataFrame): A DataFrame containing the GO annotations for various proteins. + It should include a 'go_ids' column with the GO terms associated with each protein. + + Returns: + List[int]: A sorted list of selected GO term IDs that meet the annotation threshold criteria. + + Side Effects: + - Writes the list of selected GO term IDs to a file named "classes.txt" in the specified processed directory. + + Raises: + AttributeError: If the 'data_df' argument is not provided in kwargs. + + Notes: + - The `THRESHOLD` attribute, which defines the minimum number of annotations required to select a GO term, should be defined in the subclass. + """ + # Retrieve the DataFrame containing GO annotations per protein from the keyword arguments + data_df = kwargs.get("data_df", None) + if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty: + raise AttributeError( + "The 'data_df' argument must be provided and must be a non-empty pandas DataFrame." + ) + + print(f"Selecting GO terms based on given threshold: {self.THRESHOLD} ...") + + # https://github.com/bio-ontology-research-group/deepgo/blob/master/get_functions.py#L59-L77 + go_term_annot: Dict[int, int] = {} + for idx, row in data_df.iterrows(): + # Set will contain go terms associated with the protein, along with all the ancestors of those + # associated go terms + associated_go_ids_with_ancestors = set() + + # Collect all ancestors of the GO terms associated with this protein + for go_id in row["go_ids"]: + if go_id in g.nodes: + associated_go_ids_with_ancestors.add(go_id) + associated_go_ids_with_ancestors.update( + g.predecessors(go_id) + ) # Add all predecessors (ancestors) of go_id + + # Count the annotations for each go_id **`per protein`** + for go_id in associated_go_ids_with_ancestors: + if go_id not in go_term_annot: + go_term_annot[go_id] = 0 + go_term_annot[go_id] += 1 + + # Select GO terms that meet or exceed the threshold of annotations + selected_nodes: List[int] = [ + go_id + for go_id in g.nodes + if go_id in go_term_annot and go_term_annot[go_id] >= self.THRESHOLD + ] + + # Sort the selected nodes (optional but often useful for consistent output) + selected_nodes.sort() + + # Write the selected node IDs/classes to the file + filename = "classes.txt" + with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: + fout.writelines(str(node) + "\n" for node in selected_nodes) + + return selected_nodes + + +class GOUniProtOver250(_GOUniProtOverX): + """ + A class for extracting data from the Gene Ontology (GO) dataset with a threshold of 250 for selecting classes. + + Inherits from `_GOUniProtOverX` and sets the threshold for selecting classes to 250. + + Attributes: + THRESHOLD (int): The threshold for selecting classes (250). + """ + + THRESHOLD: int = 250 + + +class GOUniProtOver50(_GOUniProtOverX): + """ + A class for extracting data from the Gene Ontology (GO) dataset with a threshold of 50 for selecting classes. + + Inherits from `_GOUniProtOverX` and sets the threshold for selecting classes to 50. + + Attributes: + THRESHOLD (int): The threshold for selecting classes (50). + """ + + THRESHOLD: int = 50 diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 902f1e92..46cd558a 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -146,11 +146,21 @@ def _get_token_index(self, token: str) -> int: return self.cache.index(str(token)) + EMBEDDING_OFFSET def _read_data(self, raw_data: str) -> List[int]: - """Read and tokenize raw data.""" + """ + Reads and tokenizes raw SMILES data into a list of token indices. + + Args: + raw_data (str): The raw SMILES string to be tokenized. + + Returns: + List[int]: A list of integers representing the indices of the SMILES tokens. + """ return [self._get_token_index(v[1]) for v in _tokenize(raw_data)] def on_finish(self) -> None: - """Write contents of self.cache into tokens.txt.""" + """ + Saves the current cache of tokens to the token file. This method is called after all data processing is complete. + """ with open(self.token_path, "w") as pk: print(f"saving {len(self.cache)} tokens to {self.token_path}...") print(f"first 10 tokens: {self.cache[:10]}") @@ -320,3 +330,140 @@ def name(cls) -> str: def _read_data(self, raw_data: str) -> List[int]: """Convert characters in raw data to their ordinal values.""" return [ord(s) for s in raw_data] + + +class ProteinDataReader(DataReader): + """ + Data reader for protein sequences using amino acid tokens. This class processes raw protein sequences into a format + suitable for model input by tokenizing them and assigning unique indices to each token. + + Note: + Refer for amino acid sequence: https://en.wikipedia.org/wiki/Protein_primary_structure + + Args: + collator_kwargs (Optional[Dict[str, Any]]): Optional dictionary of keyword arguments for configuring the collator. + token_path (Optional[str]): Path to the token file. If not provided, it will be created automatically. + kwargs: Additional keyword arguments. + """ + + COLLATOR = RaggedCollator + + # 20 natural amino acid notation + AA_LETTER = [ + "A", + "R", + "N", + "D", + "C", + "Q", + "E", + "G", + "H", + "I", + "L", + "K", + "M", + "F", + "P", + "S", + "T", + "W", + "Y", + "V", + ] + + @classmethod + def name(cls) -> str: + """ + Returns the name of the data reader. This method identifies the specific type of data reader. + + Returns: + str: The name of the data reader, which is "protein_token". + """ + return "protein_token" + + def __init__(self, *args, n_gram: Optional[int] = None, **kwargs): + """ + Initializes the ProteinDataReader, loading existing tokens from the specified token file. + + Args: + *args: Additional positional arguments passed to the base class. + **kwargs: Additional keyword arguments passed to the base class. + """ + if n_gram is not None: + assert ( + int(n_gram) >= 2 + ), "Ngrams must be greater than or equal to 2 if provided." + self.n_gram = int(n_gram) + else: + self.n_gram = None + + super().__init__(*args, **kwargs) + + # Load the existing tokens from the token file into a cache + with open(self.token_path, "r") as pk: + self.cache = [x.strip() for x in pk] + + def _get_token_index(self, token: str) -> int: + """ + Returns a unique index for each token (amino acid). If the token is not already in the cache, it is added. + + Args: + token (str): The amino acid token to retrieve or add. + + Returns: + int: The index of the token, offset by the predefined EMBEDDING_OFFSET. + """ + error_str = ( + f"Please ensure that the input only contains valid amino acids " + f"20 Valid natural amino acid notation: {self.AA_LETTER}" + f"Refer to the amino acid sequence details here: " + f"https://en.wikipedia.org/wiki/Protein_primary_structure" + ) + + if self.n_gram is None: + # Single-letter amino acid token check + if str(token) not in self.AA_LETTER: + raise KeyError(f"Invalid token '{token}' encountered. " + error_str) + else: + # n-gram token validation, ensure that each component of the n-gram is valid + for aa in token: + if aa not in self.AA_LETTER: + raise KeyError( + f"Invalid token '{token}' encountered as part of n-gram {self.n_gram}. " + + error_str + ) + + if str(token) not in self.cache: + self.cache.append(str(token)) + return self.cache.index(str(token)) + EMBEDDING_OFFSET + + def _read_data(self, raw_data: str) -> List[int]: + """ + Reads and tokenizes raw protein sequence data into a list of token indices. + + Args: + raw_data (str): The raw protein sequence to be tokenized (e.g., "MKTFF..."). + + Returns: + List[int]: A list of integers representing the indices of the amino acid tokens. + """ + if self.n_gram is not None: + # Tokenize the sequence into n-grams + tokens = [ + raw_data[i : i + self.n_gram] + for i in range(len(raw_data) - self.n_gram + 1) + ] + return [self._get_token_index(gram) for gram in tokens] + + # If n_gram is None, tokenize the sequence at the amino acid level (single-letter representation) + return [self._get_token_index(aa) for aa in raw_data] + + def on_finish(self) -> None: + """ + Saves the current cache of tokens to the token file. This method is called after all data processing is complete. + """ + with open(self.token_path, "w") as pk: + print(f"Saving {len(self.cache)} tokens to {self.token_path}...") + print(f"First 10 tokens: {self.cache[:10]}") + pk.writelines([f"{c}\n" for c in self.cache]) diff --git a/configs/data/go250.yml b/configs/data/go250.yml new file mode 100644 index 00000000..5598495c --- /dev/null +++ b/configs/data/go250.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.go_uniprot.GOUniProtOver250 +init_args: + go_branch: "BP" diff --git a/configs/data/go50.yml b/configs/data/go50.yml new file mode 100644 index 00000000..2ed4d14c --- /dev/null +++ b/configs/data/go50.yml @@ -0,0 +1 @@ +class_path: chebai.preprocessing.datasets.go_uniprot.GOUniProtOver50 diff --git a/setup.py b/setup.py index 0007f6ee..58bfc75b 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ "chardet", "pyyaml", "torchmetrics", + "biopython", ], extras_require={"dev": ["black", "isort", "pre-commit"]}, )