4
4
5
5
import logging
6
6
from copy import deepcopy
7
- from functools import lru_cache
7
+ from functools import cache
8
+ from math import isfinite
8
9
9
- import numpy as np
10
10
from rdkit import ForceField # noqa: F401
11
11
from rdkit .Chem .inchi import InchiReadWriteError , MolFromInchi
12
12
from rdkit .Chem .rdchem import Mol
23
23
24
24
logger = logging .getLogger (__name__ )
25
25
26
+
26
27
_warning_prefix = "WARNING: Energy ratio module "
27
28
_empty_results = {
28
29
"results" : {
29
- "ensemble_avg_energy" : np . nan ,
30
- "mol_pred_energy" : np . nan ,
31
- "energy_ratio" : np . nan ,
32
- "energy_ratio_passes" : np . nan ,
30
+ "ensemble_avg_energy" : float ( " nan" ) ,
31
+ "mol_pred_energy" : float ( " nan" ) ,
32
+ "energy_ratio" : float ( " nan" ) ,
33
+ "energy_ratio_passes" : float ( " nan" ) ,
33
34
}
34
35
}
35
36
@@ -39,7 +40,8 @@ def check_energy_ratio(
39
40
threshold_energy_ratio : float = 7.0 ,
40
41
ensemble_number_conformations : int = 100 ,
41
42
inchi_strict : bool = False ,
42
- ):
43
+ epsilon = 1e-10 ,
44
+ ) -> dict [str , dict [str , float | bool ]]:
43
45
"""Check whether the energy of the docked ligand is within user defined range.
44
46
45
47
Args:
@@ -72,40 +74,65 @@ def check_energy_ratio(
72
74
return _empty_results
73
75
74
76
try :
75
- conf_energy = get_conf_energy (mol_pred )
77
+ observed_energy = get_conf_energy (mol_pred )
76
78
except Exception as e :
77
79
logger .warning (_warning_prefix + "failed to calculate conformation energy for %s: %s" , inchi , e )
78
- conf_energy = np . nan
80
+ observed_energy = float ( " nan" )
79
81
80
82
try :
81
- avg_energy = float (get_average_energy (inchi , ensemble_number_conformations ))
83
+ energies = get_energies (inchi , ensemble_number_conformations )
84
+ mean_energy = sum (energies ) / len (energies )
85
+ std_energy = sum ((energy - mean_energy ) ** 2 for energy in energies ) / (len (energies ) - 1 )
86
+ std_energy = max (epsilon , std_energy ) # clipping
82
87
except Exception as e :
83
88
logger .warning (_warning_prefix + "failed to calculate ensemble conformation energy for %s: %s" , inchi , e )
84
- avg_energy = np .nan
89
+ mean_energy = float ("nan" )
90
+ std_energy = float ("nan" )
85
91
86
- if avg_energy == 0 :
92
+ if mean_energy == 0 :
87
93
logger .warning (_warning_prefix + "calculated average energy of molecule 0 for %s" , inchi )
88
- avg_energy = np .nan
94
+ mean_energy = epsilon # clipping
95
+
96
+ # simple ratio
97
+ ratio = observed_energy / mean_energy
98
+ ratio_passes = ratio <= threshold_energy_ratio if isfinite (ratio ) else float ("nan" )
89
99
90
- pred_factor = conf_energy / avg_energy
91
- ratio_passes = pred_factor <= threshold_energy_ratio
100
+ # ratio after subtracting mean
101
+ deviation = observed_energy - mean_energy
102
+ relative_deviation = deviation / mean_energy
103
+ relative_deviation_passes = (
104
+ relative_deviation <= threshold_energy_ratio if isfinite (relative_deviation ) else float ("nan" )
105
+ )
106
+
107
+ # standard score (ratio after subtracting by population mean and dividing by population std)
108
+ z_value = (observed_energy - mean_energy ) / std_energy
109
+ z_value_passes = z_value <= threshold_energy_ratio if isfinite (z_value ) else float ("nan" )
92
110
93
111
results = {
94
- "ensemble_avg_energy" : avg_energy ,
95
- "mol_pred_energy" : conf_energy ,
96
- "energy_ratio" : pred_factor ,
112
+ "ensemble_avg_energy" : mean_energy ,
113
+ "mol_pred_energy" : observed_energy ,
114
+ "energy_ratio" : ratio ,
115
+ "relative_deviation" : relative_deviation ,
116
+ "z_value" : z_value ,
97
117
"energy_ratio_passes" : ratio_passes ,
118
+ "relative_deviation_passes" : relative_deviation_passes ,
119
+ "z_value_passes" : z_value_passes ,
98
120
}
99
121
return {"results" : results }
100
122
101
123
102
- @lru_cache (maxsize = None )
103
124
def get_average_energy (inchi : str , n_confs : int = 50 , num_threads : int = 0 ) -> Mol :
104
125
"""Get average energy of an ensemble of molecule conformations."""
126
+ energies = get_energies (inchi , n_confs , num_threads )
127
+ return sum (energies ) / len (energies )
128
+
129
+
130
+ @cache
131
+ def get_energies (inchi : str , n_confs : int = 50 , num_threads : int = 0 ) -> list [float ]:
132
+ """Get energies of an ensemble of molecule conformations."""
105
133
with CaptureLogger ():
106
134
mol = MolFromInchi (inchi )
107
- energies = new_conformation (mol , n_confs , num_threads )["energies" ]
108
- return sum (energies ) / len (energies )
135
+ return new_conformation (mol , n_confs , num_threads )["energies" ]
109
136
110
137
111
138
def new_conformation (mol : Mol , n_confs : int = 1 , num_threads : int = 0 , energy_minimization = True ) -> Mol :
0 commit comments