Skip to content

Commit 5055403

Browse files
committed
add alpha / beta parameters for weighting, verbose output, dont save results automatically to file
1 parent 6aed7a1 commit 5055403

File tree

2 files changed

+105
-47
lines changed

2 files changed

+105
-47
lines changed

chebifier/ensemble/base_ensemble.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import importlib
2-
import os
32
import time
43
from pathlib import Path
54
from typing import Union
@@ -24,8 +23,9 @@ class BaseEnsemble:
2423
def __init__(
2524
self,
2625
model_configs: Union[str, Path, dict, None] = None,
27-
chebi_version: int = 241,
2826
resolve_inconsistencies: bool = True,
27+
verbose_output: bool = False,
28+
use_confidence: bool = True,
2929
):
3030
# Deferred Import: To avoid circular import error
3131
from chebifier.model_registry import MODEL_TYPES
@@ -48,6 +48,8 @@ def __init__(
4848
model_registry = yaml.safe_load(f)
4949

5050
processed_configs = process_config(config, model_registry)
51+
self.verbose_output = verbose_output
52+
self.use_confidence = use_confidence
5153

5254
self.chebi_graph = load_chebi_graph()
5355
self.disjoint_files = get_disjoint_files()
@@ -92,7 +94,8 @@ def gather_predictions(self, smiles_list):
9294
if logits_for_smiles is not None:
9395
for cls in logits_for_smiles:
9496
predicted_classes.add(cls)
95-
print(f"Sorting predictions from {len(model_predictions)} models...")
97+
if self.verbose_output:
98+
print(f"Sorting predictions from {len(model_predictions)} models...")
9699
predicted_classes = sorted(list(predicted_classes))
97100
predicted_classes_dict = {cls: i for i, cls in enumerate(predicted_classes)}
98101
ordered_logits = (
@@ -114,7 +117,11 @@ def gather_predictions(self, smiles_list):
114117
return ordered_logits, predicted_classes
115118

116119
def consolidate_predictions(
117-
self, predictions, classwise_weights, predicted_classes, **kwargs
120+
self,
121+
predictions,
122+
classwise_weights,
123+
return_intermediate_results=False,
124+
**kwargs,
118125
):
119126
"""
120127
Aggregates predictions from multiple models using weighted majority voting.
@@ -137,7 +144,9 @@ def consolidate_predictions(
137144
predictions < self.positive_prediction_threshold
138145
) & valid_predictions
139146

140-
if "use_confidence" in kwargs and kwargs["use_confidence"]:
147+
# if use_confidence is passed in kwargs, it overrides the ensemble setting
148+
use_confidence = kwargs.get("use_confidence", self.use_confidence)
149+
if use_confidence:
141150
confidence = 2 * torch.abs(
142151
predictions.nan_to_num() - self.positive_prediction_threshold
143152
)
@@ -164,10 +173,27 @@ def consolidate_predictions(
164173

165174
# Determine which classes to include for each SMILES
166175
net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
176+
if return_intermediate_results:
177+
return (
178+
net_score,
179+
has_valid_predictions,
180+
{
181+
"positive_mask": positive_mask,
182+
"negative_mask": negative_mask,
183+
"confidence": confidence,
184+
"positive_sum": positive_sum,
185+
"negative_sum": negative_sum,
186+
},
187+
)
188+
189+
return net_score, has_valid_predictions
167190

191+
def apply_inconsistency_resolution(
192+
self, net_score, class_names, has_valid_predictions
193+
):
194+
# todo - this could be more elegant
168195
# Smooth predictions
169196
start_time = time.perf_counter()
170-
class_names = list(predicted_classes.keys())
171197
if self.smoother is not None:
172198
self.smoother.set_label_names(class_names)
173199
smooth_net_score = self.smoother(net_score)
@@ -179,7 +205,8 @@ def consolidate_predictions(
179205
net_score > 0
180206
) & has_valid_predictions # Shape: (num_smiles, num_classes)
181207
end_time = time.perf_counter()
182-
print(f"Prediction smoothing took {end_time - start_time:.2f} seconds")
208+
if self.verbose_output:
209+
print(f"Prediction smoothing took {end_time - start_time:.2f} seconds")
183210

184211
complete_failure = torch.all(~has_valid_predictions, dim=1)
185212
return class_decisions, complete_failure
@@ -192,38 +219,28 @@ def calculate_classwise_weights(self, predicted_classes):
192219
return positive_weights, negative_weights
193220

194221
def predict_smiles_list(
195-
self, smiles_list, load_preds_if_possible=False, **kwargs
222+
self, smiles_list, return_intermediate_results=False, **kwargs
196223
) -> list:
197-
preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt"
198-
predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt"
199-
if not load_preds_if_possible or not os.path.isfile(preds_file):
200-
ordered_predictions, predicted_classes = self.gather_predictions(
201-
smiles_list
202-
)
203-
if len(predicted_classes) == 0:
204-
print(
205-
"Warning: No classes have been predicted for the given SMILES list."
224+
ordered_predictions, predicted_classes = self.gather_predictions(smiles_list)
225+
if len(predicted_classes) == 0:
226+
print("Warning: No classes have been predicted for the given SMILES list.")
227+
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
228+
229+
classwise_weights = self.calculate_classwise_weights(predicted_classes)
230+
if return_intermediate_results:
231+
net_score, has_valid_predictions, intermediate_results_dict = (
232+
self.consolidate_predictions(
233+
ordered_predictions,
234+
classwise_weights,
235+
return_intermediate_results=return_intermediate_results,
206236
)
207-
# save predictions
208-
if load_preds_if_possible:
209-
torch.save(ordered_predictions, preds_file)
210-
with open(predicted_classes_file, "w") as f:
211-
for cls in predicted_classes:
212-
f.write(f"{cls}\n")
213-
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
237+
)
214238
else:
215-
print(
216-
f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}"
239+
net_score, has_valid_predictions = self.consolidate_predictions(
240+
ordered_predictions, classwise_weights
217241
)
218-
ordered_predictions = torch.load(preds_file)
219-
with open(predicted_classes_file, "r") as f:
220-
predicted_classes = {
221-
line.strip(): i for i, line in enumerate(f.readlines())
222-
}
223-
224-
classwise_weights = self.calculate_classwise_weights(predicted_classes)
225-
class_decisions, is_failure = self.consolidate_predictions(
226-
ordered_predictions, classwise_weights, predicted_classes, **kwargs
242+
class_decisions, is_failure = self.apply_inconsistency_resolution(
243+
net_score, list(predicted_classes.keys()), has_valid_predictions
227244
)
228245

229246
class_names = list(predicted_classes.keys())
@@ -239,6 +256,11 @@ def predict_smiles_list(
239256
)
240257
for i, failure in zip(class_decisions, is_failure)
241258
]
259+
if return_intermediate_results:
260+
intermediate_results_dict["predicted_classes"] = predicted_classes
261+
intermediate_results_dict["classwise_weights"] = classwise_weights
262+
intermediate_results_dict["net_score"] = net_score
263+
return result, intermediate_results_dict
242264

243265
return result
244266

chebifier/ensemble/weighted_majority_ensemble.py

Lines changed: 48 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,21 @@
44

55

66
class WMVwithPPVNPVEnsemble(BaseEnsemble):
7+
8+
def __init__(
9+
self, config_path=None, weighting_strength=0.5, weighting_exponent=1.0, **kwargs
10+
):
11+
"""WMV ensemble that weights models based on their class-wise positive / negative predictive values. For each class, the weight is calculated as:
12+
weight = weighting_strength * PPV + (1 - weighting_strength)
13+
where PPV is the class-specific positive predictive value of the model on the validation set
14+
or (if the prediction is negative):
15+
weight = weighting_strength * NPV + (1 - weighting_strength)
16+
where NPV is the class-specific negative predictive value of the model on the validation set.
17+
"""
18+
super().__init__(config_path, **kwargs)
19+
self.weighting_strength = weighting_strength
20+
self.weighting_exponent = weighting_exponent
21+
722
def calculate_classwise_weights(self, predicted_classes):
823
"""
924
Given the positions of predicted classes in the predictions tensor, assign weights to each class. The
@@ -18,21 +33,40 @@ def calculate_classwise_weights(self, predicted_classes):
1833
if model.classwise_weights is None:
1934
continue
2035
for cls, weights in model.classwise_weights.items():
21-
positive_weights[predicted_classes[cls], j] *= weights["PPV"]
22-
negative_weights[predicted_classes[cls], j] *= weights["NPV"]
36+
positive_weights[predicted_classes[cls], j] *= (
37+
weights["PPV"] * self.weighting_strength
38+
+ (1 - self.weighting_strength)
39+
) ** self.weighting_exponent
40+
negative_weights[predicted_classes[cls], j] *= (
41+
weights["NPV"] * self.weighting_strength
42+
+ (1 - self.weighting_strength)
43+
) ** self.weighting_exponent
2344

24-
print(
25-
"Calculated model weightings. The averages for positive / negative weights are:"
26-
)
27-
for i, model in enumerate(self.models):
45+
if self.verbose_output:
2846
print(
29-
f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}"
47+
"Calculated model weightings. The averages for positive / negative weights are:"
3048
)
49+
for i, model in enumerate(self.models):
50+
print(
51+
f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}"
52+
)
3153

