Skip to content

Commit 27d3a21

Browse files
committed
Instrument stager with new AMSMonitoring capabilites
Signed-off-by: Loic Pottier <[email protected]>
1 parent 1f79f2f commit 27d3a21

File tree

1 file changed

+57
-27
lines changed

1 file changed

+57
-27
lines changed

src/AMSWorkflow/ams/stage.py

+57-27
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,13 @@
1010
import shutil
1111
import signal
1212
import time
13-
from abc import ABC, abstractclassmethod, abstractmethod
13+
from abc import ABC, abstractmethod
1414
from enum import Enum
1515
from multiprocessing import Process
1616
from multiprocessing import Queue as mp_queue
1717
from pathlib import Path
1818
from queue import Queue as ser_queue
1919
from threading import Thread
20-
from typing import Callable
21-
import warnings
2220

2321
import numpy as np
2422
from ams.config import AMSInstance
@@ -131,7 +129,7 @@ def __init__(self, db_path, db_store, name, i_queue, o_queue, user_obj):
131129
self.i_queue = i_queue
132130
self.o_queue = o_queue
133131
self.user_obj = user_obj
134-
self.datasize = 0
132+
self.datasize_byte = 0
135133

136134
@property
137135
def db_path(self):
@@ -167,7 +165,7 @@ def _model_update_cb(self, db, msg):
167165
_updated = self.user_obj.update_model_cb(domain, model)
168166
print(f"Model update status: {_updated}")
169167

170-
@AMSMonitor(record=["datasize"])
168+
@AMSMonitor(record=["datasize_byte"])
171169
def __call__(self):
172170
"""
173171
A busy loop reading messages from the i_queue, acting on those messages and forwarding
@@ -186,7 +184,7 @@ def __call__(self):
186184
data = item.data()
187185
inputs, outputs = self._data_cb(data)
188186
self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(inputs, outputs, data.domain_name)))
189-
self.datasize += inputs.nbytes + outputs.nbytes
187+
self.datasize_byte += inputs.nbytes + outputs.nbytes
190188
elif item.is_new_model():
191189
data = item.data()
192190
self._model_update_cb(db, data)
@@ -215,19 +213,21 @@ def __init__(self, o_queue, loader, pattern):
215213
self.o_queue = o_queue
216214
self.pattern = pattern
217215
self.loader = loader
218-
self.datasize = 0
216+
self.datasize_byte = 0
217+
self.total_time_ns = 0
219218

220-
@AMSMonitor(record=["datasize"])
219+
@AMSMonitor(array=["msgs"], record=["datasize_byte", "total_time_ns"])
221220
def __call__(self):
222221
"""
223222
Busy loop of reading all files matching the pattern and creating
224223
'100' batches which will be pushed on the queue. Upon reading all files
225224
the Task pushes a 'Terminate' message to the queue and returns.
226225
"""
227226

228-
start = time.time()
227+
start = time.time_ns()
229228
files = list(glob.glob(self.pattern))
230229
for fn in files:
230+
start_time_fs = time.time_ns()
231231
with self.loader(fn) as fd:
232232
domain_name, input_data, output_data = fd.load()
233233
print("Domain Name is", domain_name)
@@ -238,13 +238,29 @@ def __call__(self):
238238
output_batches = np.array_split(output_data, num_batches)
239239
for j, (i, o) in enumerate(zip(input_batches, output_batches)):
240240
self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(i, o, domain_name)))
241-
self.datasize += input_data.nbytes + output_data.nbytes
241+
self.datasize_byte += input_data.nbytes + output_data.nbytes
242+
243+
end_time_fs = time.time_ns()
244+
msg = {
245+
"file": fn,
246+
"domain_name": domain_name,
247+
"row_size": row_size,
248+
"batch_size": BATCH_SIZE,
249+
"rows_per_batch": rows_per_batch,
250+
"num_batches": num_batches,
251+
"size_bytes": input_data.nbytes + output_data.nbytes,
252+
"process_time_ns": end_time_fs - start_time_fs,
253+
}
254+
# msgs is a list that is managed by AMSMonitor, we simply append to it
255+
msgs.append(msg)
256+
242257
print(f"Sending Delete Message Type {self.__class__.__name__}")
243258
self.o_queue.put(QueueMessage(MessageType.Delete, fn))
244259
self.o_queue.put(QueueMessage(MessageType.Terminate, None))
245260

246-
end = time.time()
247-
print(f"Spend {end - start} at {self.__class__.__name__}")
261+
end = time.time_ns()
262+
self.total_time_ns += (end - start)
263+
print(f"Spend {(end - start)/1e9} at {self.__class__.__name__}")
248264

249265

250266
class RMQDomainDataLoaderTask(Task):
@@ -279,8 +295,8 @@ def __init__(
279295
self.cert = cert
280296
self.rmq_queue = rmq_queue
281297
self.prefetch_count = prefetch_count
282-
self.datasize = 0
283-
self.total_time = 0
298+
self.datasize_byte = 0
299+
self.total_time_ns = 0
284300
self.signals = signals
285301
self.orig_sig_handlers = {}
286302
self.policy = policy
@@ -314,25 +330,40 @@ def callback_close(self):
314330
print("Adding terminate message at queue:", self.o_queue)
315331
self.o_queue.put(QueueMessage(MessageType.Terminate, None))
316332

333+
@AMSMonitor(array=["msgs"], record=["datasize_byte", "total_time_ns"])
317334
def callback_message(self, ch, basic_deliver, properties, body):
318335
"""
319336
Callback that will be called each time a message will be consummed.
320337
the connection (or if a problem happened with the connection).
321338
"""
322-
start_time = time.time()
323-
domain_name, input_data, output_data = AMSMessage(body).decode()
339+
start_time = time.time_ns()
340+
msg = AMSMessage(body)
341+
domain_name, input_data, output_data = msg.decode()
324342
row_size = input_data[0, :].nbytes + output_data[0, :].nbytes
325343
rows_per_batch = int(np.ceil(BATCH_SIZE / row_size))
326344
num_batches = int(np.ceil(input_data.shape[0] / rows_per_batch))
327345
input_batches = np.array_split(input_data, num_batches)
328346
output_batches = np.array_split(output_data, num_batches)
329347

330-
self.datasize += input_data.nbytes + output_data.nbytes
348+
self.datasize_byte += input_data.nbytes + output_data.nbytes
331349

332350
for j, (i, o) in enumerate(zip(input_batches, output_batches)):
333351
self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(i, o, domain_name)))
334-
335-
self.total_time += time.time() - start_time
352+
end_time = time.time_ns()
353+
self.total_time_ns += (end_time - start_time)
354+
# TODO: Improve the code to manage potentially multiple messages per AMSMessage
355+
msg = {
356+
"delivery_tag": basic_deliver.delivery_tag,
357+
"mpi_rank": msg.mpi_rank,
358+
"domain_name": domain_name,
359+
"num_elements": msg.num_elements,
360+
"input_dim": msg.input_dim,
361+
"output_dim": msg.output_dim,
362+
"size_bytes": input_data.nbytes + output_data.nbytes,
363+
"ts_received": start_time,
364+
"ts_processed": end_time
365+
}
366+
msgs.append(msg)
336367

337368
def signal_wrapper(self, name, pid):
338369
def handler(signum, frame):
@@ -343,9 +374,8 @@ def handler(signum, frame):
343374

344375
def stop(self):
345376
self.rmq_consumer.stop()
346-
print(f"Spend {self.total_time} at {self.__class__.__name__}")
377+
print(f"Spend {self.total_time_ns/1e9} at {self.__class__.__name__}")
347378

348-
@AMSMonitor(record=["datasize", "total_time"])
349379
def __call__(self):
350380
"""
351381
Busy loop of consuming messages from RMQ queue
@@ -356,7 +386,7 @@ def __call__(self):
356386
signal.signal(s, self.signal_wrapper(self.__class__.__name__, os.getpid()))
357387
print(f"{self.__class__.__name__} PID is:", os.getpid())
358388
self.rmq_consumer.run()
359-
print("Returning")
389+
print(f"Returning from {self.__class__.__name__}")
360390

