|
| 1 | +import os |
| 2 | +from abc import ABC |
| 3 | +import torch |
| 4 | +import tqdm |
| 5 | +from rdkit import Chem |
| 6 | + |
| 7 | +from chebifier.prediction_models.base_predictor import BasePredictor |
| 8 | +from chebifier.prediction_models.chemlog_predictor import ChemLogPredictor |
| 9 | +from chebifier.prediction_models.electra_predictor import ElectraPredictor |
| 10 | +from chebifier.prediction_models.gnn_predictor import ResGatedPredictor |
| 11 | + |
| 12 | +MODEL_TYPES = { |
| 13 | + "electra": ElectraPredictor, |
| 14 | + "resgated": ResGatedPredictor, |
| 15 | + "chemlog": ChemLogPredictor |
| 16 | +} |
| 17 | + |
| 18 | +class BaseEnsemble(ABC): |
| 19 | + |
| 20 | + def __init__(self, model_configs: dict): |
| 21 | + self.models = [] |
| 22 | + self.positive_prediction_threshold = 0.5 |
| 23 | + for model_name, model_config in model_configs.items(): |
| 24 | + model_cls = MODEL_TYPES[model_config["type"]] |
| 25 | + model_instance = model_cls(**model_config) |
| 26 | + assert isinstance(model_instance, BasePredictor) |
| 27 | + self.models.append(model_instance) |
| 28 | + |
| 29 | + def gather_predictions(self, smiles_list): |
| 30 | + # get predictions from all models for the SMILES list |
| 31 | + # order them by alphabetically by label class |
| 32 | + model_predictions = [] |
| 33 | + predicted_classes = set() |
| 34 | + for model in self.models: |
| 35 | + model_predictions.append(model.predict_smiles_list(smiles_list)) |
| 36 | + for logits_for_smiles in model_predictions[-1]: |
| 37 | + if logits_for_smiles is not None: |
| 38 | + for cls in logits_for_smiles: |
| 39 | + predicted_classes.add(cls) |
| 40 | + print(f"Sorting predictions...") |
| 41 | + predicted_classes = sorted(list(predicted_classes)) |
| 42 | + predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)} |
| 43 | + ordered_logits = torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) * torch.nan |
| 44 | + for i, model_prediction in enumerate(model_predictions): |
| 45 | + for j, logits_for_smiles in tqdm.tqdm(enumerate(model_prediction), |
| 46 | + total=len(model_prediction), |
| 47 | + desc=f"Sorting predictions for {self.models[i].model_name}"): |
| 48 | + if logits_for_smiles is not None: |
| 49 | + for cls in logits_for_smiles: |
| 50 | + ordered_logits[j, predicted_classes[cls], i] = logits_for_smiles[cls] |
| 51 | + |
| 52 | + return ordered_logits, predicted_classes |
| 53 | + |
| 54 | + |
| 55 | + def consolidate_predictions(self, predictions, predicted_classes, classwise_weights, **kwargs): |
| 56 | + """ |
| 57 | + Aggregates predictions from multiple models using weighted majority voting. |
| 58 | + Optimized version using tensor operations instead of for loops. |
| 59 | + """ |
| 60 | + num_smiles, num_classes, num_models = predictions.shape |
| 61 | + |
| 62 | + # Create a mapping from class indices to class names for faster lookup |
| 63 | + class_names = list(predicted_classes.keys()) |
| 64 | + class_indices = {predicted_classes[cls]: cls for cls in class_names} |
| 65 | + |
| 66 | + # Get predictions for all classes |
| 67 | + valid_predictions = ~torch.isnan(predictions) |
| 68 | + valid_counts = valid_predictions.sum(dim=2) # Sum over models dimension |
| 69 | + |
| 70 | + # Skip classes with no valid predictions |
| 71 | + has_valid_predictions = valid_counts > 0 |
| 72 | + |
| 73 | + # Calculate positive and negative predictions for all classes at once |
| 74 | + positive_mask = (predictions > 0.5) & valid_predictions |
| 75 | + negative_mask = (predictions < 0.5) & valid_predictions |
| 76 | + |
| 77 | + confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold) |
| 78 | + |
| 79 | + # Extract positive and negative weights |
| 80 | + pos_weights = classwise_weights[0] # Shape: (num_classes, num_models) |
| 81 | + neg_weights = classwise_weights[1] # Shape: (num_classes, num_models) |
| 82 | + |
| 83 | + # Calculate weighted predictions using broadcasting |
| 84 | + # predictions shape: (num_smiles, num_classes, num_models) |
| 85 | + # weights shape: (num_classes, num_models) |
| 86 | + positive_weighted = positive_mask.float() * confidence * pos_weights.unsqueeze(0) |
| 87 | + negative_weighted = negative_mask.float() * confidence * neg_weights.unsqueeze(0) |
| 88 | + |
| 89 | + # Sum over models dimension |
| 90 | + positive_sum = positive_weighted.sum(dim=2) # Shape: (num_smiles, num_classes) |
| 91 | + negative_sum = negative_weighted.sum(dim=2) # Shape: (num_smiles, num_classes) |
| 92 | + |
| 93 | + # Determine which classes to include for each SMILES |
| 94 | + net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes) |
| 95 | + class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes) |
| 96 | + |
| 97 | + # Convert tensor decisions to result list using list comprehension for efficiency |
| 98 | + result = [ |
| 99 | + [class_indices[idx.item()] for idx in torch.nonzero(class_decisions[i], as_tuple=True)[0]] |
| 100 | + for i in range(num_smiles) |
| 101 | + ] |
| 102 | + |
| 103 | + return result |
| 104 | + |
| 105 | + |
| 106 | + def calculate_classwise_weights(self, predicted_classes): |
| 107 | + """No weights, simple majority voting""" |
| 108 | + positive_weights = torch.ones(len(predicted_classes), len(self.models)) |
| 109 | + negative_weights = torch.ones(len(predicted_classes), len(self.models)) |
| 110 | + |
| 111 | + return positive_weights, negative_weights |
| 112 | + |
| 113 | + def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list: |
| 114 | + preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt" |
| 115 | + predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt" |
| 116 | + if not load_preds_if_possible or not os.path.isfile(preds_file): |
| 117 | + ordered_predictions = predicted_classes = self.gather_predictions(smiles_list) |
| 118 | + # save predictions |
| 119 | + torch.save(ordered_predictions, preds_file) |
| 120 | + with open(predicted_classes_file, "w") as f: |
| 121 | + for cls in predicted_classes: |
| 122 | + f.write(f"{cls}\n") |
| 123 | + else: |
| 124 | + print(f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}") |
| 125 | + ordered_predictions = torch.load(preds_file) |
| 126 | + with open(predicted_classes_file, "r") as f: |
| 127 | + predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())} |
| 128 | + |
| 129 | + classwise_weights = self.calculate_classwise_weights(predicted_classes) |
| 130 | + aggregated_predictions = self.consolidate_predictions(ordered_predictions, predicted_classes, classwise_weights) |
| 131 | + return aggregated_predictions |
0 commit comments