Skip to content

Commit 219b73f

Browse files
committed
add single chemlog model
1 parent 5055403 commit 219b73f

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

chebifier/model_registry.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
)
1212
from chebifier.prediction_models.c3p_predictor import C3PPredictor
1313
from chebifier.prediction_models.chemlog_predictor import (
14+
ChemlogAllPredictor,
1415
ChemlogOrganoXCompoundPredictor,
1516
ChemlogXMolecularEntityPredictor,
1617
)
@@ -27,6 +28,7 @@
2728
"electra": ElectraPredictor,
2829
"resgated": ResGatedPredictor,
2930
"gat": GATPredictor,
31+
"chemlog": ChemlogAllPredictor,
3032
"chemlog_peptides": ChemlogPeptidesPredictor,
3133
"chebi_lookup": ChEBILookupPredictor,
3234
"chemlog_element": ChemlogXMolecularEntityPredictor,

chebifier/prediction_models/chemlog_predictor.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,31 @@
3232
}
3333

3434

35+
class ChemlogAllPredictor(BasePredictor):
36+
def __init__(self, model_name: str, **kwargs):
37+
super().__init__(model_name, **kwargs)
38+
self.chebi_graph = kwargs.get("chebi_graph", None)
39+
self.predictors = [
40+
ChemlogXMolecularEntityPredictor("chemlog_x_molecular_entity", **kwargs),
41+
ChemlogOrganoXCompoundPredictor("chemlog_organo_x_compound", **kwargs),
42+
ChemlogPeptidesPredictor("chemlog_peptides", **kwargs),
43+
]
44+
45+
@modelwise_smiles_lru_cache.batch_decorator
46+
def predict_smiles_list(self, smiles_list: list[str]) -> list:
47+
results = []
48+
for predictor in self.predictors:
49+
predictor_results = predictor._predict_smiles_list(smiles_list)
50+
for i, res in enumerate(predictor_results):
51+
if i >= len(results):
52+
results.append(dict())
53+
results[i].update(res)
54+
return results
55+
56+
def explain_smiles(self, smiles):
57+
return self.predictors[2].explain_smiles(smiles)
58+
59+
3560
class ChemlogExtraPredictor(BasePredictor):
3661

3762
def __init__(self, model_name: str, **kwargs):
@@ -41,6 +66,9 @@ def __init__(self, model_name: str, **kwargs):
4166

4267
@modelwise_smiles_lru_cache.batch_decorator
4368
def predict_smiles_list(self, smiles_list: list[str]) -> list:
69+
return self._predict_smiles_list(smiles_list)
70+
71+
def _predict_smiles_list(self, smiles_list: list[str]) -> list:
4472
from chemlog.cli import _smiles_to_mol
4573

4674
mol_list = [_smiles_to_mol(smiles) for smiles in smiles_list]
@@ -124,6 +152,9 @@ def predict_smiles(self, smiles: str) -> Optional[dict]:
124152

125153
@modelwise_smiles_lru_cache.batch_decorator
126154
def predict_smiles_list(self, smiles_list: list[str]) -> list:
155+
return self._predict_smiles_list(smiles_list)
156+
157+
def _predict_smiles_list(self, smiles_list: list[str]) -> list:
127158
results = []
128159
for i, smiles in tqdm.tqdm(enumerate(smiles_list)):
129160
results.append(self.predict_smiles(smiles))

0 commit comments

Comments
 (0)