Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
e3297c6
add dummy dataset class - quick testing purpose
aditya0by0 Mar 17, 2025
f0e4758
custom typehints
aditya0by0 Mar 17, 2025
4fda565
model base: make forward method as abstract method rebase ensemble_f…
aditya0by0 Mar 17, 2025
7f7c6a0
ensemble: abstract code
aditya0by0 Mar 17, 2025
4d3f4f6
ignore lightning logs
aditya0by0 Mar 17, 2025
55959de
ensemble: fix for grad runtime error
aditya0by0 Mar 17, 2025
9513fea
ensemble: config for ensemble model
aditya0by0 Mar 17, 2025
2875385
ensemble: add MLP layer on top ensemble models
aditya0by0 Mar 21, 2025
7f892d9
base: fix import
aditya0by0 Mar 21, 2025
f60b2d8
ensemble: code improvements
aditya0by0 Mar 21, 2025
72a6b37
ensemble: add class path to config and load model via this class
aditya0by0 Mar 24, 2025
82a96dc
ensemble: changes for out of scope labels for certain models
aditya0by0 Mar 25, 2025
26f5ab4
ensemble: correct confidence val calculation
aditya0by0 Mar 31, 2025
9b851c5
ensemble: update for tpv/fpv value for each label
aditya0by0 Mar 31, 2025
0541ed2
ensemble: add docstrings and typehints
aditya0by0 Apr 1, 2025
2f3ecc6
Merge branch 'dev' into ensemble_br
aditya0by0 May 4, 2025
3ace30a
remove optimizer kwargs as not needed
aditya0by0 May 4, 2025
ddcdeac
add template to ensemble config
aditya0by0 May 4, 2025
eb1798c
Update .gitignore
aditya0by0 May 4, 2025
ed92ac5
each model's each label has TPV, FPV
aditya0by0 May 5, 2025
405026e
Merge branch 'dev' into ensemble_br
aditya0by0 May 7, 2025
0ec03b1
remove ensemble learning class
aditya0by0 May 15, 2025
dabe5ff
update code change
aditya0by0 May 15, 2025
7db384a
add ensemble base to new python dir
aditya0by0 May 16, 2025
65a51e0
add ensemble controller
aditya0by0 May 16, 2025
37d46f7
add utils.print_metrics to ensemble
aditya0by0 May 18, 2025
bc6e131
add consolidator
aditya0by0 May 18, 2025
b9dbd97
add to needed classes to init
aditya0by0 May 18, 2025
4d6856d
add rank_zero_info printing
aditya0by0 May 18, 2025
825916e
add script for running ensemble
aditya0by0 May 18, 2025
69c5263
ensemble minor changes
aditya0by0 May 18, 2025
4bd00ac
private instance var + reader_dir_name param
aditya0by0 May 19, 2025
50057f0
config for ensemble
aditya0by0 May 19, 2025
ee7a166
delete models/ensemble
aditya0by0 May 19, 2025
e9f1d95
delete old ensemble config
aditya0by0 May 19, 2025
fca0305
add docstrings + typehints
aditya0by0 May 19, 2025
de6a707
delete dummy dataset
aditya0by0 May 20, 2025
b471a05
raname script with _ prefix
aditya0by0 May 20, 2025
4c89dd3
wrapper base
aditya0by0 May 21, 2025
1563c76
nn wrapper
aditya0by0 May 21, 2025
2fec9ef
rename ensemble internal files with _ prefix
aditya0by0 May 21, 2025
682801f
chemlog wrapper
aditya0by0 May 22, 2025
2b2d458
gnn wrapper
aditya0by0 May 22, 2025
7cbb732
move related code from ensemble base to nn wrapper
aditya0by0 May 22, 2025
a7df384
move constants to wrappers
aditya0by0 May 22, 2025
ee0aef1
move prop loading to base
aditya0by0 May 22, 2025
8d8a748
move wrappers to ensemble
aditya0by0 May 22, 2025
00bd478
nn validate model config
aditya0by0 May 22, 2025
4f35007
utility for loading class
aditya0by0 May 22, 2025
a1a70eb
Create _constants.py
aditya0by0 May 22, 2025
f812cd7
update controller for wrapper
aditya0by0 May 22, 2025
c48bfd2
update base for wrapper
aditya0by0 May 22, 2025
bf3cf64
Update .gitignore
aditya0by0 Jun 1, 2025
76d8a79
predict method implementation for data file and list of smiles
aditya0by0 Jun 1, 2025
95d49c1
seperate method for evaluate and prediction
aditya0by0 Jun 1, 2025
a20ce76
store collated label or any model in instance var
aditya0by0 Jun 1, 2025
c0cb6c9
fix collated labels none error
aditya0by0 Jun 1, 2025
9fc5d20
script to generate classes props
aditya0by0 Jun 2, 2025
93e9b73
save prediction to csv for predict operation mode
aditya0by0 Jun 2, 2025
954431c
use multilabel cm
aditya0by0 Jun 2, 2025
6ce02a7
raise error for duplicate subclass/wrapper
aditya0by0 Jun 9, 2025
549a71f
add model load kwargs and move cls path to nn wrapper
aditya0by0 Jun 9, 2025
366c72b
refine chemlog wrapper
aditya0by0 Jun 9, 2025
2739c64
use data class instead of explicit reader, collator
aditya0by0 Jun 10, 2025
a96ae43
refine gnn wrapper
aditya0by0 Jun 10, 2025
b5ea7d1
Merge branch 'dev' into ensemble_br
aditya0by0 Jun 12, 2025
a6800b3
correct PPV and FPV key and rectify nn wrapper
aditya0by0 Jun 12, 2025
64b3e7e
load cls, load model as utilities
aditya0by0 Jun 12, 2025
7e673f2
evaluate_from_data_file not needed for gnn wrapper
aditya0by0 Jun 12, 2025
0c1be27
use dataclass and utilities
aditya0by0 Jun 12, 2025
e5ec383
pass config file for model, data instead of explicit params
aditya0by0 Jun 12, 2025
9a3328f
use utility for scripts
aditya0by0 Jun 12, 2025
f40eff9
dm should have splits_file_path or splits.csv in its dir
aditya0by0 Jun 12, 2025
e89ec4f
fix gnn logits error
aditya0by0 Jun 15, 2025
8d40637
fix gnn predict_from smiles list logits error
aditya0by0 Jun 15, 2025
3976531
chemlog wrapper return logits
aditya0by0 Jun 15, 2025
e4e9a28
save tp, fp, fn and tn as model properties
sfluegel05 Jun 24, 2025
41dd1c6
move ensemble to chebifier repo, move property calculation and utils …
sfluegel05 Jun 24, 2025
44b60dd
remove data processed dir param linking
aditya0by0 Jun 24, 2025
1986a31
Merge branch 'ensemble_br' of https://github.com/ChEB-AI/python-cheba…
aditya0by0 Jun 24, 2025
f16550c
fix utils imports
aditya0by0 Jun 24, 2025
c33dec3
update gitignore
aditya0by0 Jun 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,4 @@ chebai.egg-info
lightning_logs
logs
.isort.cfg
/.vscode
9 changes: 5 additions & 4 deletions chebai/models/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
from typing import Any, Dict, Optional, Union, Iterable
from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Optional, Union

