diff --git a/.gitignore b/.gitignore index af998906..bafec1d9 100644 --- a/.gitignore +++ b/.gitignore @@ -175,3 +175,4 @@ chebai.egg-info lightning_logs logs .isort.cfg +/.vscode diff --git a/chebai/models/base.py b/chebai/models/base.py index fd02c6ce..cb254570 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -1,9 +1,9 @@ import logging -from typing import Any, Dict, Optional, Union, Iterable +from abc import ABC, abstractmethod +from typing import Any, Dict, Iterable, Optional, Union import torch from lightning.pytorch.core.module import LightningModule -from torchmetrics import Metric from chebai.preprocessing.structures import XYData @@ -12,7 +12,7 @@ _MODEL_REGISTRY = dict() -class ChebaiBaseNet(LightningModule): +class ChebaiBaseNet(LightningModule, ABC): """ Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule. @@ -353,6 +353,7 @@ def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int): logger=True, ) + @abstractmethod def forward(self, x: Dict[str, Any]) -> torch.Tensor: """ Defines the forward pass. @@ -363,7 +364,7 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor: Returns: torch.Tensor: The model output. """ - raise NotImplementedError + pass def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer: """ diff --git a/chebai/result/_generate_classes_props_json.py b/chebai/result/_generate_classes_props_json.py new file mode 100644 index 00000000..b8704591 --- /dev/null +++ b/chebai/result/_generate_classes_props_json.py @@ -0,0 +1,191 @@ +import json +from pathlib import Path + +import torch +from jsonargparse import CLI +from sklearn.metrics import multilabel_confusion_matrix + +from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.result.utils import ( + load_data_instance, + load_model_for_inference, + parse_config_file, +) + + +class ClassesPropertiesGenerator: + """ + Computes PPV (Positive Predictive Value) and NPV (Negative Predictive Value) + for each class in a multi-label classification problem using a PyTorch Lightning model. + """ + + @staticmethod + def load_class_labels(path: Path) -> list[str]: + """ + Load a list of class names from a .json or .txt file. + + Args: + path: Path to the class labels file (txt or json). + + Returns: + A list of class names, one per line. + """ + path = Path(path) + with path.open() as f: + return [line.strip() for line in f if line.strip()] + + @staticmethod + def compute_tpv_npv( + y_true: list[torch.Tensor], + y_pred: list[torch.Tensor], + class_names: list[str], + ) -> dict[str, dict[str, float]]: + """ + Compute TPV (precision) and NPV for each class in a multi-label setting. + + Args: + y_true: List of binary ground-truth label tensors, one tensor per sample. + y_pred: List of binary prediction tensors, one tensor per sample. + class_names: Ordered list of class names corresponding to class indices. + + Returns: + Dictionary mapping each class name to its TPV and NPV metrics: + { + "class_name": {"PPV": float, "NPV": float}, + ... + } + """ + # Stack per-sample tensors into (n_samples, n_classes) numpy arrays + true_np = torch.stack(y_true).cpu().numpy().astype(int) + pred_np = torch.stack(y_pred).cpu().numpy().astype(int) + + # Compute confusion matrix for each class + cm = multilabel_confusion_matrix(true_np, pred_np) + + results: dict[str, dict[str, float]] = {} + for idx, cls_name in enumerate(class_names): + tn, fp, fn, tp = cm[idx].ravel() + tpv = tp / (tp + fp) if (tp + fp) > 0 else 0.0 + npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0 + results[cls_name] = { + "PPV": round(tpv, 4), + "NPV": round(npv, 4), + "TN": int(tn), + "FP": int(fp), + "FN": int(fn), + "TP": int(tp), + } + return results + + def generate_props( + self, + model_ckpt_path: str, + model_config_file_path: str, + data_config_file_path: str, + output_path: str | None = None, + ) -> None: + """ + Run inference on validation set, compute TPV/NPV per class, and save to JSON. + + Args: + model_ckpt_path: Path to the PyTorch Lightning checkpoint file. + model_config_file_path: Path to yaml config file of the model. + data_config_file_path: Path to yaml config file of the data. + output_path: Optional path where to write the JSON metrics file. + Defaults to '/classes.json'. + """ + print("Extracting validation data for computation...") + + data_cls_path, data_cls_kwargs = parse_config_file(data_config_file_path) + data_module: XYBaseDataModule = load_data_instance( + data_cls_path, data_cls_kwargs + ) + + splits_file_path = Path(data_module.processed_dir_main, "splits.csv") + if data_module.splits_file_path is None: + if not splits_file_path.exists(): + raise RuntimeError( + "Either the data module should be initialized with a `splits_file_path`, " + f"or the file `{splits_file_path}` must exists.\n" + "This is to prevent the data module from dynamically generating the splits." + ) + + print( + f"`splits_file_path` is not provided as an initialization parameter to the data module\n" + f"Using splits from the file {splits_file_path}" + ) + data_module.splits_file_path = splits_file_path + + model_class_path, model_kwargs = parse_config_file(model_config_file_path) + model = load_model_for_inference( + model_ckpt_path, model_class_path, model_kwargs + ) + + val_loader = data_module.val_dataloader() + print("Running inference on validation data...") + + y_true, y_pred = [], [] + for batch_idx, batch in enumerate(val_loader): + data = model._process_batch( # pylint: disable=W0212 + batch, batch_idx=batch_idx + ) + labels = data["labels"] + outputs = model(data, **data.get("model_kwargs", {})) + logits = outputs["logits"] if isinstance(outputs, dict) else outputs + preds = torch.sigmoid(logits) > 0.5 + y_pred.extend(preds) + y_true.extend(labels) + + print("Computing TPV and NPV metrics...") + classes_file = Path(data_module.processed_dir_main) / "classes.txt" + if output_path is None: + output_file = Path(data_module.processed_dir_main) / "classes.json" + else: + output_file = Path(output_path) + + class_names = self.load_class_labels(classes_file) + metrics = self.compute_tpv_npv(y_true, y_pred, class_names) + + with output_file.open("w") as f: + json.dump(metrics, f, indent=2) + print(f"Saved TPV/NPV metrics to {output_file}") + + +class Main: + """ + CLI wrapper for ClassesPropertiesGenerator. + """ + + def generate( + self, + model_ckpt_path: str, + model_config_file_path: str, + data_config_file_path: str, + output_path: str | None = None, + ) -> None: + """ + CLI command to generate TPV/NPV JSON. + + Args: + model_ckpt_path: Path to the PyTorch Lightning checkpoint file. + model_config_file_path: Path to yaml config file of the model. + data_config_file_path: Path to yaml config file of the data. + output_path: Optional path where to write the JSON metrics file. + Defaults to '/classes.json'. + """ + generator = ClassesPropertiesGenerator() + generator.generate_props( + model_ckpt_path, + model_config_file_path, + data_config_file_path, + output_path, + ) + + +if __name__ == "__main__": + # _generate_classes_props_json.py generate \ + # --model_ckpt_path "model/ckpt/path" \ + # --model_config_file_path "model/config/file/path" \ + # --data_config_file_path "data/config/file/path" \ + # --output_path "output/file/path" # Optional + CLI(Main, as_positional=False) diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 991960d6..78d20013 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -1,14 +1,16 @@ +import importlib import os import shutil -from typing import Optional, Tuple, Union +from pathlib import Path +from typing import Optional, Tuple import torch import tqdm import wandb import wandb.util as wandb_util +import yaml from chebai.models.base import ChebaiBaseNet -from chebai.models.electra import Electra from chebai.preprocessing.datasets.base import XYBaseDataModule from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor @@ -121,7 +123,7 @@ def evaluate_model( save_batch_size = 128 n_saved = 1 - print(f"") + print("") for i in tqdm.tqdm(range(0, len(data_list), batch_size)): if not ( skip_existing_preds @@ -222,6 +224,82 @@ def load_results_from_buffer( return test_preds, test_labels +def load_class(class_path: str) -> type: + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + + +def load_data_instance(data_cls_path: str, data_cls_kwargs: dict): + assert isinstance(data_cls_kwargs, dict), "data_cls_kwargs must be a dict" + data_cls = load_class(data_cls_path) + assert isinstance(data_cls, type), f"{data_cls} is not a class." + assert issubclass( + data_cls, XYBaseDataModule + ), f"{data_cls} must inherit from XYBaseDataModule" + return data_cls(**data_cls_kwargs) + + +def load_model_for_inference( + model_ckpt_path: str, model_cls_path: str, model_load_kwargs: dict, **kwargs +) -> ChebaiBaseNet: + """ + Loads a model checkpoint and its label-related properties. + + Returns: + Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties. + """ + assert isinstance(model_load_kwargs, dict), "model_load_kwargs must be a dict" + + model_name = kwargs.get("model_name", model_ckpt_path) + + if not Path(model_ckpt_path).exists(): + raise FileNotFoundError( + f"Model path '{model_ckpt_path}' for '{model_name}' does not exist." + ) + + lightning_cls = load_class(model_cls_path) + + assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class." + assert issubclass( + lightning_cls, ChebaiBaseNet + ), f"{lightning_cls} must inherit from ChebaiBaseNet" + try: + model = lightning_cls.load_from_checkpoint(model_ckpt_path, **model_load_kwargs) + except Exception as e: + raise RuntimeError(f"Error loading model {model_name} \n Error: {e}") from e + + assert isinstance( + model, ChebaiBaseNet + ), f"Model: {model}(Model Name: {model_name}) is not a ChebaiBaseNet instance." + model.eval() + model.freeze() + return model + + +def parse_config_file(config_path: str) -> tuple[str, dict]: + path = Path(config_path) + + # Check file existence + if not path.exists(): + raise FileNotFoundError(f"Config file not found: {config_path}") + + # Check file extension + if path.suffix.lower() not in [".yml", ".yaml"]: + raise ValueError( + f"Unsupported config file type: {path.suffix}. Expected .yaml or .yml" + ) + + # Load YAML content + with open(path, "r") as f: + config: dict = yaml.safe_load(f) + + class_path: str = config["class_path"] + init_args: dict = config.get("init_args", {}) + assert isinstance(init_args, dict), "init_args must be a dictionary" + return class_path, init_args + + if __name__ == "__main__": import sys @@ -231,5 +309,5 @@ def load_results_from_buffer( ) os.makedirs(buffer_dir_concat, exist_ok=True) preds, labels = load_results_from_buffer(buffer_dir, "cpu") - torch.save(preds, os.path.join(buffer_dir_concat, f"preds000.pt")) - torch.save(labels, os.path.join(buffer_dir_concat, f"labels000.pt")) + torch.save(preds, os.path.join(buffer_dir_concat, "preds000.pt")) + torch.save(labels, os.path.join(buffer_dir_concat, "labels000.pt"))