361391

362392
class RMQControlMessageTask(RMQDomainDataLoaderTask):
@@ -385,7 +415,7 @@ def callback_message(self, ch, basic_deliver, properties, body):
385415
if data["request_type"] == "done-training":
386416
self.o_queue.put(QueueMessage(MessageType.NewModel, data))
387417

388-
self.total_time += time.time() - start_time
418+
self.total_time_ns += time.time_ns() - start_time
389419

390420

391421
class FSWriteTask(Task):
@@ -410,7 +440,7 @@ def __init__(self, i_queue, o_queue, writer_cls, out_dir):
410440
self.o_queue = o_queue
411441
self.suffix = writer_cls.get_file_format_suffix()
412442

413-
@AMSMonitor(record=["datasize"])
443+
@AMSMonitor(record=["datasize_byte"])
414444
def __call__(self):
415445
"""
416446
A busy loop reading messages from the i_queue, writting the input,output data in a file
@@ -465,7 +495,7 @@ def __call__(self):
465495
del data_files[data.domain_name]
466496

467497
end = time.time()
468-
self.datasize = total_bytes_written
498+
self.datasize_byte = total_bytes_written
469499
print(f"Spend {end - start} {total_bytes_written} at {self.__class__.__name__}")
470500

471501

@@ -483,7 +513,7 @@ class PushToStore(Task):
483513

484514
def __init__(self, i_queue, ams_config, db_path, store):
485515
"""
486-
Tnitializes the PushToStore Task. It reads files from i_queue, if the file
516+
Initializes the PushToStore Task. It reads files from i_queue, if the file
487517
is not under db_path, it copies the file to this location and if store defined
488518
it makes the kosh-store aware about the existence of the file.
489519
"""
@@ -779,7 +809,7 @@ def add_cli_args(parser):
779809
parser.set_defaults(store=True)
780810
return
781811

782-
@abstractclassmethod
812+
@abstractmethod
783813
def from_cli(cls):
784814
pass
785815

0 commit comments

Comments
 (0)