Skip to content

Commit f4e6fa9

Browse files
committed
Skeleton of Deploy tool
1 parent d51ebe1 commit f4e6fa9

File tree

4 files changed

+311
-4
lines changed

4 files changed

+311
-4
lines changed

src/AMSWorkflow/ams/deploy_tools.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import sys
2+
import os
3+
import select
4+
import subprocess as sp
5+
from enum import Enum
6+
7+
8+
class RootSched(Enum):
9+
SLURM = 1
10+
LSF = 2
11+
12+
13+
def _run_daemon(cmd, shell=False):
14+
print(f"Going to run {cmd}")
15+
proc = sp.Popen(cmd, shell=shell, stdout=sp.PIPE, stderr=sp.PIPE, bufsize=1, text=True, universal_newlines=True)
16+
return proc
17+
18+
19+
def _read_flux_uri(proc, timeout=5):
20+
"""
21+
Reads the first line from the flux start command's stdout and puts it into a queue.
22+
:param timeout: The maximum of time we wait for writting to stdout
23+
:param proc: The process from which to read stdout.
24+
"""
25+
26+
# Time to wait for I/O plus the time already waited
27+
total_wait_time = 0
28+
poll_interval = 0.5 # Poll interval in seconds
29+
30+
while total_wait_time < timeout:
31+
# Check if there is data to read from stdout
32+
ready_to_read = select.select([proc.stdout], [], [], poll_interval)[0]
33+
if ready_to_read:
34+
first_line = proc.stdout.readline()
35+
print("Frist line is", first_line)
36+
if "ssh" in first_line:
37+
return first_line
38+
total_wait_time += poll_interval
39+
print(f"Waited for {total_wait_time}")
40+
return None
41+
42+
43+
def spawn_rmq_broker(flux_uri):
44+
# TODO We need to implement this, my current specification is limited
45+
# We probably need to access to flux, to spawn a daemon inside the flux allocation
46+
raise NotImplementedError("spawn_rmq_broker is not implemented, spawn it manually and provide the credentials")
47+
return None, None
48+
49+
50+
def start_flux(scheduler, nnodes=None):
51+
def bootstrap_with_slurm(nnodes):
52+
if nnodes is None:
53+
nnodes = os.environ.get("SLURM_NNODES", None)
54+
55+
bootstrap_cmd = f"srun -N {nnodes} -n {nnodes} --pty --mpi=none --mpibind=off flux start"
56+
flux_get_uri_cmd = "flux uri --remote \\$FLUX_URI; sleep inf"
57+
58+
daemon = _run_daemon(f'{bootstrap_cmd} "{flux_get_uri_cmd}"', shell=True)
59+
flux_uri = _read_flux_uri(daemon, timeout=10)
60+
print("Got flux uri: ", flux_uri)
61+
if flux_uri is None:
62+
print("Fatal Error, Cannot read flux")
63+
daemon.terminate()
64+
raise RuntimeError("Cannot Get FLUX URI")
65+
66+
return daemon, flux_uri
67+
68+
if scheduler == RootSched.SLURM:
69+
return bootstrap_with_slurm(nnodes)
70+
71+
raise NotImplementedError("We are only supporting bootstrap through SLURM")

src/AMSWorkflow/ams/job_types.py

