Skip to content

Commit 6300bff

Browse files
authored
Merge pull request #1 from ChEB-AI/feature-ensemble
feature-ensemble
2 parents 7ea61d0 + 5cc2eba commit 6300bff

File tree

12 files changed

+587
-0
lines changed

12 files changed

+587
-0
lines changed

README.md

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,86 @@
11
# python-chebifier
22
An AI ensemble model for predicting chemical classes.
3+
4+
## Installation
5+
6+
```bash
7+
# Clone the repository
8+
git clone https://github.com/yourusername/python-chebifier.git
9+
cd python-chebifier
10+
11+
# Install the package
12+
pip install -e .
13+
```
14+
15+
## Usage
16+
17+
### Command Line Interface
18+
19+
The package provides a command-line interface (CLI) for making predictions using an ensemble model.
20+
21+
```bash
22+
# Get help
23+
python -m chebifier.cli --help
24+
25+
# Make predictions using a configuration file
26+
python -m chebifier.cli predict example_config.yml --smiles "CC(=O)OC1=CC=CC=C1C(=O)O" "C1=CC=C(C=C1)C(=O)O"
27+
28+
# Make predictions using SMILES from a file
29+
python -m chebifier.cli predict example_config.yml --smiles-file smiles.txt
30+
```
31+
32+
### Configuration File
33+
34+
The CLI requires a YAML configuration file that defines the ensemble model. Here's an example:
35+
36+
```yaml
37+
# Example configuration file for Chebifier ensemble model
38+
39+
# Each key in the top-level dictionary is a model name
40+
model1:
41+
# Required: type of model (must be one of the keys in MODEL_TYPES)
42+
type: electra
43+
# Required: name of the model
44+
model_name: electra_model1
45+
# Required: path to the checkpoint file
46+
ckpt_path: /path/to/checkpoint1.ckpt
47+
# Required: path to the target labels file
48+
target_labels_path: /path/to/target_labels1.txt
49+
# Optional: batch size for predictions (default is likely defined in the model)
50+
batch_size: 32
51+
52+
model2:
53+
type: electra
54+
model_name: electra_model2
55+
ckpt_path: /path/to/checkpoint2.ckpt
56+
target_labels_path: /path/to/target_labels2.txt
57+
batch_size: 64
58+
```
59+
60+
### Python API
61+
62+
You can also use the package programmatically:
63+
64+
```python
65+
from chebifier.ensemble.base_ensemble import BaseEnsemble
66+
import yaml
67+
68+
# Load configuration from YAML file
69+
with open('configs/example_config.yml', 'r') as f:
70+
config = yaml.safe_load(f)
71+
72+
# Instantiate ensemble model
73+
ensemble = BaseEnsemble(config)
74+
75+
# Make predictions
76+
smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C(=O)O"]
77+
predictions = ensemble.predict_smiles_list(smiles_list)
78+
79+
# Print results
80+
for smile, prediction in zip(smiles_list, predictions):
81+
print(f"SMILES: {smile}")
82+
if prediction:
83+
print(f"Predicted classes: {prediction}")
84+
else:
85+
print("No predictions")
86+
```

chebifier/__init__.py

Whitespace-only changes.

chebifier/cli.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
2+
3+
4+
import click
5+
import yaml
6+
import sys
7+
from chebifier.ensemble.base_ensemble import BaseEnsemble
8+
from chebifier.ensemble.weighted_majority_ensemble import WMVwithPPVNPVEnsemble, WMVwithF1Ensemble
9+
10+
11+
@click.group()
12+
def cli():
13+
"""Command line interface for Chebifier."""
14+
pass
15+
16+
ENSEMBLES = {
17+
"mv": BaseEnsemble,
18+
"wmv-ppvnpv": WMVwithPPVNPVEnsemble,
19+
"wmv-f1": WMVwithF1Ensemble
20+
}
21+
22+
@cli.command()
23+
@click.argument('config_file', type=click.Path(exists=True))
24+
@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict')
25+
@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)')
26+
@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)')
27+
@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)')
28+
def predict(config_file, smiles, smiles_file, output, ensemble_type):
29+
"""Predict ChEBI classes for SMILES strings using an ensemble model.
30+
31+
CONFIG_FILE is the path to a YAML configuration file for the ensemble model.
32+
"""
33+
# Load configuration from YAML file
34+
with open(config_file, 'r') as f:
35+
config = yaml.safe_load(f)
36+
37+
# Instantiate ensemble model
38+
ensemble = ENSEMBLES[ensemble_type](config)
39+
40+
# Collect SMILES strings from arguments and/or file
41+
smiles_list = list(smiles)
42+
if smiles_file:
43+
with open(smiles_file, 'r') as f:
44+
smiles_list.extend([line.strip() for line in f if line.strip()])
45+
46+
if not smiles_list:
47+
click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.")
48+
return
49+
50+
# Make predictions
51+
predictions = ensemble.predict_smiles_list(smiles_list)
52+
53+
if output:
54+
# save as json
55+
import json
56+
with open(output, 'w') as f:
57+
json.dump({smiles: pred for smiles, pred in zip(smiles_list, predictions)}, f, indent=2)
58+
59+
else:
60+
# Print results
61+
for i, (smiles, prediction) in enumerate(zip(smiles_list, predictions)):
62+
click.echo(f"Result for: {smiles}")
63+
if prediction:
64+
click.echo(f" Predicted classes: {', '.join(map(str, prediction))}")
65+
else:
66+
click.echo(" No predictions")
67+
68+
69+
if __name__ == '__main__':
70+
cli()

