-
Notifications
You must be signed in to change notification settings - Fork 6
Ensemble Models #77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+280
−9
Merged
Ensemble Models #77
Changes from 76 commits
Commits
Show all changes
82 commits
Select commit
Hold shift + click to select a range
e3297c6
add dummy dataset class - quick testing purpose
aditya0by0 f0e4758
custom typehints
aditya0by0 4fda565
model base: make forward method as abstract method rebase ensemble_f…
aditya0by0 7f7c6a0
ensemble: abstract code
aditya0by0 4d3f4f6
ignore lightning logs
aditya0by0 55959de
ensemble: fix for grad runtime error
aditya0by0 9513fea
ensemble: config for ensemble model
aditya0by0 2875385
ensemble: add MLP layer on top ensemble models
aditya0by0 7f892d9
base: fix import
aditya0by0 f60b2d8
ensemble: code improvements
aditya0by0 72a6b37
ensemble: add class path to config and load model via this class
aditya0by0 82a96dc
ensemble: changes for out of scope labels for certain models
aditya0by0 26f5ab4
ensemble: correct confidence val calculation
aditya0by0 9b851c5
ensemble: update for tpv/fpv value for each label
aditya0by0 0541ed2
ensemble: add docstrings and typehints
aditya0by0 2f3ecc6
Merge branch 'dev' into ensemble_br
aditya0by0 3ace30a
remove optimizer kwargs as not needed
aditya0by0 ddcdeac
add template to ensemble config
aditya0by0 eb1798c
Update .gitignore
aditya0by0 ed92ac5
each model's each label has TPV, FPV
aditya0by0 405026e
Merge branch 'dev' into ensemble_br
aditya0by0 0ec03b1
remove ensemble learning class
aditya0by0 dabe5ff
update code change
aditya0by0 7db384a
add ensemble base to new python dir
aditya0by0 65a51e0
add ensemble controller
aditya0by0 37d46f7
add utils.print_metrics to ensemble
aditya0by0 bc6e131
add consolidator
aditya0by0 b9dbd97
add to needed classes to init
aditya0by0 4d6856d
add rank_zero_info printing
aditya0by0 825916e
add script for running ensemble
aditya0by0 69c5263
ensemble minor changes
aditya0by0 4bd00ac
private instance var + reader_dir_name param
aditya0by0 50057f0
config for ensemble
aditya0by0 ee7a166
delete models/ensemble
aditya0by0 e9f1d95
delete old ensemble config
aditya0by0 fca0305
add docstrings + typehints
aditya0by0 de6a707
delete dummy dataset
aditya0by0 b471a05
raname script with _ prefix
aditya0by0 4c89dd3
wrapper base
aditya0by0 1563c76
nn wrapper
aditya0by0 2fec9ef
rename ensemble internal files with _ prefix
aditya0by0 682801f
chemlog wrapper
aditya0by0 2b2d458
gnn wrapper
aditya0by0 7cbb732
move related code from ensemble base to nn wrapper
aditya0by0 a7df384
move constants to wrappers
aditya0by0 ee0aef1
move prop loading to base
aditya0by0 8d8a748
move wrappers to ensemble
aditya0by0 00bd478
nn validate model config
aditya0by0 4f35007
utility for loading class
aditya0by0 a1a70eb
Create _constants.py
aditya0by0 f812cd7
update controller for wrapper
aditya0by0 c48bfd2
update base for wrapper
aditya0by0 bf3cf64
Update .gitignore
aditya0by0 76d8a79
predict method implementation for data file and list of smiles
aditya0by0 95d49c1
seperate method for evaluate and prediction
aditya0by0 a20ce76
store collated label or any model in instance var
aditya0by0 c0cb6c9
fix collated labels none error
aditya0by0 9fc5d20
script to generate classes props
aditya0by0 93e9b73
save prediction to csv for predict operation mode
aditya0by0 954431c
use multilabel cm
aditya0by0 6ce02a7
raise error for duplicate subclass/wrapper
aditya0by0 549a71f
add model load kwargs and move cls path to nn wrapper
aditya0by0 366c72b
refine chemlog wrapper
aditya0by0 2739c64
use data class instead of explicit reader, collator
aditya0by0 a96ae43
refine gnn wrapper
aditya0by0 b5ea7d1
Merge branch 'dev' into ensemble_br
aditya0by0 a6800b3
correct PPV and FPV key and rectify nn wrapper
aditya0by0 64b3e7e
load cls, load model as utilities
aditya0by0 7e673f2
evaluate_from_data_file not needed for gnn wrapper
aditya0by0 0c1be27
use dataclass and utilities
aditya0by0 e5ec383
pass config file for model, data instead of explicit params
aditya0by0 9a3328f
use utility for scripts
aditya0by0 f40eff9
dm should have splits_file_path or splits.csv in its dir
aditya0by0 e89ec4f
fix gnn logits error
aditya0by0 8d40637
fix gnn predict_from smiles list logits error
aditya0by0 3976531
chemlog wrapper return logits
aditya0by0 e4e9a28
save tp, fp, fn and tn as model properties
sfluegel05 41dd1c6
move ensemble to chebifier repo, move property calculation and utils …
sfluegel05 44b60dd
remove data processed dir param linking
aditya0by0 1986a31
Merge branch 'ensemble_br' of https://github.com/ChEB-AI/python-cheba…
aditya0by0 f16550c
fix utils imports
aditya0by0 c33dec3
update gitignore
aditya0by0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| from ._consolidator import WeightedMajorityVoting | ||
| from ._controller import NoActivationCondition | ||
| from ._wrappers import ChemLogWrapper, GNNWrapper, NNWrapper | ||
|
|
||
|
|
||
| class FullEnsembleWMV(NoActivationCondition, WeightedMajorityVoting): | ||
| """Full Ensemble (no activation condition) with Weighted Majority Voting""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| __all__ = ["FullEnsembleWMV", "NNWrapper", "GNNWrapper", "ChemLogWrapper"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,256 @@ | ||
| from abc import ABC, abstractmethod | ||
| from collections import deque | ||
| from pathlib import Path | ||
| from typing import Any, Deque, Dict | ||
|
|
||
| import pandas as pd | ||
| import torch | ||
|
|
||
| from chebai.result.classification import print_metrics | ||
|
|
||
| from ._constants import EVAL_OP, PRED_OP, WRAPPER_CLS_PATH | ||
|
|
||
|
|
||
| class EnsembleBase(ABC): | ||
| """ | ||
| Base class for ensemble models in the Chebai framework. | ||
|
|
||
| Handles loading, validating, and coordinating multiple models for ensemble prediction. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_configs: Dict[str, Dict[str, Any]], | ||
| data_processed_dir_main: str, | ||
| operation_mode: str = EVAL_OP, | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| """ | ||
| Initializes the ensemble model and loads configurations, labels, and sets up the environment. | ||
|
|
||
| Args: | ||
| model_configs (Dict[str, Dict[str, Any]]): Dictionary of model configurations. | ||
| data_processed_dir_main (str): Path to the processed data directory. | ||
| **kwargs (Any): Additional arguments, such as 'input_dim' and '_validate_configs'. | ||
| """ | ||
| if bool(kwargs.get("_perform_validation_checks", True)): | ||
| self._perform_validation_checks( | ||
| model_configs, operation=operation_mode, **kwargs | ||
| ) | ||
|
|
||
| self._model_configs: Dict[str, Dict[str, Any]] = model_configs | ||
| self._data_processed_dir_main: str = data_processed_dir_main | ||
| self._operation_mode: str = operation_mode | ||
| print(f"Ensemble operation: {self._operation_mode}") | ||
|
|
||
| # These instance variable will be set in method `_process_input_to_ensemble` | ||
| self._total_data_size: int | None = None | ||
| self._ensemble_input: list[str] | Path = self._process_input_to_ensemble( | ||
| **kwargs | ||
| ) | ||
| print(f"Total data size (data.pkl) is {self._total_data_size}") | ||
|
|
||
| self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
|
||
| self._dm_labels: Dict[str, int] = self._load_data_module_labels() | ||
| self._num_of_labels: int = len(self._dm_labels) | ||
| print(f"Number of labels for this data is {self._num_of_labels} ") | ||
|
|
||
| self._num_models_per_label: torch.Tensor = torch.zeros( | ||
| 1, self._num_of_labels, device=self._device | ||
| ) | ||
| self._model_queue: Deque[str] = deque() | ||
| self._collated_labels: torch.Tensor | None = None | ||
|
|
||
| @classmethod | ||
| def _perform_validation_checks( | ||
| cls, model_configs: Dict[str, Dict[str, Any]], operation, **kwargs | ||
| ) -> None: | ||
| """ | ||
| Validates model configuration dictionary for required keys and uniqueness. | ||
|
|
||
| Args: | ||
| model_configs (Dict[str, Dict[str, Any]]): Model configuration dictionary. | ||
|
|
||
| Raises: | ||
| AttributeError: If any model config is missing required keys. | ||
| ValueError: If duplicate paths are found for model checkpoint, class, or labels. | ||
| """ | ||
| if operation not in ["evaluate", "predict"]: | ||
| raise ValueError( | ||
| f"Invalid operation '{operation}'. Must be 'evaluate' or 'predict'." | ||
| ) | ||
|
|
||
| if operation == "predict": | ||
| if kwargs.get("smiles_list_file_path", None): | ||
| raise ValueError( | ||
| "For 'predict' operation, 'smiles_list_file_path' must be provided." | ||
| ) | ||
|
|
||
| if not Path(kwargs.get("smiles_list_file_path")).exists(): | ||
| raise FileNotFoundError(f"{kwargs.get('smiles_list_file_path')}") | ||
|
|
||
| required_keys = {WRAPPER_CLS_PATH} | ||
|
|
||
| for model_name, config in model_configs.items(): | ||
| missing_keys = required_keys - config.keys() | ||
| if missing_keys: | ||
| raise AttributeError( | ||
| f"Missing keys {missing_keys} in model '{model_name}' configuration." | ||
| ) | ||
|
|
||
| def _process_input_to_ensemble(self, **kwargs: Any) -> list[str] | Path: | ||
| if self._operation_mode == PRED_OP: | ||
| p = Path(kwargs["smiles_list_file_path"]) | ||
| smiles_list: list[str] = [] | ||
| with open(p, "r") as f: | ||
| for line in f: | ||
| # Skip empty or whitespace-only lines | ||
| if line.strip(): | ||
| # Split on whitespace and take the first item as the SMILES | ||
| smiles = line.strip().split()[0] | ||
| smiles_list.append(smiles) | ||
| self._total_data_size = len(smiles_list) | ||
| return smiles_list | ||
| elif self._operation_mode == EVAL_OP: | ||
| processed_dir_path = Path(self._data_processed_dir_main) | ||
| data_pkl_path = processed_dir_path / "data.pkl" | ||
| if not data_pkl_path.exists(): | ||
| raise FileNotFoundError( | ||
| f"data.pkl does not exist in the {processed_dir_path} directory" | ||
| ) | ||
| self._total_data_size = len(pd.read_pickle(data_pkl_path)) | ||
| return data_pkl_path | ||
| else: | ||
| raise ValueError("Invalid operation") | ||
|
|
||
| def _load_data_module_labels(self) -> dict[str, int]: | ||
| """ | ||
| Loads class labels from the classes.txt file and sets internal label mapping. | ||
|
|
||
| Raises: | ||
| FileNotFoundError: If the expected classes.txt file is not found. | ||
| """ | ||
| classes_file_path = Path(self._data_processed_dir_main) / "classes.txt" | ||
| if not classes_file_path.exists(): | ||
| raise FileNotFoundError(f"{classes_file_path} does not exist") | ||
| print(f"Loading {classes_file_path} ....") | ||
|
|
||
| dm_labels_dict = {} | ||
| with open(classes_file_path, "r") as f: | ||
| for line in f: | ||
| label = line.strip() | ||
| if label not in dm_labels_dict: | ||
| dm_labels_dict[label] = len(dm_labels_dict) | ||
| return dm_labels_dict | ||
|
|
||
| def run_ensemble(self) -> None: | ||
| """ | ||
| Executes the full ensemble prediction pipeline, aggregating predictions and printing metrics. | ||
| """ | ||
| assert self._total_data_size is not None and self._num_of_labels is not None | ||
| true_scores = torch.zeros( | ||
| self._total_data_size, self._num_of_labels, device=self._device | ||
| ) | ||
| false_scores = torch.zeros( | ||
| self._total_data_size, self._num_of_labels, device=self._device | ||
| ) | ||
|
|
||
| print( | ||
| f"Running {self.__class__.__name__} ensemble for {self._operation_mode} operation..." | ||
| ) | ||
| while self._model_queue: | ||
| model_name = self._model_queue.popleft() | ||
| print(f"Processing model: {model_name}") | ||
|
|
||
| print("\t Passing model to controller to generate predictions...") | ||
| controller_output = self._controller(model_name, self._ensemble_input) | ||
|
|
||
| print("\t Passing predictions to consolidator for aggregation...") | ||
| self._consolidator( | ||
| pred_conf_dict=controller_output["pred_conf_dict"], | ||
| model_props=controller_output["model_props"], | ||
| true_scores=true_scores, | ||
| false_scores=false_scores, | ||
| ) | ||
|
|
||
| final_preds = self._consolidate_on_finish( | ||
| true_scores=true_scores, false_scores=false_scores | ||
| ) | ||
|
|
||
| if self._operation_mode == EVAL_OP: | ||
| assert ( | ||
| self._collated_labels is not None | ||
| ), "Collated labels must be set for evaluation operation." | ||
| print_metrics( | ||
| final_preds, | ||
| self._collated_labels, | ||
| self._device, | ||
| classes=list(self._dm_labels.keys()), | ||
| ) | ||
| else: | ||
| # Get SMILES and label names | ||
| smiles_list = self._ensemble_input | ||
| label_names = list(self._dm_labels.keys()) | ||
| # Efficient conversion from tensor to NumPy | ||
| preds_np = final_preds.detach().cpu().numpy() | ||
|
|
||
| assert ( | ||
| len(smiles_list) == preds_np.shape[0] | ||
| ), "Length of SMILES list does not match number of predictions." | ||
| assert ( | ||
| len(label_names) == preds_np.shape[1] | ||
| ), "Number of label names does not match number of predictions." | ||
|
|
||
| # Build DataFrame | ||
| df = pd.DataFrame(preds_np, columns=label_names) | ||
| df.insert(0, "SMILES", smiles_list) | ||
|
|
||
| # Save to CSV | ||
| output_path = ( | ||
| Path(self._data_processed_dir_main) / "ensemble_predictions.csv" | ||
| ) | ||
| df.to_csv(output_path, index=False) | ||
|
|
||
| print(f"Predictions saved to {output_path}") | ||
|
|
||
| @abstractmethod | ||
| def _controller( | ||
| self, | ||
| model_name: str, | ||
| model_input: list[str] | Path, | ||
| **kwargs: Any, | ||
| ) -> Dict[str, torch.Tensor]: | ||
| """ | ||
| Abstract method to define model-specific prediction logic. | ||
|
|
||
| Returns: | ||
| Dict[str, torch.Tensor]: Predictions or confidence scores. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def _consolidator( | ||
| self, | ||
| *, | ||
| pred_conf_dict: Dict[str, torch.Tensor], | ||
| model_props: Dict[str, torch.Tensor], | ||
| true_scores: torch.Tensor, | ||
| false_scores: torch.Tensor, | ||
| **kwargs: Any, | ||
| ) -> None: | ||
| """ | ||
| Abstract method to define aggregation logic. | ||
|
|
||
| Should update the provided `true_scores` and `false_scores`. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def _consolidate_on_finish( | ||
| self, *, true_scores: torch.Tensor, false_scores: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Abstract method to produce final predictions after all models have been evaluated. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Final aggregated predictions. | ||
| """ |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.