+163
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
from dataclasses import dataclass
2+
from pathlib import Path
3+
import os
4+
import sys
5+
import shutil
6+
from warnings import warn
7+
from typing import List, Dict, Optional, ClassVar
8+
from flux.job import JobspecV1
9+
import flux.job as fjob
10+
11+
from ams.loader import load_class
12+
13+
14+
@dataclass(kw_only=True)
15+
class BaseJob:
16+
"""
17+
Class Modeling a Job scheduled by AMS. There can be five types of JOBs (Physics, Stagers, Training, RMQServer and TrainingDispatcher)
18+
"""
19+
20+
name: str
21+
executable: str
22+
nodes: int
23+
tasks_per_node: int
24+
args: List[str] = list()
25+
exclusive: bool = True
26+
cores_per_task: int = 1
27+
environ: Dict[str, str] = dict()
28+
orderId: ClassVar[int] = 0
29+
gpus_per_task: Optional[int] = None
30+
stdout: Optional[str] = None
31+
stderr: Optional[str] = None
32+
33+
def _construct_command(self):
34+
command = [self.executable] + self.args
35+
return command
36+
37+
def _construct_environ(self, forward_environ):
38+
environ = self.environ
39+
if forward_environ is not None:
40+
if not isinstance(forward_environ, type(os.environ)) and not isinstance(forward_environ, dict):
41+
raise TypeError(f"Unsupported forward_environ type ({type(forward_environ)})")
42+
for k, v in forward_environ:
43+
if k in environ:
44+
warn(f"Key {k} already exists in environment ({environ[k]}), prioritizing existing one ({v})")
45+
else:
46+
environ[k] = forward_environ[k]
47+
return environ
48+
49+
def _construct_redirect_paths(self, redirectDir):
50+
stdDir = Path.cwd()
51+
if redirectDir is not None:
52+
stdDir = Path(redirectDir)
53+
54+
if self.stdout is None:
55+
stdout = f"{stdDir}/{self.name}_{BaseJob.orderId}.out"
56+
else:
57+
stdout = f"{stdDir}/{self.stdout}_{BaseJob.orderId}.out"
58+
59+
if self.stderr is None:
60+
stderr = f"{stdDir}/{self.name}_{BaseJob.orderId}.err"
61+
else:
62+
stderr = f"{stdDir}/{self.stderr}_{BaseJob.orderId}.err"
63+
64+
BaseJob.orderId += 1
65+
66+
return stdout, stderr
67+
68+
def schedule(self, flux_handle, forward_environ=None, redirectDir=None, pre_signed=False, waitable=True):
69+
jobspec = JobspecV1.from_command(
70+
command=self._construct_command(),
71+
num_tasks=self.tasks_per_node * self.nodes,
72+
num_nodes=self.nodes,
73+
cores_per_task=self.cores_per_task,
74+
gpus_per_task=self.gpus_per_task,
75+
exclusive=self.exclusive,
76+
)
77+
78+
stdout, stderr = self._construct_redirect_paths(redirectDir)
79+
environ = self._construct_environ(forward_environ)
80+
jobspec.environment = environ
81+
jobspec.stdout = stdout
82+
jobspec.stderr = stderr
83+
84+
return jobspec, fjob.submit(flux_handle, jobspec, pre_signed=pre_signed, waitable=waitable)
85+
86+
87+
@dataclass(kw_only=True)
88+
class PhysicsJob(BaseJob):
89+
def _verify(self):
90+
is_executable = shutil.which(self.executable) is not None
91+
is_path = Path(self.executable).is_file()
92+
return is_executable or is_path
93+
94+
def __post_init__(self):
95+
if not self._verify():
96+
raise RuntimeError(
97+
f"[PhysicsJob] executable is neither a executable nor a system command {self.executable}"
98+
)
99+
100+
101+
@dataclass(kw_only=True, init=False)
102+
class Stager(BaseJob):
103+
def _get_stager_default_cores(self):
104+
"""
105+
We need the following cores:
106+
1 RMQ Client to receive messages
107+
1 Process to store to filesystem
108+
1 Process to make public to kosh
109+
"""
110+
return 3
111+
112+
def _verify(self, pruner_path, pruner_cls):
113+
assert Path(pruner_path).is_file(), "Path to Pruner class should exist"
114+
user_class = load_class(pruner_path, pruner_cls)
115+
print(f"Loaded Pruner Class {user_class.__name__}")
116+
117+
def __init__(
118+
self,
119+
name: str,
120+
num_cores: int,
121+
db_path: str,
122+
pruner_cls: str,
123+
pruner_path: str,
124+
pruner_args: List[str],
125+
num_gpus: Optional[int],
126+
**kwargs,
127+
):
128+
executable = sys.executable
129+
130+
self._verify(pruner_path, pruner_cls)
131+
132+
# TODO: Here we are accessing both the stager arguments and the pruner_arguments. Is is an oppotunity to emit
133+
# an early error message. But, this would require extending argparse or something else. Noting for future reference
134+
cli_arguments = [
135+
"-m",
136+
"ams_wf.AMSDBStage",
137+
"-db",
138+
db_path,
139+
"--policy",
140+
"process",
141+
"--dest",
142+
str(Path(db_path) / Path("candidates")),
143+
"--db-type",
144+
"dhdf5",
145+
"--store",
146+
"-m",
147+
"fs",
148+
"--class",
149+
pruner_cls,
150+
]
151+
cli_arguments += pruner_args
152+
153+
num_cores = self._get_stager_default_cores() + num_cores
154+
super().__init__(
155+
name=name,
156+
executable=executable,
157+
nodes=1,
158+
tasks_per_node=1,
159+
cores_per_task=num_cores,
160+
args=cli_arguments,
161+
gpus_per_task=num_gpus,
162+
**kwargs,
163+
)

