Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AMS Monitoring and Benchmark #94

Closed
wants to merge 8 commits into from
2 changes: 1 addition & 1 deletion src/AMSWorkflow/ams/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def _parse_data(self, body: str, header_info: dict) -> Tuple[str, np.array, np.a
# Return input, output
return (domain_name, data[:, :idim], data[:, idim:])

def _decode(self, body: str) -> Tuple[np.array]:
def _decode(self, body: str) -> Tuple[str, np.array, np.array]:
input = []
output = []
# Multiple AMS messages could be packed in one RMQ message
Expand Down
54 changes: 49 additions & 5 deletions src/AMSWorkflow/ams/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ def __init__(
self.orig_sig_handlers = {}
self.policy = policy

# Counter that get incremented when we receive a message
self.internal_msg_cnt = 0

# Signals can only be used within the main thread
if self.policy != "thread":
# We ignore SIGTERM, SIGUSR1, SIGINT by default so later
Expand Down Expand Up @@ -312,13 +315,17 @@ def callback_close(self):
print("Adding terminate message at queue:", self.o_queue)
self.o_queue.put(QueueMessage(MessageType.Terminate, None))


@AMSMonitor(array=["msgs"], record=["datasize", "total_time"])
def callback_message(self, ch, basic_deliver, properties, body):
"""
Callback that will be called each time a message will be consummed.
the connection (or if a problem happened with the connection).
"""
start_time = time.time()
domain_name, input_data, output_data = AMSMessage(body).decode()
start_time = time.time_ns()
msg = AMSMessage(body)

domain_name, input_data, output_data = msg.decode()
row_size = input_data[0, :].nbytes + output_data[0, :].nbytes
rows_per_batch = int(np.ceil(BATCH_SIZE / row_size))
num_batches = int(np.ceil(input_data.shape[0] / rows_per_batch))
Expand All @@ -329,8 +336,27 @@ def callback_message(self, ch, basic_deliver, properties, body):

for j, (i, o) in enumerate(zip(input_batches, output_batches)):
self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(i, o, domain_name)))

self.total_time += time.time() - start_time
end_time = time.time_ns()

self.total_time += (end_time - start_time)
# TODO: Improve the code to manage potentially multiple messages per AMSMessage
# TODO: Right now the ID is not encoded in the AMSMessage by AMSlib
# If order of messages matters we might have to encode it
msg = {
"id": self.internal_msg_cnt,
"delivery_tag": basic_deliver.delivery_tag,
"mpi_rank": msg.mpi_rank,
"domain_name": domain_name,
"num_elements": msg.num_elements,
"input_dim": msg.input_dim,
"output_dim": msg.output_dim,
"size_bytes": input_data.nbytes + output_data.nbytes,
"ts_received": start_time,
"ts_processed": end_time
}
# Msgs is the array (list) we push to (managed by AMSMonitor)
msgs.append(msg)
self.internal_msg_cnt += 1

def signal_wrapper(self, name, pid):
def handler(signum, frame):
Expand All @@ -343,7 +369,6 @@ def stop(self):
self.rmq_consumer.stop()
print(f"Spend {self.total_time} at {self.__class__.__name__}")

@AMSMonitor(record=["datasize", "total_time"])
def __call__(self):
"""
Busy loop of consuming messages from RMQ queue
Expand Down Expand Up @@ -906,6 +931,25 @@ def __init__(
print("Received a data queue of", self._data_queue)
print("Received a model_update queue of", self._model_update_queue)

# FIXME: temporary solution to kill properly the stager when using srun
self.write_pid()

def write_pid(self):
"""
Write the PID of the current process in
a file. Append it to a file if the file
exists (multiple stagers could be running).

This is useful to kill the stager.

FIXME: this solution is not very clean or elegant
and can be improved. The simulations side could send a message
on a specific queue to signify that no more data will arrive for example.
"""

with open("ams-stagers.pid", 'a') as f:
f.write(f" {os.getpid()}")

Comment on lines +948 to +966
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this code. Do you need the pid to do kill -SIG <signal> PID ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, yes. I need the PID of the stager to send a signal when scheduling tasks with Slurm. That "fix" is only needed when using srun to start the stager because srun wraps its target process and creates another process with another PID. If you kill the srun PID the signal is not being captured properly. I have to write the internal PID somewhere if we want to exit cleanly

def get_load_task(self, o_queue, policy):
"""
Return a Task that loads data from the filesystem
Expand Down