Skip to content

Commit 2ca72e9

Browse files
authored
Merge pull request #77 from ChEB-AI/ensemble_br
Ensemble Models
2 parents 1ef2ff7 + c33dec3 commit 2ca72e9

File tree

4 files changed

+280
-9
lines changed

4 files changed

+280
-9
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,4 @@ chebai.egg-info
175175
lightning_logs
176176
logs
177177
.isort.cfg
178+
/.vscode

chebai/models/base.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import logging
2-
from typing import Any, Dict, Optional, Union, Iterable
2+
from abc import ABC, abstractmethod
3+
from typing import Any, Dict, Iterable, Optional, Union
34

45
import torch
56
from lightning.pytorch.core.module import LightningModule
6-
from torchmetrics import Metric
77

88
from chebai.preprocessing.structures import XYData
99

@@ -12,7 +12,7 @@
1212
_MODEL_REGISTRY = dict()
1313

1414

15-
class ChebaiBaseNet(LightningModule):
15+
class ChebaiBaseNet(LightningModule, ABC):
1616
"""
1717
Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule.
1818
@@ -353,6 +353,7 @@ def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int):
353353
logger=True,
354354
)
355355

356+
@abstractmethod
356357
def forward(self, x: Dict[str, Any]) -> torch.Tensor:
357358
"""
358359
Defines the forward pass.
@@ -363,7 +364,7 @@ def forward(self, x: Dict[str, Any]) -> torch.Tensor:
363364
Returns:
364365
torch.Tensor: The model output.
365366
"""
366-
raise NotImplementedError
367+
pass
367368

368369
def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer:
369370
"""
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import json
2+
from pathlib import Path
3+
4+
import torch
5+
from jsonargparse import CLI
6+
from sklearn.metrics import multilabel_confusion_matrix
7+
8+
from chebai.preprocessing.datasets.base import XYBaseDataModule
9+
from chebai.result.utils import (
10+
load_data_instance,
11+
load_model_for_inference,
12+
parse_config_file,
13+
)
14+
15+
16+
class ClassesPropertiesGenerator:
17+
"""
18+
Computes PPV (Positive Predictive Value) and NPV (Negative Predictive Value)
19+
for each class in a multi-label classification problem using a PyTorch Lightning model.
20+
"""
21+
22+
@staticmethod
23+
def load_class_labels(path: Path) -> list[str]:
24+
"""
25+
Load a list of class names from a .json or .txt file.
26+
27+
Args:
28+
path: Path to the class labels file (txt or json).
29+
30+
Returns:
31+
A list of class names, one per line.
32+
"""
33+
path = Path(path)
34+
with path.open() as f:
35+
return [line.strip() for line in f if line.strip()]
36+
37+
@staticmethod
38+
def compute_tpv_npv(
39+
y_true: list[torch.Tensor],
40+
y_pred: list[torch.Tensor],
41+
class_names: list[str],
42+
) -> dict[str, dict[str, float]]:
43+
"""
44+
Compute TPV (precision) and NPV for each class in a multi-label setting.
45+
46+
Args:
47+
y_true: List of binary ground-truth label tensors, one tensor per sample.
48+
y_pred: List of binary prediction tensors, one tensor per sample.
49+
class_names: Ordered list of class names corresponding to class indices.
50+
51+
Returns:
52+
Dictionary mapping each class name to its TPV and NPV metrics:
53+
{
54+
"class_name": {"PPV": float, "NPV": float},
55+
...
56+
}
57+
"""
58+
# Stack per-sample tensors into (n_samples, n_classes) numpy arrays
59+
true_np = torch.stack(y_true).cpu().numpy().astype(int)
60+
pred_np = torch.stack(y_pred).cpu().numpy().astype(int)
61+
62+
# Compute confusion matrix for each class
63+
cm = multilabel_confusion_matrix(true_np, pred_np)
64+
65+
results: dict[str, dict[str, float]] = {}
66+
for idx, cls_name in enumerate(class_names):
67+
tn, fp, fn, tp = cm[idx].ravel()
68+
tpv = tp / (tp + fp) if (tp + fp) > 0 else 0.0
69+
npv = tn / (tn + fn) if (tn + fn) > 0 else 0.0
70+
results[cls_name] = {
71+
"PPV": round(tpv, 4),
72+
"NPV": round(npv, 4),
73+
"TN": int(tn),
74+
"FP": int(fp),
75+
"FN": int(fn),
76+
"TP": int(tp),
77+
}
78+
return results
79+
80+
def generate_props(
81+
self,
82+
model_ckpt_path: str,
83+
model_config_file_path: str,
84+
data_config_file_path: str,
85+
output_path: str | None = None,
86+
) -> None:
87+
"""
88+
Run inference on validation set, compute TPV/NPV per class, and save to JSON.
89+
90+
Args:
91+
model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
92+
model_config_file_path: Path to yaml config file of the model.
93+
data_config_file_path: Path to yaml config file of the data.
94+
output_path: Optional path where to write the JSON metrics file.
95+
Defaults to '<processed_dir_main>/classes.json'.
96+
"""
97+
print("Extracting validation data for computation...")
98+
99+
data_cls_path, data_cls_kwargs = parse_config_file(data_config_file_path)
100+
data_module: XYBaseDataModule = load_data_instance(
101+
data_cls_path, data_cls_kwargs
102+
)
103+
104+
splits_file_path = Path(data_module.processed_dir_main, "splits.csv")
105+
if data_module.splits_file_path is None:
106+
if not splits_file_path.exists():
107+
raise RuntimeError(
108+
"Either the data module should be initialized with a `splits_file_path`, "
109+
f"or the file `{splits_file_path}` must exists.\n"
110+
"This is to prevent the data module from dynamically generating the splits."
111+
)
112+
113+
print(
114+
f"`splits_file_path` is not provided as an initialization parameter to the data module\n"
115+
f"Using splits from the file {splits_file_path}"
116+
)
117+
data_module.splits_file_path = splits_file_path
118+
119+
model_class_path, model_kwargs = parse_config_file(model_config_file_path)
120+
model = load_model_for_inference(
121+
model_ckpt_path, model_class_path, model_kwargs
122+
)
123+
124+
val_loader = data_module.val_dataloader()
125+
print("Running inference on validation data...")
126+
127+
y_true, y_pred = [], []
128+
for batch_idx, batch in enumerate(val_loader):
129+
data = model._process_batch( # pylint: disable=W0212
130+
batch, batch_idx=batch_idx
131+
)
132+
labels = data["labels"]
133+
outputs = model(data, **data.get("model_kwargs", {}))
134+
logits = outputs["logits"] if isinstance(outputs, dict) else outputs
135+
preds = torch.sigmoid(logits) > 0.5
136+
y_pred.extend(preds)
137+
y_true.extend(labels)
138+
139+
print("Computing TPV and NPV metrics...")
140+
classes_file = Path(data_module.processed_dir_main) / "classes.txt"
141+
if output_path is None:
142+
output_file = Path(data_module.processed_dir_main) / "classes.json"
143+
else:
144+
output_file = Path(output_path)
145+
146+
class_names = self.load_class_labels(classes_file)
147+
metrics = self.compute_tpv_npv(y_true, y_pred, class_names)
148+
149+
with output_file.open("w") as f:
150+
json.dump(metrics, f, indent=2)
151+
print(f"Saved TPV/NPV metrics to {output_file}")
152+
153+
154+
class Main:
155+
"""
156+
CLI wrapper for ClassesPropertiesGenerator.
157+
"""
158+
159+
def generate(
160+
self,
161+
model_ckpt_path: str,
162+
model_config_file_path: str,
163+
data_config_file_path: str,
164+
output_path: str | None = None,
165+
) -> None:
166+
"""
167+
CLI command to generate TPV/NPV JSON.
168+
169+
Args:
170+
model_ckpt_path: Path to the PyTorch Lightning checkpoint file.
171+
model_config_file_path: Path to yaml config file of the model.
172+
data_config_file_path: Path to yaml config file of the data.
173+
output_path: Optional path where to write the JSON metrics file.
174+
Defaults to '<processed_dir_main>/classes.json'.
175+
"""
176+
generator = ClassesPropertiesGenerator()
177+
generator.generate_props(
178+
model_ckpt_path,
179+
model_config_file_path,
180+
data_config_file_path,
181+
output_path,
182+
)
183+
184+
185+
if __name__ == "__main__":
186+
# _generate_classes_props_json.py generate \
187+
# --model_ckpt_path "model/ckpt/path" \
188+
# --model_config_file_path "model/config/file/path" \
189+
# --data_config_file_path "data/config/file/path" \
190+
# --output_path "output/file/path" # Optional
191+
CLI(Main, as_positional=False)

