|
| 1 | +import importlib |
1 | 2 | import os |
2 | 3 | import time |
| 4 | +from pathlib import Path |
| 5 | +from typing import Union |
3 | 6 |
|
4 | 7 | import torch |
5 | 8 | import tqdm |
| 9 | +import yaml |
6 | 10 |
|
7 | 11 | from chebifier.check_env import check_package_installed |
8 | 12 | from chebifier.hugging_face import download_model_files |
9 | 13 | from chebifier.inconsistency_resolution import PredictionSmoother |
10 | 14 | from chebifier.prediction_models.base_predictor import BasePredictor |
11 | | -from chebifier.utils import get_disjoint_files, load_chebi_graph |
| 15 | +from chebifier.utils import ( |
| 16 | + get_default_configs, |
| 17 | + get_disjoint_files, |
| 18 | + load_chebi_graph, |
| 19 | + process_config, |
| 20 | +) |
12 | 21 |
|
13 | 22 |
|
14 | 23 | class BaseEnsemble: |
15 | 24 | def __init__( |
16 | 25 | self, |
17 | | - model_configs: dict, |
| 26 | + model_configs: Union[str, Path, dict, None] = None, |
18 | 27 | chebi_version: int = 241, |
19 | 28 | resolve_inconsistencies: bool = True, |
20 | 29 | ): |
21 | 30 | # Deferred Import: To avoid circular import error |
22 | 31 | from chebifier.model_registry import MODEL_TYPES |
23 | 32 |
|
| 33 | + # Load configuration from YAML file |
| 34 | + if not model_configs: |
| 35 | + config = get_default_configs() |
| 36 | + elif isinstance(model_configs, dict): |
| 37 | + config = model_configs |
| 38 | + else: |
| 39 | + print(f"Loading ensemble configuration from {model_configs}") |
| 40 | + with open(model_configs, "r") as f: |
| 41 | + config = yaml.safe_load(f) |
| 42 | + |
| 43 | + with ( |
| 44 | + importlib.resources.files("chebifier") |
| 45 | + .joinpath("model_registry.yml") |
| 46 | + .open("r") as f |
| 47 | + ): |
| 48 | + model_registry = yaml.safe_load(f) |
| 49 | + |
| 50 | + processed_configs = process_config(config, model_registry) |
| 51 | + |
24 | 52 | self.chebi_graph = load_chebi_graph() |
25 | 53 | self.disjoint_files = get_disjoint_files() |
26 | 54 |
|
27 | 55 | self.models = [] |
28 | 56 | self.positive_prediction_threshold = 0.5 |
29 | | - for model_name, model_config in model_configs.items(): |
| 57 | + for model_name, model_config in processed_configs.items(): |
30 | 58 | model_cls = MODEL_TYPES[model_config["type"]] |
31 | 59 | if "hugging_face" in model_config: |
32 | 60 | hugging_face_kwargs = download_model_files(model_config["hugging_face"]) |
|
0 commit comments