chebifier/ensemble/__init__.py

Whitespace-only changes.
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import os
2+
from abc import ABC
3+
import torch
4+
import tqdm
5+
from rdkit import Chem
6+
7+
from chebifier.prediction_models.base_predictor import BasePredictor
8+
from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor
9+
from chebifier.prediction_models.electra_predictor import ElectraPredictor
10+
from chebifier.prediction_models.gnn_predictor import ResGatedPredictor
11+
12+
MODEL_TYPES = {
13+
"electra": ElectraPredictor,
14+
"resgated": ResGatedPredictor,
15+
"chemlog": ChemLogPredictor
16+
}
17+
18+
class BaseEnsemble(ABC):
19+
20+
def __init__(self, model_configs: dict):
21+
self.models = []
22+
self.positive_prediction_threshold = 0.5
23+
for model_name, model_config in model_configs.items():
24+
model_cls = MODEL_TYPES[model_config["type"]]
25+
model_instance = model_cls(**model_config)
26+
assert isinstance(model_instance, BasePredictor)
27+
self.models.append(model_instance)
28+
29+
def gather_predictions(self, smiles_list):
30+
# get predictions from all models for the SMILES list
31+
# order them by alphabetically by label class
32+
model_predictions = []
33+
predicted_classes = set()
34+
for model in self.models:
35+
model_predictions.append(model.predict_smiles_list(smiles_list))
36+
for logits_for_smiles in model_predictions[-1]:
37+
if logits_for_smiles is not None:
38+
for cls in logits_for_smiles:
39+
predicted_classes.add(cls)
40+
print(f"Sorting predictions...")
41+
predicted_classes = sorted(list(predicted_classes))
42+
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
43+
ordered_logits = torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) * torch.nan
44+
for i, model_prediction in enumerate(model_predictions):
45+
for j, logits_for_smiles in tqdm.tqdm(enumerate(model_prediction),
46+
total=len(model_prediction),
47+
desc=f"Sorting predictions for {self.models[i].model_name}"):
48+
if logits_for_smiles is not None:
49+
for cls in logits_for_smiles:
50+
ordered_logits[j, predicted_classes[cls], i] = logits_for_smiles[cls]
51+
52+
return ordered_logits, predicted_classes
53+
54+
55+
def consolidate_predictions(self, predictions, predicted_classes, classwise_weights, **kwargs):
56+
"""
57+
Aggregates predictions from multiple models using weighted majority voting.
58+
Optimized version using tensor operations instead of for loops.
59+
"""
60+
num_smiles, num_classes, num_models = predictions.shape
61+
62+
# Create a mapping from class indices to class names for faster lookup
63+
class_names = list(predicted_classes.keys())
64+
class_indices = {predicted_classes[cls]: cls for cls in class_names}
65+
66+
# Get predictions for all classes
67+
valid_predictions = ~torch.isnan(predictions)
68+
valid_counts = valid_predictions.sum(dim=2) # Sum over models dimension
69+
70+
# Skip classes with no valid predictions
71+
has_valid_predictions = valid_counts > 0
72+
73+
# Calculate positive and negative predictions for all classes at once
74+
positive_mask = (predictions > 0.5) & valid_predictions
75+
negative_mask = (predictions < 0.5) & valid_predictions
76+
77+
confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold)
78+
79+
# Extract positive and negative weights
80+
pos_weights = classwise_weights[0] # Shape: (num_classes, num_models)
81+
neg_weights = classwise_weights[1] # Shape: (num_classes, num_models)
82+
83+
# Calculate weighted predictions using broadcasting
84+
# predictions shape: (num_smiles, num_classes, num_models)
85+
# weights shape: (num_classes, num_models)
86+
positive_weighted = positive_mask.float() * confidence * pos_weights.unsqueeze(0)
87+
negative_weighted = negative_mask.float() * confidence * neg_weights.unsqueeze(0)
88+
89+
# Sum over models dimension
90+
positive_sum = positive_weighted.sum(dim=2) # Shape: (num_smiles, num_classes)
91+
negative_sum = negative_weighted.sum(dim=2) # Shape: (num_smiles, num_classes)
92+
93+
# Determine which classes to include for each SMILES
94+
net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
95+
class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes)
96+
97+
# Convert tensor decisions to result list using list comprehension for efficiency
98+
result = [
99+
[class_indices[idx.item()] for idx in torch.nonzero(class_decisions[i], as_tuple=True)[0]]
100+
for i in range(num_smiles)
101+
]
102+
103+
return result
104+
105+
106+
def calculate_classwise_weights(self, predicted_classes):
107+
"""No weights, simple majority voting"""
108+
positive_weights = torch.ones(len(predicted_classes), len(self.models))
109+
negative_weights = torch.ones(len(predicted_classes), len(self.models))
110+
111+
return positive_weights, negative_weights
112+
113+
def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
114+
preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt"
115+
predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt"
116+
if not load_preds_if_possible or not os.path.isfile(preds_file):
117+
ordered_predictions = predicted_classes = self.gather_predictions(smiles_list)
118+
# save predictions
119+
torch.save(ordered_predictions, preds_file)
120+
with open(predicted_classes_file, "w") as f:
121+
for cls in predicted_classes:
122+
f.write(f"{cls}\n")
123+
else:
124+
print(f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}")
125+
ordered_predictions = torch.load(preds_file)
126+
with open(predicted_classes_file, "r") as f:
127+
predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())}
128+
129+
classwise_weights = self.calculate_classwise_weights(predicted_classes)
130+
aggregated_predictions = self.consolidate_predictions(ordered_predictions, predicted_classes, classwise_weights)
131+
return aggregated_predictions
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import torch
2+
3+
from chebifier.ensemble.base_ensemble import BaseEnsemble
4+
5+
6+
7+
class WMVwithPPVNPVEnsemble(BaseEnsemble):
8+
9+
def calculate_classwise_weights(self, predicted_classes):
10+
"""
11+
Given the positions of predicted classes in the predictions tensor, assign weights to each class. The
12+
result is two tensors of shape (num_predicted_classes, num_models). The weight for each class is the model_weight
13+
(default: 1) multiplied by the class-specific positive / negative weight (default 1).
14+
"""
15+
positive_weights = torch.ones(len(predicted_classes), len(self.models))
16+
negative_weights = torch.ones(len(predicted_classes), len(self.models))
17+
for j, model in enumerate(self.models):
18+
positive_weights[:, j] *= model.model_weight
19+
negative_weights[:, j] *= model.model_weight
20+
if model.classwise_weights is None:
21+
continue
22+
for cls, weights in model.classwise_weights.items():
23+
positive_weights[predicted_classes[cls], j] *= weights["PPV"]
24+
negative_weights[predicted_classes[cls], j] *= weights["NPV"]
25+
26+
print(f"Calculated model weightings. The averages for positive / negative weights are:")
27+
for i, model in enumerate(self.models):
28+
print(f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}")
29+
30+
return positive_weights, negative_weights
31+
32+
33+
class WMVwithF1Ensemble(BaseEnsemble):
34+
35+
def calculate_classwise_weights(self, predicted_classes):
36+
"""
37+
Given the positions of predicted classes in the predictions tensor, assign weights to each class. The
38+
result is two tensors of shape (num_predicted_classes, num_models). The weight for each class is the model_weight
39+
(default: 1) multiplied by the class-specific validation-f1 (default 1).
40+
"""
41+
weights_by_cls = torch.ones(len(predicted_classes), len(self.models))
42+
for j, model in enumerate(self.models):
43+
weights_by_cls[:, j] *= model.model_weight
44+
if model.classwise_weights is None:
45+
continue
46+
for cls, weights in model.classwise_weights.items():
47+
f1 = 2 * weights["TP"] / (2 * weights["TP"] + weights["FP"] + weights["FN"])
48+
weights_by_cls[predicted_classes[cls], j] *= f1
49+
50+
print(f"Calculated model weightings. The average weights are:")
51+
for i, model in enumerate(self.models):
52+
print(f"{model.model_name}: {weights_by_cls[:, i].mean().item():.3f}")
53+
54+
return weights_by_cls, weights_by_cls

chebifier/prediction_models/__init__.py

Whitespace-only changes.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from abc import ABC
2+
import json
3+
4+
class BasePredictor(ABC):
5+
6+
def __init__(self, model_name: str, model_weight: int = 1, classwise_weights_path: str = None, **kwargs):
7+
self.model_name = model_name
8+
self.model_weight = model_weight
9+
if classwise_weights_path is not None:
10+
self.classwise_weights = json.load(open(classwise_weights_path, encoding="utf-8"))
11+
else:
12+
self.classwise_weights = None
13+
14+
15+
def predict_smiles_list(self, smiles_list: list[str]) -> dict:
16+
raise NotImplementedError

0 commit comments

Comments
 (0)