11import importlib
2- import os
32import time
43from pathlib import Path
54from typing import Union
@@ -24,8 +23,9 @@ class BaseEnsemble:
2423 def __init__ (
2524 self ,
2625 model_configs : Union [str , Path , dict , None ] = None ,
27- chebi_version : int = 241 ,
2826 resolve_inconsistencies : bool = True ,
27+ verbose_output : bool = False ,
28+ use_confidence : bool = True ,
2929 ):
3030 # Deferred Import: To avoid circular import error
3131 from chebifier .model_registry import MODEL_TYPES
@@ -48,6 +48,8 @@ def __init__(
4848 model_registry = yaml .safe_load (f )
4949
5050 processed_configs = process_config (config , model_registry )
51+ self .verbose_output = verbose_output
52+ self .use_confidence = use_confidence
5153
5254 self .chebi_graph = load_chebi_graph ()
5355 self .disjoint_files = get_disjoint_files ()
@@ -92,7 +94,8 @@ def gather_predictions(self, smiles_list):
9294 if logits_for_smiles is not None :
9395 for cls in logits_for_smiles :
9496 predicted_classes .add (cls )
95- print (f"Sorting predictions from { len (model_predictions )} models..." )
97+ if self .verbose_output :
98+ print (f"Sorting predictions from { len (model_predictions )} models..." )
9699 predicted_classes = sorted (list (predicted_classes ))
97100 predicted_classes_dict = {cls : i for i , cls in enumerate (predicted_classes )}
98101 ordered_logits = (
@@ -114,7 +117,11 @@ def gather_predictions(self, smiles_list):
114117 return ordered_logits , predicted_classes
115118
116119 def consolidate_predictions (
117- self , predictions , classwise_weights , predicted_classes , ** kwargs
120+ self ,
121+ predictions ,
122+ classwise_weights ,
123+ return_intermediate_results = False ,
124+ ** kwargs ,
118125 ):
119126 """
120127 Aggregates predictions from multiple models using weighted majority voting.
@@ -137,7 +144,9 @@ def consolidate_predictions(
137144 predictions < self .positive_prediction_threshold
138145 ) & valid_predictions
139146
140- if "use_confidence" in kwargs and kwargs ["use_confidence" ]:
147+ # if use_confidence is passed in kwargs, it overrides the ensemble setting
148+ use_confidence = kwargs .get ("use_confidence" , self .use_confidence )
149+ if use_confidence :
141150 confidence = 2 * torch .abs (
142151 predictions .nan_to_num () - self .positive_prediction_threshold
143152 )
@@ -164,10 +173,27 @@ def consolidate_predictions(
164173
165174 # Determine which classes to include for each SMILES
166175 net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
176+ if return_intermediate_results :
177+ return (
178+ net_score ,
179+ has_valid_predictions ,
180+ {
181+ "positive_mask" : positive_mask ,
182+ "negative_mask" : negative_mask ,
183+ "confidence" : confidence ,
184+ "positive_sum" : positive_sum ,
185+ "negative_sum" : negative_sum ,
186+ },
187+ )
188+
189+ return net_score , has_valid_predictions
167190
191+ def apply_inconsistency_resolution (
192+ self , net_score , class_names , has_valid_predictions
193+ ):
194+ # todo - this could be more elegant
168195 # Smooth predictions
169196 start_time = time .perf_counter ()
170- class_names = list (predicted_classes .keys ())
171197 if self .smoother is not None :
172198 self .smoother .set_label_names (class_names )
173199 smooth_net_score = self .smoother (net_score )
@@ -179,7 +205,8 @@ def consolidate_predictions(
179205 net_score > 0
180206 ) & has_valid_predictions # Shape: (num_smiles, num_classes)
181207 end_time = time .perf_counter ()
182- print (f"Prediction smoothing took { end_time - start_time :.2f} seconds" )
208+ if self .verbose_output :
209+ print (f"Prediction smoothing took { end_time - start_time :.2f} seconds" )
183210
184211 complete_failure = torch .all (~ has_valid_predictions , dim = 1 )
185212 return class_decisions , complete_failure
@@ -192,38 +219,28 @@ def calculate_classwise_weights(self, predicted_classes):
192219 return positive_weights , negative_weights
193220
194221 def predict_smiles_list (
195- self , smiles_list , load_preds_if_possible = False , ** kwargs
222+ self , smiles_list , return_intermediate_results = False , ** kwargs
196223 ) -> list :
197- preds_file = f"predictions_by_model_{ '_' .join (model .model_name for model in self .models )} .pt"
198- predicted_classes_file = f"predicted_classes_{ '_' .join (model .model_name for model in self .models )} .txt"
199- if not load_preds_if_possible or not os .path .isfile (preds_file ):
200- ordered_predictions , predicted_classes = self .gather_predictions (
201- smiles_list
202- )
203- if len (predicted_classes ) == 0 :
204- print (
205- "Warning: No classes have been predicted for the given SMILES list."
224+ ordered_predictions , predicted_classes = self .gather_predictions (smiles_list )
225+ if len (predicted_classes ) == 0 :
226+ print ("Warning: No classes have been predicted for the given SMILES list." )
227+ predicted_classes = {cls : i for i , cls in enumerate (predicted_classes )}
228+
229+ classwise_weights = self .calculate_classwise_weights (predicted_classes )
230+ if return_intermediate_results :
231+ net_score , has_valid_predictions , intermediate_results_dict = (
232+ self .consolidate_predictions (
233+ ordered_predictions ,
234+ classwise_weights ,
235+ return_intermediate_results = return_intermediate_results ,
206236 )
207- # save predictions
208- if load_preds_if_possible :
209- torch .save (ordered_predictions , preds_file )
210- with open (predicted_classes_file , "w" ) as f :
211- for cls in predicted_classes :
212- f .write (f"{ cls } \n " )
213- predicted_classes = {cls : i for i , cls in enumerate (predicted_classes )}
237+ )
214238 else :
215- print (
216- f"Loading predictions from { preds_file } and label indexes from { predicted_classes_file } "
239+ net_score , has_valid_predictions = self . consolidate_predictions (
240+ ordered_predictions , classwise_weights
217241 )
218- ordered_predictions = torch .load (preds_file )
219- with open (predicted_classes_file , "r" ) as f :
220- predicted_classes = {
221- line .strip (): i for i , line in enumerate (f .readlines ())
222- }
223-
224- classwise_weights = self .calculate_classwise_weights (predicted_classes )
225- class_decisions , is_failure = self .consolidate_predictions (
226- ordered_predictions , classwise_weights , predicted_classes , ** kwargs
242+ class_decisions , is_failure = self .apply_inconsistency_resolution (
243+ net_score , list (predicted_classes .keys ()), has_valid_predictions
227244 )
228245
229246 class_names = list (predicted_classes .keys ())
@@ -239,6 +256,11 @@ def predict_smiles_list(
239256 )
240257 for i , failure in zip (class_decisions , is_failure )
241258 ]
259+ if return_intermediate_results :
260+ intermediate_results_dict ["predicted_classes" ] = predicted_classes
261+ intermediate_results_dict ["classwise_weights" ] = classwise_weights
262+ intermediate_results_dict ["net_score" ] = net_score
263+ return result , intermediate_results_dict
242264
243265 return result
244266
0 commit comments