Skip to content

Commit a48b659

Browse files
committed
update model registry and example config files
1 parent 71d2507 commit a48b659

File tree

4 files changed

+112
-27
lines changed

4 files changed

+112
-27
lines changed

chebifier/cli.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,6 @@ def cli():
3737
default="wmv-f1",
3838
help="Type of ensemble to use (default: Weighted Majority Voting)",
3939
)
40-
@click.option(
41-
"--chebi-version",
42-
"-v",
43-
type=int,
44-
default=241,
45-
help="ChEBI version to use for checking consistency (default: 241)",
46-
)
4740
@click.option(
4841
"--use-confidence",
4942
"-c",
@@ -58,23 +51,31 @@ def cli():
5851
default=True,
5952
help="Resolve inconsistencies in predictions automatically (default: True)",
6053
)
54+
@click.option(
55+
"--verbose",
56+
"-v",
57+
is_flag=True,
58+
default=False,
59+
help="Enable verbose output",
60+
)
6161
def predict(
6262
ensemble_config,
6363
smiles,
6464
smiles_file,
6565
output,
6666
ensemble_type,
67-
chebi_version,
6867
use_confidence,
6968
resolve_inconsistencies=True,
69+
verbose=False,
7070
):
7171
"""Predict ChEBI classes for SMILES strings using an ensemble model."""
7272

7373
# Instantiate ensemble model
7474
ensemble = ENSEMBLES[ensemble_type](
7575
ensemble_config,
76-
chebi_version=chebi_version,
7776
resolve_inconsistencies=resolve_inconsistencies,
77+
verbose_output=verbose,
78+
use_confidence=use_confidence,
7879
)
7980

