Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,86 @@
# python-chebifier
An AI ensemble model for predicting chemical classes.

## Installation

```bash
# Clone the repository
git clone https://github.com/yourusername/python-chebifier.git
cd python-chebifier

# Install the package
pip install -e .
```

## Usage

### Command Line Interface

The package provides a command-line interface (CLI) for making predictions using an ensemble model.

```bash
# Get help
python -m chebifier.cli --help

# Make predictions using a configuration file
python -m chebifier.cli predict example_config.yml --smiles "CC(=O)OC1=CC=CC=C1C(=O)O" "C1=CC=C(C=C1)C(=O)O"

# Make predictions using SMILES from a file
python -m chebifier.cli predict example_config.yml --smiles-file smiles.txt
```

### Configuration File

The CLI requires a YAML configuration file that defines the ensemble model. Here's an example:

```yaml
# Example configuration file for Chebifier ensemble model

# Each key in the top-level dictionary is a model name
model1:
# Required: type of model (must be one of the keys in MODEL_TYPES)
type: electra
# Required: name of the model
model_name: electra_model1
# Required: path to the checkpoint file
ckpt_path: /path/to/checkpoint1.ckpt
# Required: path to the target labels file
target_labels_path: /path/to/target_labels1.txt
# Optional: batch size for predictions (default is likely defined in the model)
batch_size: 32

model2:
type: electra
model_name: electra_model2
ckpt_path: /path/to/checkpoint2.ckpt
target_labels_path: /path/to/target_labels2.txt
batch_size: 64
```

### Python API

You can also use the package programmatically:

```python
from chebifier.ensemble.base_ensemble import BaseEnsemble
import yaml

# Load configuration from YAML file
with open('configs/example_config.yml', 'r') as f:
config = yaml.safe_load(f)

# Instantiate ensemble model
ensemble = BaseEnsemble(config)

# Make predictions
smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C(=O)O"]
predictions = ensemble.predict_smiles_list(smiles_list)

# Print results
for smile, prediction in zip(smiles_list, predictions):
print(f"SMILES: {smile}")
if prediction:
print(f"Predicted classes: {prediction}")
else:
print("No predictions")
```
Empty file added chebifier/__init__.py
Empty file.
70 changes: 70 additions & 0 deletions chebifier/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@



import click
import yaml
import sys
from chebifier.ensemble.base_ensemble import BaseEnsemble
from chebifier.ensemble.weighted_majority_ensemble import WMVwithPPVNPVEnsemble, WMVwithF1Ensemble


@click.group()
def cli():
"""Command line interface for Chebifier."""
pass

ENSEMBLES = {
"mv": BaseEnsemble,
"wmv-ppvnpv": WMVwithPPVNPVEnsemble,
"wmv-f1": WMVwithF1Ensemble
}

@cli.command()
@click.argument('config_file', type=click.Path(exists=True))
@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict')
@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)')
@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)')
@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)')
def predict(config_file, smiles, smiles_file, output, ensemble_type):
"""Predict ChEBI classes for SMILES strings using an ensemble model.

CONFIG_FILE is the path to a YAML configuration file for the ensemble model.
"""
# Load configuration from YAML file
with open(config_file, 'r') as f:
config = yaml.safe_load(f)

# Instantiate ensemble model
ensemble = ENSEMBLES[ensemble_type](config)

# Collect SMILES strings from arguments and/or file
smiles_list = list(smiles)
if smiles_file:
with open(smiles_file, 'r') as f:
smiles_list.extend([line.strip() for line in f if line.strip()])

if not smiles_list:
click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.")
return

# Make predictions
predictions = ensemble.predict_smiles_list(smiles_list)

if output:
# save as json
import json
with open(output, 'w') as f:
json.dump({smiles: pred for smiles, pred in zip(smiles_list, predictions)}, f, indent=2)

