10
10
import shutil
11
11
import signal
12
12
import time
13
- from abc import ABC , abstractclassmethod , abstractmethod
13
+ from abc import ABC , abstractmethod
14
14
from enum import Enum
15
15
from multiprocessing import Process
16
16
from multiprocessing import Queue as mp_queue
17
17
from pathlib import Path
18
18
from queue import Queue as ser_queue
19
19
from threading import Thread
20
- from typing import Callable
21
- import warnings
22
20
23
21
import numpy as np
24
22
from ams .config import AMSInstance
@@ -131,7 +129,7 @@ def __init__(self, db_path, db_store, name, i_queue, o_queue, user_obj):
131
129
self .i_queue = i_queue
132
130
self .o_queue = o_queue
133
131
self .user_obj = user_obj
134
- self .datasize = 0
132
+ self .datasize_byte = 0
135
133
136
134
@property
137
135
def db_path (self ):
@@ -167,7 +165,7 @@ def _model_update_cb(self, db, msg):
167
165
_updated = self .user_obj .update_model_cb (domain , model )
168
166
print (f"Model update status: { _updated } " )
169
167
170
- @AMSMonitor (record = ["datasize " ])
168
+ @AMSMonitor (record = ["datasize_byte " ])
171
169
def __call__ (self ):
172
170
"""
173
171
A busy loop reading messages from the i_queue, acting on those messages and forwarding
@@ -186,7 +184,7 @@ def __call__(self):
186
184
data = item .data ()
187
185
inputs , outputs = self ._data_cb (data )
188
186
self .o_queue .put (QueueMessage (MessageType .Process , DataBlob (inputs , outputs , data .domain_name )))
189
- self .datasize += inputs .nbytes + outputs .nbytes
187
+ self .datasize_byte += inputs .nbytes + outputs .nbytes
190
188
elif item .is_new_model ():
191
189
data = item .data ()
192
190
self ._model_update_cb (db , data )
@@ -215,19 +213,21 @@ def __init__(self, o_queue, loader, pattern):
215
213
self .o_queue = o_queue
216
214
self .pattern = pattern
217
215
self .loader = loader
218
- self .datasize = 0
216
+ self .datasize_byte = 0
217
+ self .total_time_ns = 0
219
218
220
- @AMSMonitor (record = ["datasize " ])
219
+ @AMSMonitor (array = [ "msgs" ], record = ["datasize_byte" , "total_time_ns " ])
221
220
def __call__ (self ):
222
221
"""
223
222
Busy loop of reading all files matching the pattern and creating
224
223
'100' batches which will be pushed on the queue. Upon reading all files
225
224
the Task pushes a 'Terminate' message to the queue and returns.
226
225
"""
227
226
228
- start = time .time ()
227
+ start = time .time_ns ()
229
228
files = list (glob .glob (self .pattern ))
230
229
for fn in files :
230
+ start_time_fs = time .time_ns ()
231
231
with self .loader (fn ) as fd :
232
232
domain_name , input_data , output_data = fd .load ()
233
233
print ("Domain Name is" , domain_name )
@@ -238,13 +238,29 @@ def __call__(self):
238
238
output_batches = np .array_split (output_data , num_batches )
239
239
for j , (i , o ) in enumerate (zip (input_batches , output_batches )):
240
240
self .o_queue .put (QueueMessage (MessageType .Process , DataBlob (i , o , domain_name )))
241
- self .datasize += input_data .nbytes + output_data .nbytes
241
+ self .datasize_byte += input_data .nbytes + output_data .nbytes
242
+
243
+ end_time_fs = time .time_ns ()
244
+ msg = {
245
+ "file" : fn ,
246
+ "domain_name" : domain_name ,
247
+ "row_size" : row_size ,
248
+ "batch_size" : BATCH_SIZE ,
249
+ "rows_per_batch" : rows_per_batch ,
250
+ "num_batches" : num_batches ,
251
+ "size_bytes" : input_data .nbytes + output_data .nbytes ,
252
+ "process_time_ns" : end_time_fs - start_time_fs ,
253
+ }
254
+ # msgs is a list that is managed by AMSMonitor, we simply append to it
255
+ msgs .append (msg )
256
+
242
257
print (f"Sending Delete Message Type { self .__class__ .__name__ } " )
243
258
self .o_queue .put (QueueMessage (MessageType .Delete , fn ))
244
259
self .o_queue .put (QueueMessage (MessageType .Terminate , None ))
245
260
246
- end = time .time ()
247
- print (f"Spend { end - start } at { self .__class__ .__name__ } " )
261
+ end = time .time_ns ()
262
+ self .total_time_ns += (end - start )
263
+ print (f"Spend { (end - start )/ 1e9 } at { self .__class__ .__name__ } " )
248
264
249
265
250
266
class RMQDomainDataLoaderTask (Task ):
@@ -279,8 +295,8 @@ def __init__(
279
295
self .cert = cert
280
296
self .rmq_queue = rmq_queue
281
297
self .prefetch_count = prefetch_count
282
- self .datasize = 0
283
- self .total_time = 0
298
+ self .datasize_byte = 0
299
+ self .total_time_ns = 0
284
300
self .signals = signals
285
301
self .orig_sig_handlers = {}
286
302
self .policy = policy
@@ -314,25 +330,40 @@ def callback_close(self):
314
330
print ("Adding terminate message at queue:" , self .o_queue )
315
331
self .o_queue .put (QueueMessage (MessageType .Terminate , None ))
316
332
333
+ @AMSMonitor (array = ["msgs" ], record = ["datasize_byte" , "total_time_ns" ])
317
334
def callback_message (self , ch , basic_deliver , properties , body ):
318
335
"""
319
336
Callback that will be called each time a message will be consummed.
320
337
the connection (or if a problem happened with the connection).
321
338
"""
322
- start_time = time .time ()
323
- domain_name , input_data , output_data = AMSMessage (body ).decode ()
339
+ start_time = time .time_ns ()
340
+ msg = AMSMessage (body )
341
+ domain_name , input_data , output_data = msg .decode ()
324
342
row_size = input_data [0 , :].nbytes + output_data [0 , :].nbytes
325
343
rows_per_batch = int (np .ceil (BATCH_SIZE / row_size ))
326
344
num_batches = int (np .ceil (input_data .shape [0 ] / rows_per_batch ))
327
345
input_batches = np .array_split (input_data , num_batches )
328
346
output_batches = np .array_split (output_data , num_batches )
329
347
330
- self .datasize += input_data .nbytes + output_data .nbytes
348
+ self .datasize_byte += input_data .nbytes + output_data .nbytes
331
349
332
350
for j , (i , o ) in enumerate (zip (input_batches , output_batches )):
333
351
self .o_queue .put (QueueMessage (MessageType .Process , DataBlob (i , o , domain_name )))
334
-
335
- self .total_time += time .time () - start_time
352
+ end_time = time .time_ns ()
353
+ self .total_time_ns += (end_time - start_time )
354
+ # TODO: Improve the code to manage potentially multiple messages per AMSMessage
355
+ msg = {
356
+ "delivery_tag" : basic_deliver .delivery_tag ,
357
+ "mpi_rank" : msg .mpi_rank ,
358
+ "domain_name" : domain_name ,
359
+ "num_elements" : msg .num_elements ,
360
+ "input_dim" : msg .input_dim ,
361
+ "output_dim" : msg .output_dim ,
362
+ "size_bytes" : input_data .nbytes + output_data .nbytes ,
363
+ "ts_received" : start_time ,
364
+ "ts_processed" : end_time
365
+ }
366
+ msgs .append (msg )
336
367
337
368
def signal_wrapper (self , name , pid ):
338
369
def handler (signum , frame ):
@@ -343,9 +374,8 @@ def handler(signum, frame):
343
374
344
375
def stop (self ):
345
376
self .rmq_consumer .stop ()
346
- print (f"Spend { self .total_time } at { self .__class__ .__name__ } " )
377
+ print (f"Spend { self .total_time_ns / 1e9 } at { self .__class__ .__name__ } " )
347
378
348
- @AMSMonitor (record = ["datasize" , "total_time" ])
349
379
def __call__ (self ):
350
380
"""
351
381
Busy loop of consuming messages from RMQ queue
@@ -356,7 +386,7 @@ def __call__(self):
356
386
signal .signal (s , self .signal_wrapper (self .__class__ .__name__ , os .getpid ()))
357
387
print (f"{ self .__class__ .__name__ } PID is:" , os .getpid ())
358
388
self .rmq_consumer .run ()
359
- print ("Returning" )
389
+ print (f "Returning from { self . __class__ . __name__ } " )
360
390
361
391
362
392
class RMQControlMessageTask (RMQDomainDataLoaderTask ):
@@ -385,7 +415,7 @@ def callback_message(self, ch, basic_deliver, properties, body):
385
415
if data ["request_type" ] == "done-training" :
386
416
self .o_queue .put (QueueMessage (MessageType .NewModel , data ))
387
417
388
- self .total_time += time .time () - start_time
418
+ self .total_time_ns += time .time_ns () - start_time
389
419
390
420
391
421
class FSWriteTask (Task ):
@@ -410,7 +440,7 @@ def __init__(self, i_queue, o_queue, writer_cls, out_dir):
410
440
self .o_queue = o_queue
411
441
self .suffix = writer_cls .get_file_format_suffix ()
412
442
413
- @AMSMonitor (record = ["datasize " ])
443
+ @AMSMonitor (record = ["datasize_byte " ])
414
444
def __call__ (self ):
415
445
"""
416
446
A busy loop reading messages from the i_queue, writting the input,output data in a file
@@ -465,7 +495,7 @@ def __call__(self):
465
495
del data_files [data .domain_name ]
466
496
467
497
end = time .time ()
468
- self .datasize = total_bytes_written
498
+ self .datasize_byte = total_bytes_written
469
499
print (f"Spend { end - start } { total_bytes_written } at { self .__class__ .__name__ } " )
470
500
471
501
@@ -483,7 +513,7 @@ class PushToStore(Task):
483
513
484
514
def __init__ (self , i_queue , ams_config , db_path , store ):
485
515
"""
486
- Tnitializes the PushToStore Task. It reads files from i_queue, if the file
516
+ Initializes the PushToStore Task. It reads files from i_queue, if the file
487
517
is not under db_path, it copies the file to this location and if store defined
488
518
it makes the kosh-store aware about the existence of the file.
489
519
"""
@@ -779,7 +809,7 @@ def add_cli_args(parser):
779
809
parser .set_defaults (store = True )
780
810
return
781
811
782
- @abstractclassmethod
812
+ @abstractmethod
783
813
def from_cli (cls ):
784
814
pass
785
815
0 commit comments