Skip to content

Commit 321eb93

Browse files
committed
feat: add SlurmClusterManager support
1 parent fc32e76 commit 321eb93

File tree

3 files changed

+43
-26
lines changed

3 files changed

+43
-26
lines changed

pysr/julia_extensions.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,20 @@
22

33
from typing import Literal
44

5+
from .julia_helpers import KNOWN_CLUSTERMANAGER_BACKENDS
56
from .julia_import import Pkg, jl
67
from .julia_registry_helpers import try_with_registry_fallback
78
from .logger_specs import AbstractLoggerSpec, TensorBoardLoggerSpec
89

10+
PACKAGE_UUIDS = {
11+
"LoopVectorization": "bdcacae8-1622-11e9-2a5c-532679323890",
12+
"Bumper": "8ce10254-0962-460f-a3d8-1f77fea1446e",
13+
"Zygote": "e88e6eb3-aa80-5325-afca-941959d7151f",
14+
"SlurmClusterManager": "c82cd089-7bf7-41d7-976b-6b5d413cbe0a",
15+
"ClusterManagers": "34f1f09b-3a8b-5176-ab39-66d58a4d544e",
16+
"TensorBoardLogger": "899adc3e-224a-11e9-021f-63837185c80f",
17+
}
18+
919

1020
def load_required_packages(
1121
*,
@@ -16,26 +26,24 @@ def load_required_packages(
1626
logger_spec: AbstractLoggerSpec | None = None,
1727
):
1828
if turbo:
19-
load_package("LoopVectorization", "bdcacae8-1622-11e9-2a5c-532679323890")
29+
load_package("LoopVectorization")
2030
if bumper:
21-
load_package("Bumper", "8ce10254-0962-460f-a3d8-1f77fea1446e")
31+
load_package("Bumper")
2232
if autodiff_backend is not None:
23-
load_package("Zygote", "e88e6eb3-aa80-5325-afca-941959d7151f")
33+
load_package("Zygote")
2434
if cluster_manager is not None:
25-
load_package("ClusterManagers", "34f1f09b-3a8b-5176-ab39-66d58a4d544e")
35+
if cluster_manager == "slurm_native":
36+
load_package("SlurmClusterManager")
37+
elif cluster_manager in KNOWN_CLUSTERMANAGER_BACKENDS:
38+
load_package("ClusterManagers")
2639
if isinstance(logger_spec, TensorBoardLoggerSpec):
27-
load_package("TensorBoardLogger", "899adc3e-224a-11e9-021f-63837185c80f")
40+
load_package("TensorBoardLogger")
2841

2942

3043
def load_all_packages():
3144
"""Install and load all Julia extensions available to PySR."""
32-
load_required_packages(
33-
turbo=True,
34-
bumper=True,
35-
autodiff_backend="Zygote",
36-
cluster_manager="slurm",
37-
logger_spec=TensorBoardLoggerSpec(log_dir="logs"),
38-
)
45+
for package_name, uuid_s in PACKAGE_UUIDS.items():
46+
load_package(package_name, uuid_s)
3947

4048

4149
# TODO: Refactor this file so we can install all packages at once using `juliapkg`,
@@ -46,7 +54,8 @@ def isinstalled(uuid_s: str):
4654
return jl.haskey(Pkg.dependencies(), jl.Base.UUID(uuid_s))
4755

4856

49-
def load_package(package_name: str, uuid_s: str) -> None:
57+
def load_package(package_name: str, uuid_s: str | None = None) -> None:
58+
uuid_s = uuid_s or PACKAGE_UUIDS[package_name]
5059
if not isinstalled(uuid_s):
5160

5261
def _add_package():

pysr/julia_helpers.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,19 @@ def _escape_filename(filename):
2929
return str_repr
3030

3131

32-
def _load_cluster_manager(cluster_manager: str):
33-
jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
34-
return jl.seval(f"addprocs_{cluster_manager}")
32+
KNOWN_CLUSTERMANAGER_BACKENDS = ["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"]
33+
34+
35+
def load_cluster_manager(cluster_manager: str) -> AnyValue:
36+
if cluster_manager == "slurm_native":
37+
jl.seval("using SlurmClusterManager: SlurmManager")
38+
return jl.seval("SlurmManager")
39+
elif cluster_manager in KNOWN_CLUSTERMANAGER_BACKENDS:
40+
jl.seval(f"using ClusterManagers: addprocs_{cluster_manager}")
41+
return jl.seval(f"addprocs_{cluster_manager}")
42+
else:
43+
# Assume it's a function
44+
return jl.seval(cluster_manager)
3545

3646

3747
def jl_array(x, dtype=None):

pysr/sr.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@
4141
from .julia_extensions import load_required_packages
4242
from .julia_helpers import (
4343
_escape_filename,
44-
_load_cluster_manager,
4544
jl_array,
4645
jl_deserialize,
4746
jl_is_function,
4847
jl_serialize,
48+
load_cluster_manager,
4949
)
5050
from .julia_import import AnyValue, SymbolicRegression, VectorValue, jl
5151
from .logger_specs import AbstractLoggerSpec
@@ -549,8 +549,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
549549
Default is `None`.
550550
cluster_manager : str
551551
For distributed computing, this sets the job queue system. Set
552-
to one of "slurm", "pbs", "lsf", "sge", "qrsh", "scyld", or
553-
"htc". If set to one of these, PySR will run in distributed
552+
to one of "slurm_native", "slurm", "pbs", "lsf", "sge", "qrsh", "scyld",
553+
or "htc". If set to one of these, PySR will run in distributed
554554
mode, and use `procs` to figure out how many processes to launch.
555555
Default is `None`.
556556
heap_size_hint_in_bytes : int
@@ -849,13 +849,11 @@ def __init__(
849849
probability_negate_constant: float = 0.00743,
850850
tournament_selection_n: int = 15,
851851
tournament_selection_p: float = 0.982,
852-
parallelism: (
853-
Literal["serial", "multithreading", "multiprocessing"] | None
854-
) = None,
852+
# fmt: off
853+
parallelism: Literal["serial", "multithreading", "multiprocessing"] | None = None,
855854
procs: int | None = None,
856-
cluster_manager: (
857-
Literal["slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"] | None
858-
) = None,
855+
cluster_manager: Literal["slurm_native", "slurm", "pbs", "lsf", "sge", "qrsh", "scyld", "htc"] | str | None = None,
856+
# fmt: on
859857
heap_size_hint_in_bytes: int | None = None,
860858
batching: bool = False,
861859
batch_size: int = 50,
@@ -1842,7 +1840,7 @@ def _run(
18421840
raise ValueError(
18431841
"To use cluster managers, you must set `parallelism='multiprocessing'`."
18441842
)
1845-
cluster_manager = _load_cluster_manager(cluster_manager)
1843+
cluster_manager = load_cluster_manager(cluster_manager)
18461844

18471845
# TODO(mcranmer): These functions should be part of this class.
18481846
binary_operators, unary_operators = _maybe_create_inline_operators(

0 commit comments

Comments
 (0)