else:
# Print results
for i, (smiles, prediction) in enumerate(zip(smiles_list, predictions)):
click.echo(f"Result for: {smiles}")
if prediction:
click.echo(f" Predicted classes: {', '.join(map(str, prediction))}")
else:
click.echo(" No predictions")


if __name__ == '__main__':
cli()
Empty file added chebifier/ensemble/__init__.py
Empty file.
131 changes: 131 additions & 0 deletions chebifier/ensemble/base_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os
from abc import ABC
import torch
import tqdm
from rdkit import Chem

from chebifier.prediction_models.base_predictor import BasePredictor
from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor
from chebifier.prediction_models.electra_predictor import ElectraPredictor
from chebifier.prediction_models.gnn_predictor import ResGatedPredictor

MODEL_TYPES = {
"electra": ElectraPredictor,
"resgated": ResGatedPredictor,
"chemlog": ChemLogPredictor
}

class BaseEnsemble(ABC):

def __init__(self, model_configs: dict):
self.models = []
self.positive_prediction_threshold = 0.5
for model_name, model_config in model_configs.items():
model_cls = MODEL_TYPES[model_config["type"]]
model_instance = model_cls(**model_config)
assert isinstance(model_instance, BasePredictor)
self.models.append(model_instance)

def gather_predictions(self, smiles_list):
# get predictions from all models for the SMILES list
# order them by alphabetically by label class
model_predictions = []
predicted_classes = set()
for model in self.models:
model_predictions.append(model.predict_smiles_list(smiles_list))
for logits_for_smiles in model_predictions[-1]:
if logits_for_smiles is not None:
for cls in logits_for_smiles:
predicted_classes.add(cls)
print(f"Sorting predictions...")
predicted_classes = sorted(list(predicted_classes))
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
ordered_logits = torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) * torch.nan
for i, model_prediction in enumerate(model_predictions):
for j, logits_for_smiles in tqdm.tqdm(enumerate(model_prediction),
total=len(model_prediction),
desc=f"Sorting predictions for {self.models[i].model_name}"):
if logits_for_smiles is not None:
for cls in logits_for_smiles:
ordered_logits[j, predicted_classes[cls], i] = logits_for_smiles[cls]

return ordered_logits, predicted_classes


def consolidate_predictions(self, predictions, predicted_classes, classwise_weights, **kwargs):
"""
Aggregates predictions from multiple models using weighted majority voting.
Optimized version using tensor operations instead of for loops.
"""
num_smiles, num_classes, num_models = predictions.shape

# Create a mapping from class indices to class names for faster lookup
class_names = list(predicted_classes.keys())
class_indices = {predicted_classes[cls]: cls for cls in class_names}

# Get predictions for all classes
valid_predictions = ~torch.isnan(predictions)
valid_counts = valid_predictions.sum(dim=2) # Sum over models dimension

# Skip classes with no valid predictions
has_valid_predictions = valid_counts > 0

# Calculate positive and negative predictions for all classes at once
positive_mask = (predictions > 0.5) & valid_predictions
negative_mask = (predictions < 0.5) & valid_predictions

confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold)

# Extract positive and negative weights
pos_weights = classwise_weights[0] # Shape: (num_classes, num_models)
neg_weights = classwise_weights[1] # Shape: (num_classes, num_models)

# Calculate weighted predictions using broadcasting
# predictions shape: (num_smiles, num_classes, num_models)
# weights shape: (num_classes, num_models)
positive_weighted = positive_mask.float() * confidence * pos_weights.unsqueeze(0)
negative_weighted = negative_mask.float() * confidence * neg_weights.unsqueeze(0)

# Sum over models dimension
positive_sum = positive_weighted.sum(dim=2) # Shape: (num_smiles, num_classes)
negative_sum = negative_weighted.sum(dim=2) # Shape: (num_smiles, num_classes)

# Determine which classes to include for each SMILES
net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes)