import torch
from lightning.pytorch.core.module import LightningModule
from torchmetrics import Metric

from chebai.preprocessing.structures import XYData

Expand All @@ -12,7 +12,7 @@
_MODEL_REGISTRY = dict()


class ChebaiBaseNet(LightningModule):
class ChebaiBaseNet(LightningModule, ABC):
"""
Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule.

Expand Down Expand Up @@ -353,6 +353,7 @@ def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int):
logger=True,
)

@abstractmethod
def forward(self, x: Dict[str, Any]) -> torch.Tensor:
"""
Defines the forward pass.
Expand All @@ -363,7 +364,7 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
Returns:
torch.Tensor: The model output.
"""
raise NotImplementedError
pass

def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer:
"""
Expand Down
191 changes: 191 additions & 0 deletions chebai/result/_generate_classes_props_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import json
from pathlib import Path

import torch
from jsonargparse import CLI
from sklearn.metrics import multilabel_confusion_matrix

from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.result.utils import (
load_data_instance,
load_model_for_inference,
parse_config_file,
)


class ClassesPropertiesGenerator:
"""
Computes PPV (Positive Predictive Value) and NPV (Negative Predictive Value)
for each class in a multi-label classification problem using a PyTorch Lightning model.
"""

@staticmethod
def load_class_labels(path: Path) -> list[str]:
"""
Load a list of class names from a .json or .txt file.

Args:
path: Path to the class labels file (txt or json).

Returns:
A list of class names, one per line.
"""
path = Path(path)
with path.open() as f:
return [line.strip() for line in f if line.strip()]

