Skip to content

Commit 5fa1ea5

Browse files
committed
AMSDeploy tries to kill sleep jobs, this is untested
1 parent 5d5b3eb commit 5fa1ea5

File tree

2 files changed

+57
-17
lines changed

2 files changed

+57
-17
lines changed

src/AMSWorkflow/ams/stage.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,6 @@ def __call__(self):
525525
while True:
526526
# This is a blocking call
527527
item = self.i_queue.get(block=True)
528-
print(f"{self.__class__.__name__} Received message {total_messages}")
529528
total_messages += 1
530529
if item.is_terminate():
531530
for k, v in data_files.items():
@@ -552,7 +551,6 @@ def __call__(self):
552551
data_files[data.domain_name][1] += bytes_written
553552
total_bytes_written += data.inputs.size * data.inputs.itemsize
554553
total_bytes_written += data.outputs.size * data.outputs.itemsize
555-
print(f"Received Terminate {data.inputs.shape} {data.outputs.shape}")
556554

557555
if data_files[data.domain_name][1] >= 2 * 1024 * 1024 * 1024:
558556
data_files[data.domain_name][0].close()
@@ -562,6 +560,10 @@ def __call__(self):
562560
)
563561
)
564562
del data_files[data.domain_name]
563+
if total_messages % 100 == 0:
564+
print(
565+
f"I have processed {total_messages} in total amounting to {total_bytes_written/(1024.0*1024.0)} MB"
566+
)
565567

566568
end = time.time()
567569
self.datasize_byte = total_bytes_written
@@ -609,9 +611,7 @@ def __call__(self):
609611

610612
with AMSMonitor(obj=self, tag="internal_loop", record=[]):
611613
while True:
612-
print(f"{self.__class__.__name__} Receives messages at queue:", self.i_queue)
613614
item = self.i_queue.get(block=True)
614-
print(f"{self.__class__.__name__} Received messages at queue:", self.i_queue)
615615
if item.is_terminate():
616616
print(f"Received Terminate {self.__class__.__name__}")
617617
break
@@ -840,12 +840,12 @@ def execute(self, policy):
840840
f"Pipeline execute does not support policy: {policy}, please select from {Pipeline.supported_policies}"
841841
)
842842

843-
# self.init_signals()
843+
self.init_signals()
844844
# Create a pipeline of actions and link them with appropriate queues
845845
self._link_pipeline(policy)
846846
# Execute them
847847
self._execute_tasks(policy)
848-
# self.release_signals()
848+
self.release_signals()
849849

850850
@abstractmethod
851851
def requires_model_update(self):

src/AMSWorkflow/ams_wf/AMSDeploy.py

+51-11
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
import os
23
from ams.ams_flux import AMSFluxExecutor
34
import time
45
from ams.ams_jobs import nested_instance_job_descr, get_echo_job
@@ -7,6 +8,7 @@
78
import warnings
89
from flux.job import FluxExecutor
910
import flux
11+
from flux.job import kill as fkill
1012

1113

1214
def verify_arg(name, uri, nodes):
@@ -27,7 +29,7 @@ def get_partition_uri(root_executor, nnodes, cores_per_node, gpus_per_node, time
2729
uri = fut.uri()
2830
nested_instance = flux.Flux(uri)
2931
nested_instance.rpc("state-machine.wait").get()
30-
return uri
32+
return uri, fut
3133

3234

3335
def main():
@@ -57,6 +59,17 @@ def main():
5759

5860
args = parser.parse_args()
5961

62+
sleep_time = int(args.sleep_time)
63+
if sleep_time == 0:
64+
start = int(os.getenv("SLURM_JOB_START_TIME", "0"))
65+
end = int(os.getenv("SLURM_JOB_END_TIME", "0"))
66+
sleep_time = end - start
67+
68+
if sleep_time == 0:
69+
print("Cannot create background job with 0 time")
70+
return
71+
print(f"Partions will be allocated for {sleep_time}")
72+
6073
wf_manager = AMSWorkflowManager.from_descr(args.workflow_descr, args.credentials)
6174
print(wf_manager)
6275

@@ -83,31 +96,58 @@ def main():
8396
# NOTE: We need a AMSFluxExecutor to easily get flux uri because FluxExecutor does not provide the respective API
8497
# We set track_uri to true to enable the executor to generate futures tracking the uri of submitted jobs
8598
start = time.time()
86-
with AMSFluxExecutor(True, threads=6, handle_args=(args.root_uri,)) as root_executor:
99+
with AMSFluxExecutor(True, threads=1, handle_args=(args.root_uri,)) as root_executor:
87100
print("Spawning Flux executor for root took", time.time() - start)
88101
start = time.time()
89-
domain_uri = get_partition_uri(root_executor, num_domain_nodes, cores_per_node, gpus_per_node, args.sleep_time)
102+
domain_uri, domain_future = get_partition_uri(
103+
root_executor, num_domain_nodes, cores_per_node, gpus_per_node, str(sleep_time)
104+
)
90105
print("Resolving domain uri took", time.time() - start, domain_uri)
91106
start = time.time()
107+
ml_future = None
108+
stage_future = None
92109

93110
if ml_uri is None:
94-
ml_uri = get_partition_uri(root_executor, num_ml_nodes, cores_per_node, gpus_per_node, args.sleep_time)
111+
(
112+
ml_uri,
113+
ml_future,
114+
) = get_partition_uri(root_executor, num_ml_nodes, cores_per_node, gpus_per_node, str(sleep_time))
95115
print("Resolving ML uri took", time.time() - start, ml_uri)
96116
start = time.time()
97117

98118
if stage_uri is None:
99-
stage_uri = get_partition_uri(
100-
root_executor, num_stage_nodes, cores_per_node, gpus_per_node, args.sleep_time
119+
stage_uri, stage_future = get_partition_uri(
120+
root_executor, num_stage_nodes, cores_per_node, gpus_per_node, str(sleep_time)
101121
)
102122
print("Resolving stage uri took", time.time() - start, stage_uri)
103123

104124
# 1) We first schedule the ML training orchestrator.
105-
print("Here")
106125
wf_manager.start(ml_uri, stage_uri, domain_uri)
107-
print("Done")
108-
109-
return
126+
# The root executor should not wait, the partitions have "infinite allocation time. So we forcefully shut them down"
127+
print("All internal executors are done ... moving to stopping root job")
128+
# TODO: When I get here I need to kill all the jobs of the partitions and exit.
129+
print("Stopping domain partition...")
130+
domain_handle = flux.Flux(domain_uri)
131+
domain_handle.rpc("state-machine.wait").get()
132+
print("Cancel ", domain_future.jobid())
133+
# fkill(domain_handle, domain_future.jobid())
134+
135+
if ml_uri:
136+
print("Stopiing ml partition")
137+
ml_handle = flux.Flux(ml_uri)
138+
ml_handle.rpc("state-machine.wait").get()
139+
if stage_uri:
140+
print("Stopping stager partition")
141+
stage_future.cancel()
142+
ml_handle = flux.Flux(stage_uri)
143+
stage_handle.rpc("state-machine.wait").get()
144+
print("Shutting root executor")
145+
root_executor.shutdown(wait=False, cancel_futures=True)
146+
print("All daemons are down")
147+
print("Exiting")
148+
149+
return 0
110150

111151

112152
if __name__ == "__main__":
113-
main()
153+
sys.exit(main())

0 commit comments

Comments
 (0)