Skip to content

Commit f79c959

Browse files
authored
Merge pull request #19 from ChEB-AI/feature/ensemble-update
Ensemble update
2 parents 6aed7a1 + 6b1b44b commit f79c959

File tree

11 files changed

+344
-121
lines changed

11 files changed

+344
-121
lines changed

README.md

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
An AI ensemble model for predicting chemical classes in the ChEBI ontology. It integrates deep learning models,
33
rule-based models and generative AI-based models.
44

5-
A web application for the ensemble is available at https://chebifier.hastingslab.org/.
5+
A web application for Chebifier is available at https://chebifier.hastingslab.org/.
66

77
## Installation
88

@@ -38,23 +38,27 @@ The package provides a command-line interface (CLI) for making predictions using
3838
The ensemble configuration is given by a configuration file (by default, this is `chebifier/ensemble.yml`). If you
3939
want to change which models are included in the ensemble or how they are weighted, you can create your own configuration file.
4040

41-
Model weights for deep learning models are automatically downloaded from [Hugging Face](https://huggingface.co/chebai).
42-
To use specific model weights from Hugging face, add the `load_model` key in your configuration file. For example:
41+
Trained deep learning models are automatically downloaded from [Hugging Face](https://huggingface.co/chebai).
42+
To access a model from Hugging face, add the `load_model` key in your configuration file. For example:
4343

4444
```yaml
4545
my_electra:
4646
type: electra
47-
load_model: "electra_chebi50_v241"
47+
load_model: "electra_chebi50-3star_v244"
4848
```
4949
5050
### Available model weights:
5151
52+
* `resgated-aug_chebi50-3star_v244`
53+
* `gat-aug_chebi50_v244`
54+
* `electra_chebi50-3star_v244`
55+
* `gat_chebi50_v244`
5256
* `electra_chebi50_v241`
5357
* `resgated_chebi50_v241`
5458
* `c3p_with_weights`
5559

5660

57-
However, you can also supply your own model checkpoints (see `configs/example_config.yml` for an example).
61+
You can also supply your own model checkpoints (see `configs/example_config.yml` for an example).
5862

5963
```bash
6064
# Make predictions
@@ -72,12 +76,12 @@ python -m chebifier predict --help
7276

7377
### Python API
7478

75-
You can also use the package programmatically:
79+
You can use the package programmatically as well:
7680

7781
```python
7882
from chebifier import BaseEnsemble
7983
80-
# Instantiate ensemble model. If desired, can pass
84+
# Instantiate ensemble model. Optionally, you can pass
8185
# a path to a configuration, like 'configs/example_config.yml'
8286
ensemble = BaseEnsemble()
8387
@@ -100,11 +104,12 @@ Currently, the following models are supported:
100104

101105
| Model | Description | #Classes | Publication | Repository |
102106
|-------|-------------|----------|-----------------------------------------------------------------------|----------------------------------------------------------------------------------------|
103-
| `electra` | A transformer-based deep learning model trained on ChEBI SMILES strings. | 1522 | [Glauer, Martin, et al., 2024: Chebifier: Automating semantic classification in ChEBI to accelerate data-driven discovery, Digital Discovery 3 (2024) 896-907](https://pubs.rsc.org/en/content/articlehtml/2024/dd/d3dd00238a) | [python-chebai](https://github.com/ChEB-AI/python-chebai) |
104-
| `resgated` | A Residual Gated Graph Convolutional Network trained on ChEBI molecules. | 1522 | | [python-chebai-graph](https://github.com/ChEB-AI/python-chebai-graph) |
107+
| `electra` | A transformer-based deep learning model trained on ChEBI SMILES strings. | 1531* | [Glauer, Martin, et al., 2024: Chebifier: Automating semantic classification in ChEBI to accelerate data-driven discovery, Digital Discovery 3 (2024) 896-907](https://pubs.rsc.org/en/content/articlehtml/2024/dd/d3dd00238a) | [python-chebai](https://github.com/ChEB-AI/python-chebai) |
108+
| `resgated` | A Residual Gated Graph Convolutional Network trained on ChEBI molecules. | 1531* | | [python-chebai-graph](https://github.com/ChEB-AI/python-chebai-graph) |
109+
| `gat` | A Graph Attention Network trained on ChEBI molecules. | 1531* | | [python-chebai-graph](https://github.com/ChEB-AI/python-chebai-graph) |
105110
| `chemlog_peptides` | A rule-based model specialised on peptide classes. | 18 | [Flügel, Simon, et al., 2025: ChemLog: Making MSOL Viable for Ontological Classification and Learning, arXiv](https://arxiv.org/abs/2507.13987) | [chemlog-peptides](https://github.com/sfluegel05/chemlog-peptides) |
106111
| `chemlog_element`, `chemlog_organox` | Extensions of ChemLog for classes that are defined either by the presence of a specific element or by the presence of an organic bond. | 118 + 37 | | [chemlog-extra](https://github.com/ChEB-AI/chemlog-extra) |
107-
| `c3p` | A collection _Chemical Classifier Programs_, generated by LLMs based on the natural language definitions of ChEBI classes. | 338 | [Mungall, Christopher J., et al., 2025: Chemical classification program synthesis using generative artificial intelligence, arXiv](https://arxiv.org/abs/2505.18470) | [c3p](https://github.com/chemkg/c3p) |
112+
| `c3p` | A collection _Chemical Classifier Programs_, generated by LLMs based on the natural language definitions of ChEBI classes. | 338 | [Mungall, Christopher J., et al., 2025: Chemical classification program synthesis using generative artificial intelligence, Journal of Cheminsformatics](https://link.springer.com/article/10.1186/s13321-025-01092-3) | [c3p](https://github.com/chemkg/c3p) |
108113

109114
In addition, Chebifier also includes a ChEBI lookup that automatically retrieves the ChEBI superclasses for a class
110115
matched by a SMILES string. This is not activated by default, but can be included by adding
@@ -116,6 +121,8 @@ chebi_lookup:
116121
to your configuration file.
117122

118123
### The ensemble
124+
For an extended description of the ensemble, see [Flügel, Simon, et al., 2025: Chebifier 2: An Ensemble for Chemistry](https://ceur-ws.org/Vol-4064/SymGenAI4Sci-paper4.pdf).
125+
119126
<img width="700" alt="ensemble_architecture" src="https://github.com/user-attachments/assets/9275d3cd-ac88-466f-a1e9-27d20d67543b" />
120127

121128
Given a sample (i.e., a SMILES string) and models $m_1, m_2, \ldots, m_n$, the ensemble works as follows:
@@ -146,20 +153,18 @@ Therefore, if in doubt, we are more confident in the negative prediction.
146153

147154
Confidence can be disabled by the `use_confidence` parameter of the predict method (default: True).
148155

149-
The model_weight can be set for each model in the configuration file (default: 1). This is used to favor a certain
156+
The`model_weight` can be set for each model in the configuration file (default: 1). This is used to favor a certain
150157
model independently of a given class.
151-
Trust is based on the model's performance on a validation set. After training, we evaluate the Machine Learning models
152-
on a validation set for each class. If the `ensemble_type` is set to `wmv-f1`, the trust is calculated as 1 + the F1 score.
158+
`Trust` is based on the model's performance on a validation set. After training, we evaluate the Machine Learning models
159+
on a validation set for each class. If the `ensemble_type` is set to `wmv-f1`, the trust is calculated as F1-score $^{6.25}$.
153160
If the `ensemble_type` is set to `mv` (the default), the trust is set to 1 for all models.
154161

155162
### Inconsistency resolution
156163
After a decision has been made for each class independently, the consistency of the predictions with regard to the ChEBI hierarchy
157164
and disjointness axioms is checked. This is
158165
done in 3 steps:
159166
- (1) First, the hierarchy is corrected. For each pair of classes $A$ and $B$ where $A$ is a subclass of $B$ (following
160-
the is-a relation in ChEBI), we set the ensemble prediction of $B$ to 1 if the prediction of $A$ is 1. Intuitively
161-
speaking, if we have determined that a molecule belongs to a specific class (e.g., aromatic primary alcohol), it also
162-
belongs to the direct and indirect superclasses (e.g., primary alcohol, aromatic alcohol, alcohol).
167+
the is-a relation in ChEBI), we set the ensemble prediction of $A$ to $0$ if the _absolute value_ of $B$'s score is large than that of $A$. For example, if $A$ has a net score of $3$ and $B$ has a net score of $-4$, the ensemble will set $A$ to $0$ (i.e., predict neither $A$ nor $B$).
163168
- (2) Next, we check for disjointness. This is not specified directly in ChEBI, but in an additional ChEBI module ([chebi-disjoints.owl](https://ftp.ebi.ac.uk/pub/databases/chebi/ontology/)).
164169
We have extracted these disjointness axioms into a CSV file and added some more disjointness axioms ourselves (see
165170
`data>disjoint_chebi.csv` and `data>disjoint_additional.csv`). If two classes $A$ and $B$ are disjoint and we predict

chebifier/cli.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,6 @@ def cli():
3737
default="wmv-f1",
3838
help="Type of ensemble to use (default: Weighted Majority Voting)",
3939
)
40-
@click.option(
41-
"--chebi-version",
42-
"-v",
43-
type=int,
44-
default=241,
45-
help="ChEBI version to use for checking consistency (default: 241)",
46-
)
4740
@click.option(
4841
"--use-confidence",
4942
"-c",
@@ -58,23 +51,31 @@ def cli():
5851
default=True,
5952
help="Resolve inconsistencies in predictions automatically (default: True)",
6053
)
54+
@click.option(
55+
"--verbose",
56+
"-v",
57+
is_flag=True,
58+
default=False,
59+
help="Enable verbose output",
60+
)
6161
def predict(
6262
ensemble_config,
6363
smiles,
6464
smiles_file,
6565
output,
6666
ensemble_type,
67-
chebi_version,
6867
use_confidence,
6968
resolve_inconsistencies=True,
69+
verbose=False,
7070
):
7171
"""Predict ChEBI classes for SMILES strings using an ensemble model."""
7272

7373
# Instantiate ensemble model
7474
ensemble = ENSEMBLES[ensemble_type](
7575
ensemble_config,
76-
chebi_version=chebi_version,
7776
resolve_inconsistencies=resolve_inconsistencies,
77+
verbose_output=verbose,
78+
use_confidence=use_confidence,
7879
)
7980

8081
# Collect SMILES strings from arguments and/or file
@@ -88,9 +89,7 @@ def predict(
8889
return
8990

9091
# Make predictions
91-
predictions = ensemble.predict_smiles_list(
92-
smiles_list, use_confidence=use_confidence
93-
)
92+
predictions = ensemble.predict_smiles_list(smiles_list)
9493

9594
if output:
9695
# save as json

chebifier/ensemble.yml

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
electra:
2-
load_model: electra_chebi50_v241
3-
resgated:
4-
load_model: resgated_chebi50_v241
5-
chemlog_peptides:
6-
type: chemlog_peptides
7-
model_weight: 100
8-
chemlog_element:
9-
type: chemlog_element
10-
model_weight: 100
11-
chemlog_organox:
12-
type: chemlog_organox
1+
electra_chebi50-3star_v244:
2+
load_model: electra_chebi50-3star_v244
3+
gat_chebi50_v244:
4+
load_model: gat_chebi50_v244
5+
gat-aug_chebi50_v244:
6+
load_model: gat-aug_chebi50_v244
7+
resgated-aug_chebi50-3star_v244:
8+
load_model: resgated-aug_chebi50-3star_v244
9+
chemlog:
10+
type: chemlog
1311
model_weight: 100
1412
c3p:
1513
load_model: c3p_with_weights

chebifier/ensemble/base_ensemble.py

Lines changed: 60 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import importlib
2-
import os
32
import time
43
from pathlib import Path
54
from typing import Union
@@ -10,7 +9,7 @@
109

1110
from chebifier.check_env import check_package_installed
1211
from chebifier.hugging_face import download_model_files
13-
from chebifier.inconsistency_resolution import PredictionSmoother
12+
from chebifier.inconsistency_resolution import ScoreBasedPredictionSmoother
1413
from chebifier.prediction_models.base_predictor import BasePredictor
1514
from chebifier.utils import (
1615
get_default_configs,
@@ -24,8 +23,9 @@ class BaseEnsemble:
2423
def __init__(
2524
self,
2625
model_configs: Union[str, Path, dict, None] = None,
27-
chebi_version: int = 241,
2826
resolve_inconsistencies: bool = True,
27+
verbose_output: bool = False,
28+
use_confidence: bool = True,
2929
):
3030
# Deferred Import: To avoid circular import error
3131
from chebifier.model_registry import MODEL_TYPES
@@ -48,6 +48,8 @@ def __init__(
4848
model_registry = yaml.safe_load(f)
4949

5050
processed_configs = process_config(config, model_registry)
51+
self.verbose_output = verbose_output
52+
self.use_confidence = use_confidence
5153

5254
self.chebi_graph = load_chebi_graph()
5355
self.disjoint_files = get_disjoint_files()
@@ -73,10 +75,11 @@ def __init__(
7375
self.models.append(model_instance)
7476

7577
if resolve_inconsistencies:
76-
self.smoother = PredictionSmoother(
78+
self.smoother = ScoreBasedPredictionSmoother(
7779
self.chebi_graph,
7880
label_names=None,
7981
disjoint_files=self.disjoint_files,
82+
verbose=self.verbose_output,
8083
)
8184
else:
8285
self.smoother = None
@@ -92,7 +95,8 @@ def gather_predictions(self, smiles_list):
9295
if logits_for_smiles is not None:
9396
for cls in logits_for_smiles:
9497
predicted_classes.add(cls)
95-
print(f"Sorting predictions from {len(model_predictions)} models...")
98+
if self.verbose_output:
99+
print(f"Sorting predictions from {len(model_predictions)} models...")
96100
predicted_classes = sorted(list(predicted_classes))
97101
predicted_classes_dict = {cls: i for i, cls in enumerate(predicted_classes)}
98102
ordered_logits = (
@@ -114,7 +118,11 @@ def gather_predictions(self, smiles_list):
114118
return ordered_logits, predicted_classes
115119

116120
def consolidate_predictions(
117-
self, predictions, classwise_weights, predicted_classes, **kwargs
121+
self,
122+
predictions,
123+
classwise_weights,
124+
return_intermediate_results=False,
125+
**kwargs,
118126
):
119127
"""
120128
Aggregates predictions from multiple models using weighted majority voting.
@@ -137,7 +145,9 @@ def consolidate_predictions(
137145
predictions < self.positive_prediction_threshold
138146
) & valid_predictions
139147

140-
if "use_confidence" in kwargs and kwargs["use_confidence"]:
148+
# if use_confidence is passed in kwargs, it overrides the ensemble setting
149+
use_confidence = kwargs.get("use_confidence", self.use_confidence)
150+
if use_confidence:
141151
confidence = 2 * torch.abs(
142152
predictions.nan_to_num() - self.positive_prediction_threshold
143153
)
@@ -164,22 +174,39 @@ def consolidate_predictions(
164174

165175
# Determine which classes to include for each SMILES
166176
net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
177+
if return_intermediate_results:
178+
return (
179+
net_score,
180+
has_valid_predictions,
181+
{
182+
"positive_mask": positive_mask,
183+
"negative_mask": negative_mask,
184+
"confidence": confidence,
185+
"positive_sum": positive_sum,
186+
"negative_sum": negative_sum,
187+
},
188+
)
189+
190+
return net_score, has_valid_predictions
167191

192+
def apply_inconsistency_resolution(
193+
self, net_score, class_names, has_valid_predictions
194+
):
168195
# Smooth predictions
169196
start_time = time.perf_counter()
170-
class_names = list(predicted_classes.keys())
171197
if self.smoother is not None:
172198
self.smoother.set_label_names(class_names)
173199
smooth_net_score = self.smoother(net_score)
174200
class_decisions = (
175-
smooth_net_score > 0.5
201+
smooth_net_score > 0
176202
) & has_valid_predictions # Shape: (num_smiles, num_classes)
177203
else:
178204
class_decisions = (
179205
net_score > 0
180206
) & has_valid_predictions # Shape: (num_smiles, num_classes)
181207
end_time = time.perf_counter()
182-
print(f"Prediction smoothing took {end_time - start_time:.2f} seconds")
208+
if self.verbose_output:
209+
print(f"Prediction smoothing took {end_time - start_time:.2f} seconds")
183210

184211
complete_failure = torch.all(~has_valid_predictions, dim=1)
185212
return class_decisions, complete_failure
@@ -192,38 +219,28 @@ def calculate_classwise_weights(self, predicted_classes):
192219
return positive_weights, negative_weights
193220

194221
def predict_smiles_list(
195-
self, smiles_list, load_preds_if_possible=False, **kwargs
222+
self, smiles_list, return_intermediate_results=False, **kwargs
196223
) -> list:
197-
preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt"
198-
predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt"
199-
if not load_preds_if_possible or not os.path.isfile(preds_file):
200-
ordered_predictions, predicted_classes = self.gather_predictions(
201-
smiles_list
202-
)
203-
if len(predicted_classes) == 0:
204-
print(
205-
"Warning: No classes have been predicted for the given SMILES list."
224+
ordered_predictions, predicted_classes = self.gather_predictions(smiles_list)
225+
if len(predicted_classes) == 0:
226+
print("Warning: No classes have been predicted for the given SMILES list.")
227+
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
228+
229+
classwise_weights = self.calculate_classwise_weights(predicted_classes)
230+
if return_intermediate_results:
231+
net_score, has_valid_predictions, intermediate_results_dict = (
232+
self.consolidate_predictions(
233+
ordered_predictions,
234+
classwise_weights,
235+
return_intermediate_results=return_intermediate_results,
206236
)
207-
# save predictions
208-
if load_preds_if_possible:
209-
torch.save(ordered_predictions, preds_file)
210-
with open(predicted_classes_file, "w") as f:
211-
for cls in predicted_classes:
212-
f.write(f"{cls}\n")
213-
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
237+
)
214238
else:
215-
print(
216-
f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}"
239+
net_score, has_valid_predictions = self.consolidate_predictions(
240+
ordered_predictions, classwise_weights
217241
)
218-
ordered_predictions = torch.load(preds_file)
219-
with open(predicted_classes_file, "r") as f:
220-
predicted_classes = {
221-
line.strip(): i for i, line in enumerate(f.readlines())
222-
}
223-
224-
classwise_weights = self.calculate_classwise_weights(predicted_classes)
225-
class_decisions, is_failure = self.consolidate_predictions(
226-
ordered_predictions, classwise_weights, predicted_classes, **kwargs
242+
class_decisions, is_failure = self.apply_inconsistency_resolution(
243+
net_score, list(predicted_classes.keys()), has_valid_predictions
227244
)
228245

229246
class_names = list(predicted_classes.keys())
@@ -239,6 +256,11 @@ def predict_smiles_list(
239256
)
240257
for i, failure in zip(class_decisions, is_failure)
241258
]
259+
if return_intermediate_results:
260+
intermediate_results_dict["predicted_classes"] = predicted_classes
261+
intermediate_results_dict["classwise_weights"] = classwise_weights
262+
intermediate_results_dict["net_score"] = net_score
263+
return result, intermediate_results_dict
242264

243265
return result
244266

0 commit comments

Comments
 (0)