# Convert tensor decisions to result list using list comprehension for efficiency
result = [
[class_indices[idx.item()] for idx in torch.nonzero(class_decisions[i], as_tuple=True)[0]]
for i in range(num_smiles)
]

return result


def calculate_classwise_weights(self, predicted_classes):
"""No weights, simple majority voting"""
positive_weights = torch.ones(len(predicted_classes), len(self.models))
negative_weights = torch.ones(len(predicted_classes), len(self.models))

return positive_weights, negative_weights

def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt"
predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt"
if not load_preds_if_possible or not os.path.isfile(preds_file):
ordered_predictions = predicted_classes = self.gather_predictions(smiles_list)
# save predictions
torch.save(ordered_predictions, preds_file)
with open(predicted_classes_file, "w") as f:
for cls in predicted_classes:
f.write(f"{cls}\n")
else:
print(f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}")
ordered_predictions = torch.load(preds_file)
with open(predicted_classes_file, "r") as f:
predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())}

classwise_weights = self.calculate_classwise_weights(predicted_classes)
aggregated_predictions = self.consolidate_predictions(ordered_predictions, predicted_classes, classwise_weights)
return aggregated_predictions
54 changes: 54 additions & 0 deletions chebifier/ensemble/weighted_majority_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

from chebifier.ensemble.base_ensemble import BaseEnsemble



class WMVwithPPVNPVEnsemble(BaseEnsemble):

def calculate_classwise_weights(self, predicted_classes):
"""
Given the positions of predicted classes in the predictions tensor, assign weights to each class. The
result is two tensors of shape (num_predicted_classes, num_models). The weight for each class is the model_weight
(default: 1) multiplied by the class-specific positive / negative weight (default 1).
"""
positive_weights = torch.ones(len(predicted_classes), len(self.models))
negative_weights = torch.ones(len(predicted_classes), len(self.models))
for j, model in enumerate(self.models):
positive_weights[:, j] *= model.model_weight
negative_weights[:, j] *= model.model_weight
if model.classwise_weights is None:
continue
for cls, weights in model.classwise_weights.items():
positive_weights[predicted_classes[cls], j] *= weights["PPV"]
negative_weights[predicted_classes[cls], j] *= weights["NPV"]

print(f"Calculated model weightings. The averages for positive / negative weights are:")
for i, model in enumerate(self.models):
print(f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}")

return positive_weights, negative_weights


class WMVwithF1Ensemble(BaseEnsemble):

def calculate_classwise_weights(self, predicted_classes):
"""
Given the positions of predicted classes in the predictions tensor, assign weights to each class. The
result is two tensors of shape (num_predicted_classes, num_models). The weight for each class is the model_weight
(default: 1) multiplied by the class-specific validation-f1 (default 1).
"""
weights_by_cls = torch.ones(len(predicted_classes), len(self.models))
for j, model in enumerate(self.models):
weights_by_cls[:, j] *= model.model_weight
if model.classwise_weights is None:
continue
for cls, weights in model.classwise_weights.items():
f1 = 2 * weights["TP"] / (2 * weights["TP"] + weights["FP"] + weights["FN"])
weights_by_cls[predicted_classes[cls], j] *= f1

print(f"Calculated model weightings. The average weights are:")
for i, model in enumerate(self.models):
print(f"{model.model_name}: {weights_by_cls[:, i].mean().item():.3f}")

return weights_by_cls, weights_by_cls
Empty file.
16 changes: 16 additions & 0 deletions chebifier/prediction_models/base_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC
import json

class BasePredictor(ABC):

def __init__(self, model_name: str, model_weight: int = 1, classwise_weights_path: str = None, **kwargs):
self.model_name = model_name
self.model_weight = model_weight
if classwise_weights_path is not None:
self.classwise_weights = json.load(open(classwise_weights_path, encoding="utf-8"))
else:
self.classwise_weights = None


def predict_smiles_list(self, smiles_list: list[str]) -> dict:
raise NotImplementedError
Loading