@staticmethod
def compute_tpv_npv(
y_true: list[torch.Tensor],
y_pred: list[torch.Tensor],
class_names: list[str],
) -> dict[str, dict[str, float]]:
"""
Compute TPV (precision) and NPV for each class in a multi-label setting.

Args:
y_true: List of binary ground-truth label tensors, one tensor per sample.
y_pred: List of binary prediction tensors, one tensor per sample.
class_names: Ordered list of class names corresponding to class indices.

Returns:
Dictionary mapping each class name to its TPV and NPV metrics:
{
"class_name": {"PPV": float, "NPV": float},
...
}
"""
# Stack per-sample tensors into (n_samples, n_classes) numpy arrays
true_np = torch.stack(y_true).cpu().numpy().astype(int)
pred_np = torch.stack(y_pred).cpu().numpy().astype(int)

# Compute confusion matrix for each class
cm = multilabel_confusion_matrix(true_np, pred_np)

results: dict[str, dict[str, float]] = {}
for idx, cls_name in enumerate(class_names):
tn, fp, fn, tp = cm[idx].ravel()
tpv = tp / (tp + fp) if (tp + fp) > 0 else 0.0
npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
results[cls_name] = {
"PPV": round(tpv, 4),
"NPV": round(npv, 4),
"TN": int(tn),
"FP": int(fp),
"FN": int(fn),
"TP": int(tp),
}
return results

def generate_props(
self,
model_ckpt_path: str,
model_config_file_path: str,
data_config_file_path: str,
output_path: str | None = None,
) -> None:
"""
Run inference on validation set, compute TPV/NPV per class, and save to JSON.

Args:
model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
model_config_file_path: Path to yaml config file of the model.
data_config_file_path: Path to yaml config file of the data.
output_path: Optional path where to write the JSON metrics file.
Defaults to '<processed_dir_main>/classes.json'.
"""
print("Extracting validation data for computation...")

data_cls_path, data_cls_kwargs = parse_config_file(data_config_file_path)
data_module: XYBaseDataModule = load_data_instance(
data_cls_path, data_cls_kwargs
)

splits_file_path = Path(data_module.processed_dir_main, "splits.csv")
if data_module.splits_file_path is None:
if not splits_file_path.exists():
raise RuntimeError(
"Either the data module should be initialized with a `splits_file_path`, "
f"or the file `{splits_file_path}` must exists.\n"
"This is to prevent the data module from dynamically generating the splits."
)

print(
f"`splits_file_path` is not provided as an initialization parameter to the data module\n"
f"Using splits from the file {splits_file_path}"
)
data_module.splits_file_path = splits_file_path

model_class_path, model_kwargs = parse_config_file(model_config_file_path)
model = load_model_for_inference(
model_ckpt_path, model_class_path, model_kwargs
)

val_loader = data_module.val_dataloader()
print("Running inference on validation data...")

y_true, y_pred = [], []
for batch_idx, batch in enumerate(val_loader):
data = model._process_batch( # pylint: disable=W0212
batch, batch_idx=batch_idx
)
labels = data["labels"]
outputs = model(data, **data.get("model_kwargs", {}))
logits = outputs["logits"] if isinstance(outputs, dict) else outputs
preds = torch.sigmoid(logits) > 0.5
y_pred.extend(preds)
y_true.extend(labels)

print("Computing TPV and NPV metrics...")
classes_file = Path(data_module.processed_dir_main) / "classes.txt"
if output_path is None:
output_file = Path(data_module.processed_dir_main) / "classes.json"
else:
output_file = Path(output_path)

class_names = self.load_class_labels(classes_file)
metrics = self.compute_tpv_npv(y_true, y_pred, class_names)

with output_file.open("w") as f:
json.dump(metrics, f, indent=2)
print(f"Saved TPV/NPV metrics to {output_file}")


class Main:
"""
CLI wrapper for ClassesPropertiesGenerator.
"""

