Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions CosmoAPI/__main__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,24 @@
import argparse
from typing import Dict, Any

from .api_io import load_yaml_file
from .not_implemented import not_implemented_message

def gen_datavec(config, verbose=False):
def gen_datavec(config: Dict[str, Any], verbose: bool = False) -> None:
# Functionality for generating data vector
if verbose:
print("Verbose mode enabled.")
print("Generating data vector with config:", config)

def gen_covariance(config):
def gen_covariance(config: Dict[str, Any]) -> None:
# Functionality for generating covariance
print(not_implemented_message)

def forecast(config):
def forecast(config: Dict[str, Any]) -> None:
# Functionality for forecast
print(not_implemented_message)

def main():
def main() -> None:
parser = argparse.ArgumentParser(
prog="CosmoAPI",
description="CosmoAPI: Cosmology Analysis Pipeline Interface"
Expand Down Expand Up @@ -83,4 +84,4 @@ def main():
forecast(config)

if __name__ == "__main__":
main()
main()
5 changes: 3 additions & 2 deletions CosmoAPI/api_io.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import yaml
import importlib
from typing import Any, Dict

def load_yaml_file(file_path):
def load_yaml_file(file_path: str) -> Dict[str, Any]:
"""Helper function to load a YAML file"""
with open(file_path, 'r') as file:
return yaml.safe_load(file)

def load_metadata_function_class(function_name):
def load_metadata_function_class(function_name: str) -> Any:
"""
Dynamically load a class based on the 'function' name specified in the YAML file.
FIXME: Change the docstrings
Expand Down
14 changes: 7 additions & 7 deletions CosmoAPI/two_point_functions/generate_theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from firecrown.metadata_functions import make_all_photoz_bin_combinations
import firecrown.likelihood.two_point as tp
from firecrown.utils import base_model_from_yaml

from typing import Dict, Any, List, Tuple

from .nz_loader import load_all_nz
sys.path.append("..")
from not_implemented import not_implemented_message
from api_io import load_metadata_function_class


def generate_ell_theta_array_from_yaml(yaml_data, type_key, dtype=float):
def generate_ell_theta_array_from_yaml(yaml_data: Dict[str, Any], type_key: str, dtype: type = float) -> np.ndarray:
"""
Generate a linear or logarithmic array based on the configuration in the YAML data.

Expand All @@ -39,7 +39,7 @@ def generate_ell_theta_array_from_yaml(yaml_data, type_key, dtype=float):
else:
raise ValueError(f"Unknown array type: {array_type}")

def load_systematics_factory(probe_systematics):
def load_systematics_factory(probe_systematics: Dict[str, Any]) -> Any:
"""
Dynamically load a class based on the systematics 'type' specified in the YAML file.

Expand Down Expand Up @@ -87,7 +87,7 @@ def load_systematics_factory(probe_systematics):
except AttributeError as e:
raise AttributeError(f"Class '{systematics_type}' not found in module {module_path}: {e}")

def process_probes_load_2pt(yaml_data):
def process_probes_load_2pt(yaml_data: Dict[str, Any]) -> Tuple[Any, List[str]]:
"""
Process the probes from the YAML data, check if 'function'
is the same across probes with 'nz_type',
Expand Down Expand Up @@ -138,8 +138,8 @@ def process_probes_load_2pt(yaml_data):

return loaded_function, nz_type_probes

def generate_two_point_metadata(yaml_data, two_point_function, two_pt_probes,
two_point_bins):
def generate_two_point_metadata(yaml_data: Dict[str, Any], two_point_function: Any, two_pt_probes: List[str],
two_point_bins: List[Any]) -> List[Any]:
"""
Generate the metadata for the two-point functions based on the YAML data.

Expand Down Expand Up @@ -181,7 +181,7 @@ def generate_two_point_metadata(yaml_data, two_point_function, two_pt_probes,
raise ValueError("Unknown TwoPointFunction type")
return all_two_point_metadata

def prepare_2pt_functions(yaml_data):
def prepare_2pt_functions(yaml_data: Dict[str, Any]) -> Tuple[Any, List[Any]]:
# here we call this X because we do not know if it is ell_bins or theta_bins
two_point_function, two_pt_probes = process_probes_load_2pt(yaml_data)

Expand Down
9 changes: 5 additions & 4 deletions CosmoAPI/two_point_functions/nz_loader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import importlib
import sys
from typing import Dict, List, Any, Type
sys.path.append("..")
from not_implemented import not_implemented_message

_DESC_SCENARIOS = {"LSST_Y10_SOURCE_BIN_COLLECTION", "LSST_Y10_LENS_BIN_COLLECTION",
"LSST_Y1_LENS_BIN_COLLECTION", "LSST_Y1_SOURCE_BIN_COLLECTION",}

def _load_nz(yaml_data):
def _load_nz(yaml_data: Dict[str, Any]) -> List[Any]:
try:
nz_type = yaml_data["nz_type"]
except KeyError:
Expand All @@ -17,15 +18,15 @@ def _load_nz(yaml_data):
else:
raise NotImplementedError(not_implemented_message)

def load_all_nz(yaml_data):
def load_all_nz(yaml_data: Dict[str, Any]) -> List[Any]:
nzs = []
for probe, propr in yaml_data['probes'].items():
if 'nz_type' in propr:
#print(propr['nz_type'])
nzs += _load_nz(propr)
return nzs

def _load_nz_from_module(nz_type):
def _load_nz_from_module(nz_type: str) -> Type:
# Define the module path
module_path = "firecrown.generators.inferred_galaxy_zdist"

Expand All @@ -37,4 +38,4 @@ def _load_nz_from_module(nz_type):
except ImportError as e:
raise ImportError(f"Failed to import module {module_path}: {e}")
except AttributeError as e:
raise AttributeError(f"'{nz_type}' not found in module {module_path}: {e}")
raise AttributeError(f"'{nz_type}' not found in module {module_path}: {e}")