Skip to content
Merged
Show file tree
Hide file tree
Changes from 76 commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
e3297c6
add dummy dataset class - quick testing purpose
aditya0by0 Mar 17, 2025
f0e4758
custom typehints
aditya0by0 Mar 17, 2025
4fda565
model base: make forward method as abstract method rebase ensemble_f…
aditya0by0 Mar 17, 2025
7f7c6a0
ensemble: abstract code
aditya0by0 Mar 17, 2025
4d3f4f6
ignore lightning logs
aditya0by0 Mar 17, 2025
55959de
ensemble: fix for grad runtime error
aditya0by0 Mar 17, 2025
9513fea
ensemble: config for ensemble model
aditya0by0 Mar 17, 2025
2875385
ensemble: add MLP layer on top ensemble models
aditya0by0 Mar 21, 2025
7f892d9
base: fix import
aditya0by0 Mar 21, 2025
f60b2d8
ensemble: code improvements
aditya0by0 Mar 21, 2025
72a6b37
ensemble: add class path to config and load model via this class
aditya0by0 Mar 24, 2025
82a96dc
ensemble: changes for out of scope labels for certain models
aditya0by0 Mar 25, 2025
26f5ab4
ensemble: correct confidence val calculation
aditya0by0 Mar 31, 2025
9b851c5
ensemble: update for tpv/fpv value for each label
aditya0by0 Mar 31, 2025
0541ed2
ensemble: add docstrings and typehints
aditya0by0 Apr 1, 2025
2f3ecc6
Merge branch 'dev' into ensemble_br
aditya0by0 May 4, 2025
3ace30a
remove optimizer kwargs as not needed
aditya0by0 May 4, 2025
ddcdeac
add template to ensemble config
aditya0by0 May 4, 2025
eb1798c
Update .gitignore
aditya0by0 May 4, 2025
ed92ac5
each model's each label has TPV, FPV
aditya0by0 May 5, 2025
405026e
Merge branch 'dev' into ensemble_br
aditya0by0 May 7, 2025
0ec03b1
remove ensemble learning class
aditya0by0 May 15, 2025
dabe5ff
update code change
aditya0by0 May 15, 2025
7db384a
add ensemble base to new python dir
aditya0by0 May 16, 2025
65a51e0
add ensemble controller
aditya0by0 May 16, 2025
37d46f7
add utils.print_metrics to ensemble
aditya0by0 May 18, 2025
bc6e131
add consolidator
aditya0by0 May 18, 2025
b9dbd97
add to needed classes to init
aditya0by0 May 18, 2025
4d6856d
add rank_zero_info printing
aditya0by0 May 18, 2025
825916e
add script for running ensemble
aditya0by0 May 18, 2025
69c5263
ensemble minor changes
aditya0by0 May 18, 2025
4bd00ac
private instance var + reader_dir_name param
aditya0by0 May 19, 2025
50057f0
config for ensemble
aditya0by0 May 19, 2025
ee7a166
delete models/ensemble
aditya0by0 May 19, 2025
e9f1d95
delete old ensemble config
aditya0by0 May 19, 2025
fca0305
add docstrings + typehints
aditya0by0 May 19, 2025
de6a707
delete dummy dataset
aditya0by0 May 20, 2025
b471a05
raname script with _ prefix
aditya0by0 May 20, 2025
4c89dd3
wrapper base
aditya0by0 May 21, 2025
1563c76
nn wrapper
aditya0by0 May 21, 2025
2fec9ef
rename ensemble internal files with _ prefix
aditya0by0 May 21, 2025
682801f
chemlog wrapper
aditya0by0 May 22, 2025
2b2d458
gnn wrapper
aditya0by0 May 22, 2025
7cbb732
move related code from ensemble base to nn wrapper
aditya0by0 May 22, 2025
a7df384
move constants to wrappers
aditya0by0 May 22, 2025
ee0aef1
move prop loading to base
aditya0by0 May 22, 2025
8d8a748
move wrappers to ensemble
aditya0by0 May 22, 2025
00bd478
nn validate model config
aditya0by0 May 22, 2025
4f35007
utility for loading class
aditya0by0 May 22, 2025
a1a70eb
Create _constants.py
aditya0by0 May 22, 2025
f812cd7
update controller for wrapper
aditya0by0 May 22, 2025
c48bfd2
update base for wrapper
aditya0by0 May 22, 2025
bf3cf64
Update .gitignore
aditya0by0 Jun 1, 2025
76d8a79
predict method implementation for data file and list of smiles
aditya0by0 Jun 1, 2025
95d49c1
seperate method for evaluate and prediction
aditya0by0 Jun 1, 2025
a20ce76
store collated label or any model in instance var
aditya0by0 Jun 1, 2025
c0cb6c9
fix collated labels none error
aditya0by0 Jun 1, 2025
9fc5d20
script to generate classes props
aditya0by0 Jun 2, 2025
93e9b73
save prediction to csv for predict operation mode
aditya0by0 Jun 2, 2025
954431c
use multilabel cm
aditya0by0 Jun 2, 2025
6ce02a7
raise error for duplicate subclass/wrapper
aditya0by0 Jun 9, 2025
549a71f
add model load kwargs and move cls path to nn wrapper
aditya0by0 Jun 9, 2025
366c72b
refine chemlog wrapper
aditya0by0 Jun 9, 2025
2739c64
use data class instead of explicit reader, collator
aditya0by0 Jun 10, 2025
a96ae43
refine gnn wrapper
aditya0by0 Jun 10, 2025
b5ea7d1
Merge branch 'dev' into ensemble_br
aditya0by0 Jun 12, 2025
a6800b3
correct PPV and FPV key and rectify nn wrapper
aditya0by0 Jun 12, 2025
64b3e7e
load cls, load model as utilities
aditya0by0 Jun 12, 2025
7e673f2
evaluate_from_data_file not needed for gnn wrapper
aditya0by0 Jun 12, 2025
0c1be27
use dataclass and utilities
aditya0by0 Jun 12, 2025
e5ec383
pass config file for model, data instead of explicit params
aditya0by0 Jun 12, 2025
9a3328f
use utility for scripts
aditya0by0 Jun 12, 2025
f40eff9
dm should have splits_file_path or splits.csv in its dir
aditya0by0 Jun 12, 2025
e89ec4f
fix gnn logits error
aditya0by0 Jun 15, 2025
8d40637
fix gnn predict_from smiles list logits error
aditya0by0 Jun 15, 2025
3976531
chemlog wrapper return logits
aditya0by0 Jun 15, 2025
e4e9a28
save tp, fp, fn and tn as model properties
sfluegel05 Jun 24, 2025
41dd1c6
move ensemble to chebifier repo, move property calculation and utils …
sfluegel05 Jun 24, 2025
44b60dd
remove data processed dir param linking
aditya0by0 Jun 24, 2025
1986a31
Merge branch 'ensemble_br' of https://github.com/ChEB-AI/python-cheba…
aditya0by0 Jun 24, 2025
f16550c
fix utils imports
aditya0by0 Jun 24, 2025
c33dec3
update gitignore
aditya0by0 Jun 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,6 @@ cython_debug/
/logs
/results_buffer
electra_pretrained.ckpt

