1
1
import argparse
2
+ import os
2
3
from ams .ams_flux import AMSFluxExecutor
3
4
import time
4
5
from ams .ams_jobs import nested_instance_job_descr , get_echo_job
7
8
import warnings
8
9
from flux .job import FluxExecutor
9
10
import flux
11
+ from flux .job import kill as fkill
10
12
11
13
12
14
def verify_arg (name , uri , nodes ):
@@ -27,7 +29,7 @@ def get_partition_uri(root_executor, nnodes, cores_per_node, gpus_per_node, time
27
29
uri = fut .uri ()
28
30
nested_instance = flux .Flux (uri )
29
31
nested_instance .rpc ("state-machine.wait" ).get ()
30
- return uri
32
+ return uri , fut
31
33
32
34
33
35
def main ():
@@ -57,6 +59,17 @@ def main():
57
59
58
60
args = parser .parse_args ()
59
61
62
+ sleep_time = int (args .sleep_time )
63
+ if sleep_time == 0 :
64
+ start = int (os .getenv ("SLURM_JOB_START_TIME" , "0" ))
65
+ end = int (os .getenv ("SLURM_JOB_END_TIME" , "0" ))
66
+ sleep_time = end - start
67
+
68
+ if sleep_time == 0 :
69
+ print ("Cannot create background job with 0 time" )
70
+ return
71
+ print (f"Partions will be allocated for { sleep_time } " )
72
+
60
73
wf_manager = AMSWorkflowManager .from_descr (args .workflow_descr , args .credentials )
61
74
print (wf_manager )
62
75
@@ -83,31 +96,58 @@ def main():
83
96
# NOTE: We need a AMSFluxExecutor to easily get flux uri because FluxExecutor does not provide the respective API
84
97
# We set track_uri to true to enable the executor to generate futures tracking the uri of submitted jobs
85
98
start = time .time ()
86
- with AMSFluxExecutor (True , threads = 6 , handle_args = (args .root_uri ,)) as root_executor :
99
+ with AMSFluxExecutor (True , threads = 1 , handle_args = (args .root_uri ,)) as root_executor :
87
100
print ("Spawning Flux executor for root took" , time .time () - start )
88
101
start = time .time ()
89
- domain_uri = get_partition_uri (root_executor , num_domain_nodes , cores_per_node , gpus_per_node , args .sleep_time )
102
+ domain_uri , domain_future = get_partition_uri (
103
+ root_executor , num_domain_nodes , cores_per_node , gpus_per_node , str (sleep_time )
104
+ )
90
105
print ("Resolving domain uri took" , time .time () - start , domain_uri )
91
106
start = time .time ()
107
+ ml_future = None
108
+ stage_future = None
92
109
93
110
if ml_uri is None :
94
- ml_uri = get_partition_uri (root_executor , num_ml_nodes , cores_per_node , gpus_per_node , args .sleep_time )
111
+ (
112
+ ml_uri ,
113
+ ml_future ,
114
+ ) = get_partition_uri (root_executor , num_ml_nodes , cores_per_node , gpus_per_node , str (sleep_time ))
95
115
print ("Resolving ML uri took" , time .time () - start , ml_uri )
96
116
start = time .time ()
97
117
98
118
if stage_uri is None :
99
- stage_uri = get_partition_uri (
100
- root_executor , num_stage_nodes , cores_per_node , gpus_per_node , args . sleep_time
119
+ stage_uri , stage_future = get_partition_uri (
120
+ root_executor , num_stage_nodes , cores_per_node , gpus_per_node , str ( sleep_time )
101
121
)
102
122
print ("Resolving stage uri took" , time .time () - start , stage_uri )
103
123
104
124
# 1) We first schedule the ML training orchestrator.
105
- print ("Here" )
106
125
wf_manager .start (ml_uri , stage_uri , domain_uri )
107
- print ("Done" )
108
-
109
- return
126
+ # The root executor should not wait, the partitions have "infinite allocation time. So we forcefully shut them down"
127
+ print ("All internal executors are done ... moving to stopping root job" )
128
+ # TODO: When I get here I need to kill all the jobs of the partitions and exit.
129
+ print ("Stopping domain partition..." )
130
+ domain_handle = flux .Flux (domain_uri )
131
+ domain_handle .rpc ("state-machine.wait" ).get ()
132
+ print ("Cancel " , domain_future .jobid ())
133
+ # fkill(domain_handle, domain_future.jobid())
134
+
135
+ if ml_uri :
136
+ print ("Stopiing ml partition" )
137
+ ml_handle = flux .Flux (ml_uri )
138
+ ml_handle .rpc ("state-machine.wait" ).get ()
139
+ if stage_uri :
140
+ print ("Stopping stager partition" )
141
+ stage_future .cancel ()
142
+ ml_handle = flux .Flux (stage_uri )
143
+ stage_handle .rpc ("state-machine.wait" ).get ()
144
+ print ("Shutting root executor" )
145
+ root_executor .shutdown (wait = False , cancel_futures = True )
146
+ print ("All daemons are down" )
147
+ print ("Exiting" )
148
+
149
+ return 0
110
150
111
151
112
152
if __name__ == "__main__" :
113
- main ()
153
+ sys . exit ( main () )
0 commit comments