8081
# Collect SMILES strings from arguments and/or file
@@ -88,9 +89,7 @@ def predict(
8889
return
8990

9091
# Make predictions
91-
predictions = ensemble.predict_smiles_list(
92-
smiles_list, use_confidence=use_confidence
93-
)
92+
predictions = ensemble.predict_smiles_list(smiles_list)
9493

9594
if output:
9695
# save as json

chebifier/ensemble.yml

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
electra:
2-
load_model: electra_chebi50_v241
3-
resgated:
4-
load_model: resgated_chebi50_v241
5-
chemlog_peptides:
6-
type: chemlog_peptides
7-
model_weight: 100
8-
chemlog_element:
9-
type: chemlog_element
10-
model_weight: 100
11-
chemlog_organox:
12-
type: chemlog_organox
1+
electra_chebi50-3star_v244:
2+
load_model: electra_chebi50-3star_v244
3+
gat_chebi50_v244:
4+
load_model: gat_chebi50_v244
5+
gat-aug_chebi50_v244:
6+
load_model: gat-aug_chebi50_v244
7+
resgated-aug_chebi50-3star_v244:
8+
load_model: resgated-aug_chebi50-3star_v244
9+
chemlog:
10+
type: chemlog
1311
model_weight: 100
1412
c3p:
1513
load_model: c3p_with_weights

chebifier/ensemble/weighted_majority_ensemble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
class WMVwithPPVNPVEnsemble(BaseEnsemble):
77

88
def __init__(
9-
self, config_path=None, weighting_strength=0.5, weighting_exponent=1.0, **kwargs
9+
self, config_path=None, weighting_strength=1, weighting_exponent=1, **kwargs
1010
):
1111
"""WMV ensemble that weights models based on their class-wise positive / negative predictive values. For each class, the weight is calculated as:
1212
weight = (weighting_strength * PPV + (1 - weighting_strength)) ** weighting_exponent
@@ -57,7 +57,7 @@ def calculate_classwise_weights(self, predicted_classes):
5757
class WMVwithF1Ensemble(BaseEnsemble):
5858

5959
def __init__(
60-
self, config_path=None, weighting_strength=0.5, weighting_exponent=1.0, **kwargs
60+
self, config_path=None, weighting_strength=1, weighting_exponent=6.25, **kwargs
6161
):
6262
"""WMV ensemble that weights models based on their class-wise F1 scores. For each class, the weight is calculated as:
6363
weight = model_weight * (weighting_strength * F1 + (1 - weighting_strength)) ** weighting_exponent

chebifier/model_registry.yml

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,91 @@
1+
electra_chebi50-3star_v244:
2+
type: electra
3+
hugging_face:
4+
repo_id: chebai/electra_chebi50-3star_v244
5+
files:
6+
ckpt_path: electra_chebi50-3star_v244_x2mngani_epoch=180.ckpt
7+
target_labels_path: classes.txt
8+
classwise_weights_path: electra_chebi50-3star_v244_x2mngani_epoch=180_trust_3star.json
9+
gat_chebi50_v244:
10+
type: gat
11+
hugging_face:
12+
repo_id: chebai/gat_chebi50_v244
13+
files:
14+
ckpt_path: gat_chebi50_v244_0nfi19qt_epoch=198.ckpt
15+
target_labels_path: classes.txt
16+
classwise_weights_path: gat_chebi50_v244_0nfi19qt_epoch=198_trust_3star.json
17+
dataset_cls: chebai_graph.preprocessing.datasets.ChEBI50GraphProperties
18+
molecular_properties:
19+
- chebai_graph.preprocessing.properties.AtomType
20+
- chebai_graph.preprocessing.properties.NumAtomBonds
21+
- chebai_graph.preprocessing.properties.AtomCharge
22+
- chebai_graph.preprocessing.properties.AtomAromaticity
23+
- chebai_graph.preprocessing.properties.AtomHybridization
24+
- chebai_graph.preprocessing.properties.AtomNumHs
25+
- chebai_graph.preprocessing.properties.BondType
26+
- chebai_graph.preprocessing.properties.BondInRing
27+
- chebai_graph.preprocessing.properties.BondAromaticity
28+
- chebai_graph.preprocessing.properties.RDKit2DNormalized
29+
gat-aug_chebi50_v244:
30+
type: gat
31+
hugging_face:
32+
repo_id: chebai/gat-aug_chebi50_v244
33+
files:
34+
ckpt_path: gat-aug_chebi50_v244_8fky8tru_epoch=192.ckpt
35+
target_labels_path: classes.txt
36+
classwise_weights_path: gat-aug_chebi50_v244_8fky8tru_epoch=192_trust_3star.json
37+
dataset_cls: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_AsPerNodeType
38+
molecular_properties:
39+
- chebai_graph.preprocessing.properties.AtomNodeLevel
40+
# Atom Node type properties
41+
- chebai_graph.preprocessing.properties.AugAtomAromaticity
42+
- chebai_graph.preprocessing.properties.AugAtomCharge
43+
- chebai_graph.preprocessing.properties.AugAtomHybridization
44+
- chebai_graph.preprocessing.properties.AugAtomNumHs
45+
- chebai_graph.preprocessing.properties.AugAtomType
46+
- chebai_graph.preprocessing.properties.AugNumAtomBonds
47+
# FG Node type properties
48+
- chebai_graph.preprocessing.properties.AtomFunctionalGroup
49+
- chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG
50+
- chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG
51+
- chebai_graph.preprocessing.properties.IsFGAlkyl
52+
# Graph Node type properties
53+
- chebai_graph.preprocessing.properties.AugRDKit2DNormalized
54+
# Bond properties
55+
- chebai_graph.preprocessing.properties.BondLevel
56+
- chebai_graph.preprocessing.properties.AugBondAromaticity
57+
- chebai_graph.preprocessing.properties.AugBondInRing
58+
- chebai_graph.preprocessing.properties.AugBondType
59+
resgated-aug_chebi50-3star_v244:
60+
type: resgated
61+
hugging_face:
62+
repo_id: chebai/resgated-aug_chebi50-3star_v244
63+
files:
64+
ckpt_path: resgated-aug_chebi50-3star_v244_w0rhmajx_epoch=190.ckpt
65+
target_labels_path: classes.txt
66+
classwise_weights_path: resgated-aug_chebi50-3star_v244_w0rhmajx_epoch=190_trust_3star.json
67+
dataset_cls: chebai_graph.preprocessing.datasets.ChEBI50_WFGE_WGN_AsPerNodeType
68+
molecular_properties:
69+
- chebai_graph.preprocessing.properties.AtomNodeLevel
70+
# Atom Node type properties
71+
- chebai_graph.preprocessing.properties.AugAtomAromaticity
72+
- chebai_graph.preprocessing.properties.AugAtomCharge
73+
- chebai_graph.preprocessing.properties.AugAtomHybridization
74+
- chebai_graph.preprocessing.properties.AugAtomNumHs
75+
- chebai_graph.preprocessing.properties.AugAtomType
76+
- chebai_graph.preprocessing.properties.AugNumAtomBonds
77+
# FG Node type properties
78+
- chebai_graph.preprocessing.properties.AtomFunctionalGroup
79+
- chebai_graph.preprocessing.properties.IsHydrogenBondDonorFG
80+
- chebai_graph.preprocessing.properties.IsHydrogenBondAcceptorFG
81+
- chebai_graph.preprocessing.properties.IsFGAlkyl
82+
# Graph Node type properties
83+
- chebai_graph.preprocessing.properties.AugRDKit2DNormalized
84+
# Bond properties
85+
- chebai_graph.preprocessing.properties.BondLevel
86+
- chebai_graph.preprocessing.properties.AugBondAromaticity
87+
- chebai_graph.preprocessing.properties.AugBondInRing
88+
- chebai_graph.preprocessing.properties.AugBondType
189
electra_chebi50_v241:
290
type: electra
391
hugging_face:
@@ -31,4 +119,4 @@ c3p_with_weights:
31119
repo_id: chebai/chebifier
32120
repo_type: dataset
33121
files:
34-
classwise_weights_path: c3p_trust.json
122+
classwise_weights_path: c3p_trust.json

0 commit comments

Comments
 (0)