src/AMSWorkflow/ams/rmq_async.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -402,17 +402,19 @@ def stop(self):
402402
print("Already closed?")
403403

404404

405-
def broker_running(credentials, cacert):
405+
def broker_status(credentials, cacert):
406+
print(credentials)
407+
print(cacert)
406408
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2)
407409
ssl_context.verify_mode = ssl.CERT_REQUIRED
408410
ssl_context.load_verify_locations(cacert)
409411

410412
pika_credentials = pika.PlainCredentials(credentials["rabbitmq-user"], credentials["rabbitmq-password"])
411413

412414
parameters = pika.ConnectionParameters(
413-
host=credentials["server"],
414-
port=credentials["port"],
415-
virtual_host=credentials["virtual_host"],
415+
host=credentials["service-host"],
416+
port=credentials["service-port"],
417+
virtual_host=credentials["rabbitmq-vhost"],
416418
credentials=pika_credentials,
417419
ssl_options=pika.SSLOptions(ssl_context),
418420
)

src/AMSWorkflow/ams_wf/AMSDeploy.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import argparse
2+
import logging
3+
import sys
4+
import os
5+
import json
6+
from urllib import parse
7+
8+
from ams.deploy_tools import spawn_rmq_broker
9+
from ams.deploy_tools import RootSched
10+
from ams.deploy_tools import start_flux
11+
from ams.rmq_async import broker_status
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
def get_rmq_credentials(flux_uri, rmq_creds, rmq_cert):
17+
if rmq_creds is None:
18+
# TODO Overhere we need to spawn our own server
19+
rmq_creds, rmq_cert = spawn_rmq_broker(flux_uri)
20+
with open(rmq_creds, "r") as fd:
21+
rmq_creds = json.load(fd)
22+
23+
return rmq_creds, rmq_cert
24+
25+
26+
def main():
27+
parser = argparse.ArgumentParser(description="AMS workflow deployment")
28+
29+
parser.add_argument("--rmq-creds", help="Credentials file (JSON)")
30+
parser.add_argument("--rmq-cert", help="TLS certificate file")
31+
parser.add_argument("--flux-uri", help="Flux uri of an already existing allocation")
32+
parser.add_argument("--nnodes", help="Number of nnodes to use for this AMS Deployment")
33+
parser.add_argument(
34+
"--root-scheduler",
35+
dest="scheduler",
36+
choices=[e.name for e in RootSched],
37+
help="The provided scheduler of the cluster",
38+
)
39+
40+
args = parser.parse_args()
41+
42+
"""
43+
Verify System is on a "Valid" Status
44+
"""
45+
46+
if args.flux_uri is None and args.scheduler is None:
47+
print("Please provide either a flux URI handle to connect to or provide the base job scheduler")
48+
sys.exit()
49+
50+
flux_process = None
51+
flux_uri = args.flux_uri
52+
if flux_uri is None:
53+
flux_process, flux_uri = start_flux(RootSched[args.scheduler], args.nnodes)
54+
55+
rmq_creds, rmq_cert = get_rmq_credentials(flux_uri, args.rmq_creds, args.rmq_cert)
56+
57+
if not broker_status(rmq_creds, rmq_cert):
58+
# If we created a subprocess in the background to run flux, we should terminate it
59+
if flux_process is not None:
60+
flux_process.terminate()
61+
print("RMQ Broker is not connected, exiting ...")
62+
sys.exit()
63+
64+
"""
65+
We Have FLUX URI and here we know that rmq_creds, and rmq_cert are valid and we can start
66+
scheduling jobs
67+
"""
68+
69+
70+
if __name__ == "__main__":
71+
main()

0 commit comments

Comments
 (0)