build
.virtual_documents
.jupyter
chebai.egg-info
lightning_logs
logs
/lightning_logs
.isort.cfg
/.vscode
6 changes: 6 additions & 0 deletions chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def call_data_methods(data: Type[XYBaseDataModule]):
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
)

parser.link_arguments(
"data.processed_dir_main",
"model.init_args.data_processed_dir_main",
apply_on="instantiate",
)

@staticmethod
def subcommands() -> Dict[str, Set[str]]:
"""
Expand Down
12 changes: 12 additions & 0 deletions chebai/ensemble/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from ._consolidator import WeightedMajorityVoting
from ._controller import NoActivationCondition
from ._wrappers import ChemLogWrapper, GNNWrapper, NNWrapper


class FullEnsembleWMV(NoActivationCondition, WeightedMajorityVoting):
"""Full Ensemble (no activation condition) with Weighted Majority Voting"""

pass


__all__ = ["FullEnsembleWMV", "NNWrapper", "GNNWrapper", "ChemLogWrapper"]
256 changes: 256 additions & 0 deletions chebai/ensemble/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
from abc import ABC, abstractmethod
from collections import deque
from pathlib import Path
from typing import Any, Deque, Dict

import pandas as pd
import torch

from chebai.result.classification import print_metrics

from ._constants import EVAL_OP, PRED_OP, WRAPPER_CLS_PATH


class EnsembleBase(ABC):
"""
Base class for ensemble models in the Chebai framework.

