22
33from typing import Literal
44
5+ from .julia_helpers import KNOWN_CLUSTERMANAGER_BACKENDS
56from .julia_import import Pkg , jl
67from .julia_registry_helpers import try_with_registry_fallback
78from .logger_specs import AbstractLoggerSpec , TensorBoardLoggerSpec
89
10+ PACKAGE_UUIDS = {
11+ "LoopVectorization" : "bdcacae8-1622-11e9-2a5c-532679323890" ,
12+ "Bumper" : "8ce10254-0962-460f-a3d8-1f77fea1446e" ,
13+ "Zygote" : "e88e6eb3-aa80-5325-afca-941959d7151f" ,
14+ "SlurmClusterManager" : "c82cd089-7bf7-41d7-976b-6b5d413cbe0a" ,
15+ "ClusterManagers" : "34f1f09b-3a8b-5176-ab39-66d58a4d544e" ,
16+ "TensorBoardLogger" : "899adc3e-224a-11e9-021f-63837185c80f" ,
17+ }
18+
919
1020def load_required_packages (
1121 * ,
@@ -16,26 +26,24 @@ def load_required_packages(
1626 logger_spec : AbstractLoggerSpec | None = None ,
1727):
1828 if turbo :
19- load_package ("LoopVectorization" , "bdcacae8-1622-11e9-2a5c-532679323890" )
29+ load_package ("LoopVectorization" )
2030 if bumper :
21- load_package ("Bumper" , "8ce10254-0962-460f-a3d8-1f77fea1446e" )
31+ load_package ("Bumper" )
2232 if autodiff_backend is not None :
23- load_package ("Zygote" , "e88e6eb3-aa80-5325-afca-941959d7151f" )
33+ load_package ("Zygote" )
2434 if cluster_manager is not None :
25- load_package ("ClusterManagers" , "34f1f09b-3a8b-5176-ab39-66d58a4d544e" )
35+ if cluster_manager == "slurm_native" :
36+ load_package ("SlurmClusterManager" )
37+ elif cluster_manager in KNOWN_CLUSTERMANAGER_BACKENDS :
38+ load_package ("ClusterManagers" )
2639 if isinstance (logger_spec , TensorBoardLoggerSpec ):
27- load_package ("TensorBoardLogger" , "899adc3e-224a-11e9-021f-63837185c80f" )
40+ load_package ("TensorBoardLogger" )
2841
2942
3043def load_all_packages ():
3144 """Install and load all Julia extensions available to PySR."""
32- load_required_packages (
33- turbo = True ,
34- bumper = True ,
35- autodiff_backend = "Zygote" ,
36- cluster_manager = "slurm" ,
37- logger_spec = TensorBoardLoggerSpec (log_dir = "logs" ),
38- )
45+ for package_name , uuid_s in PACKAGE_UUIDS .items ():
46+ load_package (package_name , uuid_s )
3947
4048
4149# TODO: Refactor this file so we can install all packages at once using `juliapkg`,
@@ -46,7 +54,8 @@ def isinstalled(uuid_s: str):
4654 return jl .haskey (Pkg .dependencies (), jl .Base .UUID (uuid_s ))
4755
4856
49- def load_package (package_name : str , uuid_s : str ) -> None :
57+ def load_package (package_name : str , uuid_s : str | None = None ) -> None :
58+ uuid_s = uuid_s or PACKAGE_UUIDS [package_name ]
5059 if not isinstalled (uuid_s ):
5160
5261 def _add_package ():
0 commit comments