|
9 | 9 | import flux
|
10 | 10 | from flux.job import JobspecV1
|
11 | 11 | import flux.job as fjob
|
12 |
| - |
| 12 | +import signal |
13 | 13 | from ams.store import CreateStore, AMSDataStore
|
14 | 14 |
|
15 | 15 | logger = logging.getLogger(__name__)
|
@@ -84,6 +84,91 @@ def __init__(self, stager_job_generator, config):
|
84 | 84 | self._stager = JobSpec("ams_stager", stager_job_generator(config), exclusive=True)
|
85 | 85 |
|
86 | 86 |
|
| 87 | +class AMSConcurrentJobScheduler(AMSJobScheduler): |
| 88 | + def __init__(self, config): |
| 89 | + def create_rmq_stager_job_descr(user_descr): |
| 90 | + config = dict() |
| 91 | + |
| 92 | + # TODO: This is SUPER ugly and not to mention |
| 93 | + # potenitally buggy. We will need to clean this up |
| 94 | + # once we have all pieces in places (including AMSlib json initialization) |
| 95 | + with open("rmq_config.json", "w") as fd: |
| 96 | + json.dump(user_descr["stager"]["rmq"], fd, indent=6) |
| 97 | + |
| 98 | + rmq_config_path = Path("rmq_config.json").resolve() |
| 99 | + |
| 100 | + config["executable"] = sys.executable |
| 101 | + config["arguments"] = [ |
| 102 | + "-m", |
| 103 | + "ams_wf.AMSDBStage", |
| 104 | + "-db", |
| 105 | + user_descr["db"]["path"], |
| 106 | + "--policy", |
| 107 | + "process", |
| 108 | + "--dest", |
| 109 | + str(Path(user_descr["db"]["path"]) / Path("candidates")), |
| 110 | + "--db-type", |
| 111 | + "dhdf5", |
| 112 | + "--store", |
| 113 | + "--mechanism", |
| 114 | + "network", |
| 115 | + "--class", |
| 116 | + user_descr["stager"]["pruner_class"], |
| 117 | + "--cert", |
| 118 | + user_descr["stager"]["rmq"]["rabbitmq-cert"], |
| 119 | + "--creds", |
| 120 | + str(rmq_config_path), |
| 121 | + "--queue", |
| 122 | + user_descr["stager"]["rmq"]["rabbitmq-outbound-queue"], |
| 123 | + "--load", |
| 124 | + user_descr["stager"]["pruner_path"], |
| 125 | + ] + user_descr["stager"]["pruner_args"] |
| 126 | + |
| 127 | + config["resources"] = { |
| 128 | + "num_nodes": 1, |
| 129 | + "num_processes_per_node": 1, |
| 130 | + "num_tasks": 1, |
| 131 | + "cores_per_task": 5, |
| 132 | + "gpus_per_task": 0, |
| 133 | + } |
| 134 | + |
| 135 | + return config |
| 136 | + |
| 137 | + super().__init__(create_rmq_stager_job_descr, config) |
| 138 | + |
| 139 | + def execute(self): |
| 140 | + def execute_and_wait(job_descr, handle): |
| 141 | + jid = job_descr.start(handle) |
| 142 | + if not result.success: |
| 143 | + logger.critical(f"Unsuccessfull Job Execution: {job_descr.name}") |
| 144 | + logger.debug(f"Error code of failed job {result.jobid} is {result.errstr}") |
| 145 | + logger.debug(f"stdout is redirected to: {job_descr.stdout}") |
| 146 | + logger.debug(f"stderr is redirected to: {job_descr.stderr}") |
| 147 | + return False |
| 148 | + return True |
| 149 | + |
| 150 | + # We start stager first |
| 151 | + logger.debug("Start stager") |
| 152 | + stager_id = self._stager.start(self._flux_handle) |
| 153 | + logger.debug(f"Stager job id is {stager_id}") |
| 154 | + |
| 155 | + logger.debug("Start user app") |
| 156 | + user_app_id = self._user_app.start(self._flux_handle) |
| 157 | + logger.debug(f"User App job id is {user_app_id}") |
| 158 | + |
| 159 | + # We are actively waiting for main application to terminate |
| 160 | + logger.debug("Wait for user application") |
| 161 | + result = fjob.wait(self._flux_handle, jobid=user_app_id) |
| 162 | + |
| 163 | + # stager handles SIGTERM, kill it |
| 164 | + kill_status = fjob.kill_async(self._flux_handle, jobid=stager_id, signum=signal.SIGTERM) |
| 165 | + logger.debug("Waiting for job to be killed") |
| 166 | + print(kill_status.get()) |
| 167 | + fjob.wait(self._flux_handle, jobid=stager_id) |
| 168 | + |
| 169 | + return True |
| 170 | + |
| 171 | + |
87 | 172 | class AMSSequentialJobScheduler(AMSJobScheduler):
|
88 | 173 | def __init__(self, config):
|
89 | 174 | def create_fs_stager_job_descr(user_descr):
|
@@ -120,7 +205,7 @@ def create_fs_stager_job_descr(user_descr):
|
120 | 205 |
|
121 | 206 | return config
|
122 | 207 |
|
123 |
| - super().__init__(config, create_fs_stager_job_descr) |
| 208 | + super().__init__(create_fs_stager_job_descr, config) |
124 | 209 |
|
125 | 210 | def execute(self):
|
126 | 211 | def execute_and_wait(job_descr, handle):
|
@@ -152,11 +237,10 @@ def deploy(config):
|
152 | 237 | # the server is up and running
|
153 | 238 | logger.info(f"")
|
154 | 239 | if config["execution_mode"] == "concurrent":
|
155 |
| - # TODO Launch concurrent execution |
156 |
| - pass |
| 240 | + executor = AMSConcurrentJobScheduler(config) |
157 | 241 | elif config["execution_mode"] == "sequential":
|
158 | 242 | executor = AMSSequentialJobScheduler(config)
|
159 |
| - return executor.execute() |
| 243 | + return executor.execute() |
160 | 244 |
|
161 | 245 |
|
162 | 246 | def bootstrap(cmd, scheduler, flux_log):
|
@@ -241,10 +325,6 @@ def validate_step_field(level, config):
|
241 | 325 | if config["stager"]["mode"] == "filesystem":
|
242 | 326 | logger.critical("Database is concurrent but the stager polls data from filesystem")
|
243 | 327 | return False
|
244 |
| - elif config["stager"]["mode"] == "rmq": |
245 |
| - if "num_clients" not in config["stager"]: |
246 |
| - logger.critical("When stager set in mode 'rmq' you need to define the number of rmq clients") |
247 |
| - return False |
248 | 328 |
|
249 | 329 | if config["stager"]["mode"] == "rmq":
|
250 | 330 | rmq_config = config["stager"]["rmq"]
|
@@ -432,8 +512,6 @@ def main():
|
432 | 512 | ret = not args.func(args)
|
433 | 513 | return ret
|
434 | 514 |
|
435 |
| - sys.exit(main()) |
436 |
| - |
437 | 515 |
|
438 | 516 | if __name__ == "__main__":
|
439 | 517 | sys.exit(main())
|
0 commit comments