3254
return positive_weights, negative_weights
3355

3456

3557
class WMVwithF1Ensemble(BaseEnsemble):
58+
59+
def __init__(
60+
self, config_path=None, weighting_strength=0.5, weighting_exponent=1.0, **kwargs
61+
):
62+
"""WMV ensemble that weights models based on their class-wise F1 scores. For each class, the weight is calculated as:
63+
weight = model_weight * (weighting_strength * F1 + (1 - weighting_strength))
64+
where F1 is the class-specific F1 score ("trust") of the model on the validation set.
65+
"""
66+
super().__init__(config_path, **kwargs)
67+
self.weighting_strength = weighting_strength
68+
self.weighting_exponent = weighting_exponent
69+
3670
def calculate_classwise_weights(self, predicted_classes):
3771
"""
3872
Given the positions of predicted classes in the predictions tensor, assign weights to each class. The
@@ -52,10 +86,12 @@ def calculate_classwise_weights(self, predicted_classes):
5286
* weights["TP"]
5387
/ (2 * weights["TP"] + weights["FP"] + weights["FN"])
5488
)
55-
weights_by_cls[predicted_classes[cls], j] *= 1 + f1
56-
57-
print("Calculated model weightings. The average weights are:")
58-
for i, model in enumerate(self.models):
59-
print(f"{model.model_name}: {weights_by_cls[:, i].mean().item():.3f}")
89+
weights_by_cls[predicted_classes[cls], j] *= (
90+
self.weighting_strength * f1 + 1 - self.weighting_strength
91+
) ** self.weighting_exponent
92+
if self.verbose_output:
93+
print("Calculated model weightings. The average weights are:")
94+
for i, model in enumerate(self.models):
95+
print(f"{model.model_name}: {weights_by_cls[:, i].mean().item():.3f}")
6096

6197
return weights_by_cls, weights_by_cls

0 commit comments

Comments
 (0)