-
Notifications
You must be signed in to change notification settings - Fork 6
Ensemble Models #77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Ensemble Models #77
Changes from 78 commits
Commits
Show all changes
82 commits
Select commit
Hold shift + click to select a range
e3297c6
add dummy dataset class - quick testing purpose
aditya0by0 f0e4758
custom typehints
aditya0by0 4fda565
model base: make forward method as abstract method rebase ensemble_f…
aditya0by0 7f7c6a0
ensemble: abstract code
aditya0by0 4d3f4f6
ignore lightning logs
aditya0by0 55959de
ensemble: fix for grad runtime error
aditya0by0 9513fea
ensemble: config for ensemble model
aditya0by0 2875385
ensemble: add MLP layer on top ensemble models
aditya0by0 7f892d9
base: fix import
aditya0by0 f60b2d8
ensemble: code improvements
aditya0by0 72a6b37
ensemble: add class path to config and load model via this class
aditya0by0 82a96dc
ensemble: changes for out of scope labels for certain models
aditya0by0 26f5ab4
ensemble: correct confidence val calculation
aditya0by0 9b851c5
ensemble: update for tpv/fpv value for each label
aditya0by0 0541ed2
ensemble: add docstrings and typehints
aditya0by0 2f3ecc6
Merge branch 'dev' into ensemble_br
aditya0by0 3ace30a
remove optimizer kwargs as not needed
aditya0by0 ddcdeac
add template to ensemble config
aditya0by0 eb1798c
Update .gitignore
aditya0by0 ed92ac5
each model's each label has TPV, FPV
aditya0by0 405026e
Merge branch 'dev' into ensemble_br
aditya0by0 0ec03b1
remove ensemble learning class
aditya0by0 dabe5ff
update code change
aditya0by0 7db384a
add ensemble base to new python dir
aditya0by0 65a51e0
add ensemble controller
aditya0by0 37d46f7
add utils.print_metrics to ensemble
aditya0by0 bc6e131
add consolidator
aditya0by0 b9dbd97
add to needed classes to init
aditya0by0 4d6856d
add rank_zero_info printing
aditya0by0 825916e
add script for running ensemble
aditya0by0 69c5263
ensemble minor changes
aditya0by0 4bd00ac
private instance var + reader_dir_name param
aditya0by0 50057f0
config for ensemble
aditya0by0 ee7a166
delete models/ensemble
aditya0by0 e9f1d95
delete old ensemble config
aditya0by0 fca0305
add docstrings + typehints
aditya0by0 de6a707
delete dummy dataset
aditya0by0 b471a05
raname script with _ prefix
aditya0by0 4c89dd3
wrapper base
aditya0by0 1563c76
nn wrapper
aditya0by0 2fec9ef
rename ensemble internal files with _ prefix
aditya0by0 682801f
chemlog wrapper
aditya0by0 2b2d458
gnn wrapper
aditya0by0 7cbb732
move related code from ensemble base to nn wrapper
aditya0by0 a7df384
move constants to wrappers
aditya0by0 ee0aef1
move prop loading to base
aditya0by0 8d8a748
move wrappers to ensemble
aditya0by0 00bd478
nn validate model config
aditya0by0 4f35007
utility for loading class
aditya0by0 a1a70eb
Create _constants.py
aditya0by0 f812cd7
update controller for wrapper
aditya0by0 c48bfd2
update base for wrapper
aditya0by0 bf3cf64
Update .gitignore
aditya0by0 76d8a79
predict method implementation for data file and list of smiles
aditya0by0 95d49c1
seperate method for evaluate and prediction
aditya0by0 a20ce76
store collated label or any model in instance var
aditya0by0 c0cb6c9
fix collated labels none error
aditya0by0 9fc5d20
script to generate classes props
aditya0by0 93e9b73
save prediction to csv for predict operation mode
aditya0by0 954431c
use multilabel cm
aditya0by0 6ce02a7
raise error for duplicate subclass/wrapper
aditya0by0 549a71f
add model load kwargs and move cls path to nn wrapper
aditya0by0 366c72b
refine chemlog wrapper
aditya0by0 2739c64
use data class instead of explicit reader, collator
aditya0by0 a96ae43
refine gnn wrapper
aditya0by0 b5ea7d1
Merge branch 'dev' into ensemble_br
aditya0by0 a6800b3
correct PPV and FPV key and rectify nn wrapper
aditya0by0 64b3e7e
load cls, load model as utilities
aditya0by0 7e673f2
evaluate_from_data_file not needed for gnn wrapper
aditya0by0 0c1be27
use dataclass and utilities
aditya0by0 e5ec383
pass config file for model, data instead of explicit params
aditya0by0 9a3328f
use utility for scripts
aditya0by0 f40eff9
dm should have splits_file_path or splits.csv in its dir
aditya0by0 e89ec4f
fix gnn logits error
aditya0by0 8d40637
fix gnn predict_from smiles list logits error
aditya0by0 3976531
chemlog wrapper return logits
aditya0by0 e4e9a28
save tp, fp, fn and tn as model properties
sfluegel05 41dd1c6
move ensemble to chebifier repo, move property calculation and utils …
sfluegel05 44b60dd
remove data processed dir param linking
aditya0by0 1986a31
Merge branch 'ensemble_br' of https://github.com/ChEB-AI/python-cheba…
aditya0by0 f16550c
fix utils imports
aditya0by0 c33dec3
update gitignore
aditya0by0 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,191 @@ | ||
| import json | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
| from jsonargparse import CLI | ||
| from sklearn.metrics import multilabel_confusion_matrix | ||
|
|
||
| from chebai.preprocessing.datasets.base import XYBaseDataModule | ||
| from chebai.result.utils import ( | ||
| load_data_instance, | ||
| load_model_for_inference, | ||
| parse_config_file, | ||
| ) | ||
|
|
||
|
|
||
| class ClassesPropertiesGenerator: | ||
| """ | ||
| Computes PPV (Positive Predictive Value) and NPV (Negative Predictive Value) | ||
| for each class in a multi-label classification problem using a PyTorch Lightning model. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def load_class_labels(path: Path) -> list[str]: | ||
| """ | ||
| Load a list of class names from a .json or .txt file. | ||
|
|
||
| Args: | ||
| path: Path to the class labels file (txt or json). | ||
|
|
||
| Returns: | ||
| A list of class names, one per line. | ||
| """ | ||
| path = Path(path) | ||
| with path.open() as f: | ||
| return [line.strip() for line in f if line.strip()] | ||
|
|
||
| @staticmethod | ||
| def compute_tpv_npv( | ||
| y_true: list[torch.Tensor], | ||
| y_pred: list[torch.Tensor], | ||
| class_names: list[str], | ||
| ) -> dict[str, dict[str, float]]: | ||
| """ | ||
| Compute TPV (precision) and NPV for each class in a multi-label setting. | ||
|
|
||
| Args: | ||
| y_true: List of binary ground-truth label tensors, one tensor per sample. | ||
| y_pred: List of binary prediction tensors, one tensor per sample. | ||
| class_names: Ordered list of class names corresponding to class indices. | ||
|
|
||
| Returns: | ||
| Dictionary mapping each class name to its TPV and NPV metrics: | ||
| { | ||
| "class_name": {"PPV": float, "NPV": float}, | ||
| ... | ||
| } | ||
| """ | ||
| # Stack per-sample tensors into (n_samples, n_classes) numpy arrays | ||
| true_np = torch.stack(y_true).cpu().numpy().astype(int) | ||
| pred_np = torch.stack(y_pred).cpu().numpy().astype(int) | ||
|
|
||
| # Compute confusion matrix for each class | ||
| cm = multilabel_confusion_matrix(true_np, pred_np) | ||
|
|
||
| results: dict[str, dict[str, float]] = {} | ||
| for idx, cls_name in enumerate(class_names): | ||
| tn, fp, fn, tp = cm[idx].ravel() | ||
| tpv = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | ||
| npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0 | ||
| results[cls_name] = { | ||
| "PPV": round(tpv, 4), | ||
| "NPV": round(npv, 4), | ||
| "TN": int(tn), | ||
| "FP": int(fp), | ||
| "FN": int(fn), | ||
| "TP": int(tp), | ||
| } | ||
| return results | ||
|
|
||
| def generate_props( | ||
| self, | ||
| model_ckpt_path: str, | ||
| model_config_file_path: str, | ||
| data_config_file_path: str, | ||
| output_path: str | None = None, | ||
| ) -> None: | ||
| """ | ||
| Run inference on validation set, compute TPV/NPV per class, and save to JSON. | ||
|
|
||
| Args: | ||
| model_ckpt_path: Path to the PyTorch Lightning checkpoint file. | ||
| model_config_file_path: Path to yaml config file of the model. | ||
| data_config_file_path: Path to yaml config file of the data. | ||
| output_path: Optional path where to write the JSON metrics file. | ||
| Defaults to '<processed_dir_main>/classes.json'. | ||
| """ | ||
| print("Extracting validation data for computation...") | ||
|
|
||
| data_cls_path, data_cls_kwargs = parse_config_file(data_config_file_path) | ||
| data_module: XYBaseDataModule = load_data_instance( | ||
| data_cls_path, data_cls_kwargs | ||
| ) | ||
|
|
||
| splits_file_path = Path(data_module.processed_dir_main, "splits.csv") | ||
| if data_module.splits_file_path is None: | ||
| if not splits_file_path.exists(): | ||
| raise RuntimeError( | ||
| "Either the data module should be initialized with a `splits_file_path`, " | ||
| f"or the file `{splits_file_path}` must exists.\n" | ||
| "This is to prevent the data module from dynamically generating the splits." | ||
| ) | ||
|
|
||
| print( | ||
| f"`splits_file_path` is not provided as an initialization parameter to the data module\n" | ||
| f"Using splits from the file {splits_file_path}" | ||
| ) | ||
| data_module.splits_file_path = splits_file_path | ||
|
|
||
| model_class_path, model_kwargs = parse_config_file(model_config_file_path) | ||
| model = load_model_for_inference( | ||
| model_ckpt_path, model_class_path, model_kwargs | ||
| ) | ||
|
|
||
| val_loader = data_module.val_dataloader() | ||
| print("Running inference on validation data...") | ||
|
|
||
| y_true, y_pred = [], [] | ||
| for batch_idx, batch in enumerate(val_loader): | ||
| data = model._process_batch( # pylint: disable=W0212 | ||
| batch, batch_idx=batch_idx | ||
| ) | ||
| labels = data["labels"] | ||
| outputs = model(data, **data.get("model_kwargs", {})) | ||
| logits = outputs["logits"] if isinstance(outputs, dict) else outputs | ||
| preds = torch.sigmoid(logits) > 0.5 | ||
| y_pred.extend(preds) | ||
| y_true.extend(labels) | ||
|
|
||
| print("Computing TPV and NPV metrics...") | ||
| classes_file = Path(data_module.processed_dir_main) / "classes.txt" | ||
| if output_path is None: | ||
| output_file = Path(data_module.processed_dir_main) / "classes.json" | ||
| else: | ||
| output_file = Path(output_path) | ||
|
|
||
| class_names = self.load_class_labels(classes_file) | ||
| metrics = self.compute_tpv_npv(y_true, y_pred, class_names) | ||
|
|
||
| with output_file.open("w") as f: | ||
| json.dump(metrics, f, indent=2) | ||
| print(f"Saved TPV/NPV metrics to {output_file}") | ||
|
|
||
|
|
||
| class Main: | ||
| """ | ||
| CLI wrapper for ClassesPropertiesGenerator. | ||
| """ | ||
|
|
||
| def generate( | ||
| self, | ||
| model_ckpt_path: str, | ||
| model_config_file_path: str, | ||
| data_config_file_path: str, | ||
| output_path: str | None = None, | ||
| ) -> None: | ||
| """ | ||
| CLI command to generate TPV/NPV JSON. | ||
|
|
||
| Args: | ||
| model_ckpt_path: Path to the PyTorch Lightning checkpoint file. | ||
| model_config_file_path: Path to yaml config file of the model. | ||
| data_config_file_path: Path to yaml config file of the data. | ||
| output_path: Optional path where to write the JSON metrics file. | ||
| Defaults to '<processed_dir_main>/classes.json'. | ||
| """ | ||
| generator = ClassesPropertiesGenerator() | ||
| generator.generate_props( | ||
| model_ckpt_path, | ||
| model_config_file_path, | ||
| data_config_file_path, | ||
| output_path, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| # _generate_classes_props_json.py generate \ | ||
| # --model_ckpt_path "model/ckpt/path" \ | ||
| # --model_config_file_path "model/config/file/path" \ | ||
| # --data_config_file_path "data/config/file/path" \ | ||
| # --output_path "output/file/path" # Optional | ||
| CLI(Main, as_positional=False) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.