Handles loading, validating, and coordinating multiple models for ensemble prediction.
"""

def __init__(
self,
model_configs: Dict[str, Dict[str, Any]],
data_processed_dir_main: str,
operation_mode: str = EVAL_OP,
**kwargs: Any,
) -> None:
"""
Initializes the ensemble model and loads configurations, labels, and sets up the environment.

Args:
model_configs (Dict[str, Dict[str, Any]]): Dictionary of model configurations.
data_processed_dir_main (str): Path to the processed data directory.
**kwargs (Any): Additional arguments, such as 'input_dim' and '_validate_configs'.
"""
if bool(kwargs.get("_perform_validation_checks", True)):
self._perform_validation_checks(
model_configs, operation=operation_mode, **kwargs
)

self._model_configs: Dict[str, Dict[str, Any]] = model_configs
self._data_processed_dir_main: str = data_processed_dir_main
self._operation_mode: str = operation_mode
print(f"Ensemble operation: {self._operation_mode}")

# These instance variable will be set in method `_process_input_to_ensemble`
self._total_data_size: int | None = None
self._ensemble_input: list[str] | Path = self._process_input_to_ensemble(
**kwargs
)
print(f"Total data size (data.pkl) is {self._total_data_size}")

self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self._dm_labels: Dict[str, int] = self._load_data_module_labels()
self._num_of_labels: int = len(self._dm_labels)
print(f"Number of labels for this data is {self._num_of_labels} ")

self._num_models_per_label: torch.Tensor = torch.zeros(
1, self._num_of_labels, device=self._device
)
self._model_queue: Deque[str] = deque()
self._collated_labels: torch.Tensor | None = None

@classmethod
def _perform_validation_checks(
cls, model_configs: Dict[str, Dict[str, Any]], operation, **kwargs
) -> None:
"""
Validates model configuration dictionary for required keys and uniqueness.

Args:
model_configs (Dict[str, Dict[str, Any]]): Model configuration dictionary.

Raises:
AttributeError: If any model config is missing required keys.
ValueError: If duplicate paths are found for model checkpoint, class, or labels.
"""
if operation not in ["evaluate", "predict"]:
raise ValueError(
f"Invalid operation '{operation}'. Must be 'evaluate' or 'predict'."
)

if operation == "predict":
if kwargs.get("smiles_list_file_path", None):
raise ValueError(
"For 'predict' operation, 'smiles_list_file_path' must be provided."
)

if not Path(kwargs.get("smiles_list_file_path")).exists():
raise FileNotFoundError(f"{kwargs.get('smiles_list_file_path')}")

required_keys = {WRAPPER_CLS_PATH}

for model_name, config in model_configs.items():
missing_keys = required_keys - config.keys()
if missing_keys:
raise AttributeError(
f"Missing keys {missing_keys} in model '{model_name}' configuration."
)

def _process_input_to_ensemble(self, **kwargs: Any) -> list[str] | Path:
if self._operation_mode == PRED_OP:
p = Path(kwargs["smiles_list_file_path"])
smiles_list: list[str] = []
with open(p, "r") as f:
for line in f:
# Skip empty or whitespace-only lines
if line.strip():
# Split on whitespace and take the first item as the SMILES
smiles = line.strip().split()[0]
smiles_list.append(smiles)
self._total_data_size = len(smiles_list)
return smiles_list
elif self._operation_mode == EVAL_OP:
processed_dir_path = Path(self._data_processed_dir_main)
data_pkl_path = processed_dir_path / "data.pkl"
if not data_pkl_path.exists():
raise FileNotFoundError(
f"data.pkl does not exist in the {processed_dir_path} directory"
)
self._total_data_size = len(pd.read_pickle(data_pkl_path))
return data_pkl_path
else:
raise ValueError("Invalid operation")

def _load_data_module_labels(self) -> dict[str, int]:
"""
Loads class labels from the classes.txt file and sets internal label mapping.