def generate(
self,
model_ckpt_path: str,
model_config_file_path: str,
data_config_file_path: str,
output_path: str | None = None,
) -> None:
"""
CLI command to generate TPV/NPV JSON.

Args:
model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
model_config_file_path: Path to yaml config file of the model.
data_config_file_path: Path to yaml config file of the data.
output_path: Optional path where to write the JSON metrics file.
Defaults to '<processed_dir_main>/classes.json'.
"""
generator = ClassesPropertiesGenerator()
generator.generate_props(
model_ckpt_path,
model_config_file_path,
data_config_file_path,
output_path,
)


if __name__ == "__main__":
# _generate_classes_props_json.py generate \
# --model_ckpt_path "model/ckpt/path" \
# --model_config_file_path "model/config/file/path" \
# --data_config_file_path "data/config/file/path" \
# --output_path "output/file/path" # Optional
CLI(Main, as_positional=False)
88 changes: 83 additions & 5 deletions chebai/result/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import importlib
import os
import shutil
from typing import Optional, Tuple, Union
from pathlib import Path
from typing import Optional, Tuple

import torch
import tqdm
import wandb
import wandb.util as wandb_util
import yaml

from chebai.models.base import ChebaiBaseNet
from chebai.models.electra import Electra
from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor

Expand Down Expand Up @@ -121,7 +123,7 @@ def evaluate_model(
save_batch_size = 128
n_saved = 1

print(f"")
print("")
for i in tqdm.tqdm(range(0, len(data_list), batch_size)):
if not (
skip_existing_preds
Expand Down Expand Up @@ -222,6 +224,82 @@ def load_results_from_buffer(
return test_preds, test_labels


def load_class(class_path: str) -> type:
module_path, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_path)
return getattr(module, class_name)


def load_data_instance(data_cls_path: str, data_cls_kwargs: dict):
assert isinstance(data_cls_kwargs, dict), "data_cls_kwargs must be a dict"
data_cls = load_class(data_cls_path)
assert isinstance(data_cls, type), f"{data_cls} is not a class."
assert issubclass(
data_cls, XYBaseDataModule
), f"{data_cls} must inherit from XYBaseDataModule"
return data_cls(**data_cls_kwargs)


def load_model_for_inference(
model_ckpt_path: str, model_cls_path: str, model_load_kwargs: dict, **kwargs
) -> ChebaiBaseNet:
"""
Loads a model checkpoint and its label-related properties.

Returns:
Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties.
"""
assert isinstance(model_load_kwargs, dict), "model_load_kwargs must be a dict"

model_name = kwargs.get("model_name", model_ckpt_path)

if not Path(model_ckpt_path).exists():
raise FileNotFoundError(
f"Model path '{model_ckpt_path}' for '{model_name}' does not exist."
)

lightning_cls = load_class(model_cls_path)

assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class."
assert issubclass(
lightning_cls, ChebaiBaseNet
), f"{lightning_cls} must inherit from ChebaiBaseNet"
try:
model = lightning_cls.load_from_checkpoint(model_ckpt_path, **model_load_kwargs)
except Exception as e:
raise RuntimeError(f"Error loading model {model_name} \n Error: {e}") from e

assert isinstance(
model, ChebaiBaseNet
), f"Model: {model}(Model Name: {model_name}) is not a ChebaiBaseNet instance."
model.eval()
model.freeze()
return model


def parse_config_file(config_path: str) -> tuple[str, dict]:
path = Path(config_path)

# Check file existence
if not path.exists():
raise FileNotFoundError(f"Config file not found: {config_path}")

# Check file extension
if path.suffix.lower() not in [".yml", ".yaml"]:
raise ValueError(
f"Unsupported config file type: {path.suffix}. Expected .yaml or .yml"
)

# Load YAML content
with open(path, "r") as f:
config: dict = yaml.safe_load(f)

class_path: str = config["class_path"]
init_args: dict = config.get("init_args", {})
assert isinstance(init_args, dict), "init_args must be a dictionary"
return class_path, init_args


if __name__ == "__main__":
import sys

Expand All @@ -231,5 +309,5 @@ def load_results_from_buffer(
)
os.makedirs(buffer_dir_concat, exist_ok=True)
preds, labels = load_results_from_buffer(buffer_dir, "cpu")
torch.save(preds, os.path.join(buffer_dir_concat, f"preds000.pt"))
torch.save(labels, os.path.join(buffer_dir_concat, f"labels000.pt"))
torch.save(preds, os.path.join(buffer_dir_concat, "preds000.pt"))
torch.save(labels, os.path.join(buffer_dir_concat, "labels000.pt"))