chebai/result/utils.py

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1+
import importlib
12
import os
23
import shutil
3-
from typing import Optional, Tuple, Union
4+
from pathlib import Path
5+
from typing import Optional, Tuple
46

57
import torch
68
import tqdm
79
import wandb
810
import wandb.util as wandb_util
11+
import yaml
912

1013
from chebai.models.base import ChebaiBaseNet
11-
from chebai.models.electra import Electra
1214
from chebai.preprocessing.datasets.base import XYBaseDataModule
1315
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
1416

@@ -121,7 +123,7 @@ def evaluate_model(
121123
save_batch_size = 128
122124
n_saved = 1
123125

124-
print(f"")
126+
print("")
125127
for i in tqdm.tqdm(range(0, len(data_list), batch_size)):
126128
if not (
127129
skip_existing_preds
@@ -222,6 +224,82 @@ def load_results_from_buffer(
222224
return test_preds, test_labels
223225

224226

227+
def load_class(class_path: str) -> type:
228+
module_path, class_name = class_path.rsplit(".", 1)
229+
module = importlib.import_module(module_path)
230+
return getattr(module, class_name)
231+
232+
233+
def load_data_instance(data_cls_path: str, data_cls_kwargs: dict):
234+
assert isinstance(data_cls_kwargs, dict), "data_cls_kwargs must be a dict"
235+
data_cls = load_class(data_cls_path)
236+
assert isinstance(data_cls, type), f"{data_cls} is not a class."
237+
assert issubclass(
238+
data_cls, XYBaseDataModule
239+
), f"{data_cls} must inherit from XYBaseDataModule"
240+
return data_cls(**data_cls_kwargs)
241+
242+
243+
def load_model_for_inference(
244+
model_ckpt_path: str, model_cls_path: str, model_load_kwargs: dict, **kwargs
245+
) -> ChebaiBaseNet:
246+
"""
247+
Loads a model checkpoint and its label-related properties.
248+
249+
Returns:
250+
Tuple[LightningModule, Dict[str, torch.Tensor]]: The model and its label properties.
251+
"""
252+
assert isinstance(model_load_kwargs, dict), "model_load_kwargs must be a dict"
253+
254+
model_name = kwargs.get("model_name", model_ckpt_path)
255+
256+
if not Path(model_ckpt_path).exists():
257+
raise FileNotFoundError(
258+
f"Model path '{model_ckpt_path}' for '{model_name}' does not exist."
259+
)
260+
261+
lightning_cls = load_class(model_cls_path)
262+
263+
assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class."
264+
assert issubclass(
265+
lightning_cls, ChebaiBaseNet
266+
), f"{lightning_cls} must inherit from ChebaiBaseNet"
267+
try:
268+
model = lightning_cls.load_from_checkpoint(model_ckpt_path, **model_load_kwargs)
269+
except Exception as e:
270+
raise RuntimeError(f"Error loading model {model_name} \n Error: {e}") from e
271+
272+
assert isinstance(
273+
model, ChebaiBaseNet
274+
), f"Model: {model}(Model Name: {model_name}) is not a ChebaiBaseNet instance."
275+
model.eval()
276+
model.freeze()
277+
return model
278+
279+
280+
def parse_config_file(config_path: str) -> tuple[str, dict]:
281+
path = Path(config_path)
282+
283+
# Check file existence
284+
if not path.exists():
285+
raise FileNotFoundError(f"Config file not found: {config_path}")
286+
287+
# Check file extension
288+
if path.suffix.lower() not in [".yml", ".yaml"]:
289+
raise ValueError(
290+
f"Unsupported config file type: {path.suffix}. Expected .yaml or .yml"
291+
)
292+
293+
# Load YAML content
294+
with open(path, "r") as f:
295+
config: dict = yaml.safe_load(f)
296+
297+
class_path: str = config["class_path"]
298+
init_args: dict = config.get("init_args", {})
299+
assert isinstance(init_args, dict), "init_args must be a dictionary"
300+
return class_path, init_args
301+
302+
225303
if __name__ == "__main__":
226304
import sys
227305

@@ -231,5 +309,5 @@ def load_results_from_buffer(
231309
)
232310
os.makedirs(buffer_dir_concat, exist_ok=True)
233311
preds, labels = load_results_from_buffer(buffer_dir, "cpu")
234-
torch.save(preds, os.path.join(buffer_dir_concat, f"preds000.pt"))
235-
torch.save(labels, os.path.join(buffer_dir_concat, f"labels000.pt"))
312+
torch.save(preds, os.path.join(buffer_dir_concat, "preds000.pt"))
313+
torch.save(labels, os.path.join(buffer_dir_concat, "labels000.pt"))

0 commit comments

Comments
 (0)