Raises:
FileNotFoundError: If the expected classes.txt file is not found.
"""
classes_file_path = Path(self._data_processed_dir_main) / "classes.txt"
if not classes_file_path.exists():
raise FileNotFoundError(f"{classes_file_path} does not exist")
print(f"Loading {classes_file_path} ....")

dm_labels_dict = {}
with open(classes_file_path, "r") as f:
for line in f:
label = line.strip()
if label not in dm_labels_dict:
dm_labels_dict[label] = len(dm_labels_dict)
return dm_labels_dict

def run_ensemble(self) -> None:
"""
Executes the full ensemble prediction pipeline, aggregating predictions and printing metrics.
"""
assert self._total_data_size is not None and self._num_of_labels is not None
true_scores = torch.zeros(
self._total_data_size, self._num_of_labels, device=self._device
)
false_scores = torch.zeros(
self._total_data_size, self._num_of_labels, device=self._device
)

print(
f"Running {self.__class__.__name__} ensemble for {self._operation_mode} operation..."
)
while self._model_queue:
model_name = self._model_queue.popleft()
print(f"Processing model: {model_name}")

print("\t Passing model to controller to generate predictions...")
controller_output = self._controller(model_name, self._ensemble_input)

print("\t Passing predictions to consolidator for aggregation...")
self._consolidator(
pred_conf_dict=controller_output["pred_conf_dict"],
model_props=controller_output["model_props"],
true_scores=true_scores,
false_scores=false_scores,
)

final_preds = self._consolidate_on_finish(
true_scores=true_scores, false_scores=false_scores
)

if self._operation_mode == EVAL_OP:
assert (
self._collated_labels is not None
), "Collated labels must be set for evaluation operation."
print_metrics(
final_preds,
self._collated_labels,
self._device,
classes=list(self._dm_labels.keys()),
)
else:
# Get SMILES and label names
smiles_list = self._ensemble_input
label_names = list(self._dm_labels.keys())
# Efficient conversion from tensor to NumPy
preds_np = final_preds.detach().cpu().numpy()

assert (
len(smiles_list) == preds_np.shape[0]
), "Length of SMILES list does not match number of predictions."
assert (
len(label_names) == preds_np.shape[1]
), "Number of label names does not match number of predictions."

# Build DataFrame
df = pd.DataFrame(preds_np, columns=label_names)
df.insert(0, "SMILES", smiles_list)

# Save to CSV
output_path = (
Path(self._data_processed_dir_main) / "ensemble_predictions.csv"
)
df.to_csv(output_path, index=False)

print(f"Predictions saved to {output_path}")

@abstractmethod
def _controller(
self,
model_name: str,
model_input: list[str] | Path,
**kwargs: Any,
) -> Dict[str, torch.Tensor]:
"""
Abstract method to define model-specific prediction logic.

Returns:
Dict[str, torch.Tensor]: Predictions or confidence scores.
"""

@abstractmethod
def _consolidator(
self,
*,
pred_conf_dict: Dict[str, torch.Tensor],
model_props: Dict[str, torch.Tensor],
true_scores: torch.Tensor,
false_scores: torch.Tensor,
**kwargs: Any,
) -> None:
"""
Abstract method to define aggregation logic.

Should update the provided `true_scores` and `false_scores`.
"""

@abstractmethod
def _consolidate_on_finish(
self, *, true_scores: torch.Tensor, false_scores: torch.Tensor
) -> torch.Tensor:
"""
Abstract method to produce final predictions after all models have been evaluated.

Returns:
torch.Tensor: Final aggregated predictions.
"""
Loading