Skip to content

Commit 540fa13

Browse files
committed
We now use the 2 bytes for padding in AMSMessage for sending the message ID from AMSlib
Signed-off-by: Loic Pottier <[email protected]>
1 parent e965aa3 commit 540fa13

File tree

9 files changed

+166
-116
lines changed

9 files changed

+166
-116
lines changed

benchmarks/ams_bench_db.cpp

-4
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,6 @@ int main(int argc, char **argv)
294294
return -1;
295295
}
296296

297-
if (dbType != AMSDBType::AMS_RMQ) {
298-
AMSConfigureFSDatabase(dbType, db_config);
299-
}
300-
301297
// -------------------------------------------------------------------------
302298
// AMS allocators setup
303299
// -------------------------------------------------------------------------

src/AMSWorkflow/ams/monitor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def reset(cls):
206206

207207
def start_monitor(self, *args, **kwargs):
208208
self.start_time = time.time_ns()
209-
self.internal_ts = str(time.time_ns())
209+
self.internal_ts = time.time_ns()
210210

211211
def stop_monitor(self):
212212
end = time.time_ns()
@@ -233,7 +233,7 @@ def __call__(self, func: Callable):
233233
"""
234234

235235
def wrapper(*args, **kwargs):
236-
ts = str(time.time_ns())
236+
ts = time.time_ns()
237237
start = time.time_ns()
238238

239239
if self.use_arrays:
@@ -364,11 +364,11 @@ def _filter(self, data: dict, keys: List[str]) -> dict:
364364
return data
365365
return {k: v for k, v in data.items() if k in keys}
366366

367-
def _get_ts(self, class_name: str, func: str) -> str:
367+
def _get_ts(self, class_name: str, func: str) -> int:
368368
"""
369369
Return initial timestamp for a given monitored function.
370370
"""
371-
ts = str(time.time_ns())
371+
ts = time.time_ns()
372372
if class_name not in AMSMonitor._stats or func not in AMSMonitor._stats[class_name]:
373373
return ts
374374

src/AMSWorkflow/ams/rmq.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,17 @@ def __init__(self, body: str):
3434
self.domain_names = []
3535
self.input_dim = None
3636
self.output_dim = None
37+
self.message_id = None
3738

3839
def __str__(self):
3940
dt = "float" if self.dtype_byte == 4 else 8
4041
if not self.dtype_byte:
4142
dt = None
42-
return f"AMSMessage(domain={self.domain_names}, #mpi={self.mpi_rank}, num_elements={self.num_elements}, datatype={dt}, input_dim={self.input_dim}, output_dim={self.output_dim})"
43+
s = f"AMSMessage(mpi_rank={self.mpi_rank}, "
44+
s += f"message_id={self.message_id}, domain={self.domain_names}, "
45+
s += f"num_elements={self.num_elements}, datatype={dt}, "
46+
s += f"input_dim={self.input_dim}, output_dim={self.output_dim})"
47+
return s
4348

4449
def __repr__(self):
4550
return self.__str__()
@@ -55,9 +60,9 @@ def header_format(self) -> str:
5560
- 4 bytes are the number of elements in the message. Limit max: 2^32 - 1
5661
- 2 bytes are the input dimension. Limit max: 65535
5762
- 2 bytes are the output dimension. Limit max: 65535
58-
- 2 bytes are for aligning memory to 8
63+
- 2 bytes are the message ID given by AMSlib (local to each MPI rank). Limit max: 65535
5964
60-
|_Header_|_Datatype_|_Rank_|_DomainSize_|_#elems_|_InDim_|_OutDim_|_Pad_|_DomainName_|.Real_Data.|
65+
|_Header_|_Datatype_|_Rank_|_DomainSize_|_#elems_|_InDim_|_OutDim_|_MessageID_|_DomainName_|.Real_Data.|
6166
6267
Then the data starts at byte 16 with the domain name, then the real data and
6368
is structured as pairs of input/outputs. Let K be the total number of elements,
@@ -75,7 +80,7 @@ def endianness(self) -> str:
7580
"""
7681
return "="
7782

78-
def encode(self, num_elem: int, domain_name: str, input_dim: int, output_dim: int, dtype_byte: int = 4) -> bytes:
83+
def encode(self, num_elem: int, domain_name: str, input_dim: int, output_dim: int, message_id: int , dtype_byte: int = 4) -> bytes:
7984
"""
8085
For debugging and testing purposes, this function encode a message identical to what AMS would send
8186
"""
@@ -87,8 +92,7 @@ def encode(self, num_elem: int, domain_name: str, input_dim: int, output_dim: in
8792
data = np.random.rand(num_elem * (input_dim + output_dim))
8893
domain_name_size = len(domain_name)
8994
domain_name = bytes(domain_name, "utf-8")
90-
padding = 0
91-
header_content = (hsize, dtype_byte, mpi_rank, domain_name_size, data.size, input_dim, output_dim, padding)
95+
header_content = (hsize, dtype_byte, mpi_rank, domain_name_size, data.size, input_dim, output_dim, message_id)
9296
# float or double
9397
msg_format = f"{header_format}{domain_name_size}s{data.size}{dt}"
9498
return struct.pack(msg_format, *header_content, domain_name, *data)
@@ -113,7 +117,7 @@ def _parse_header(self, body: str) -> dict:
113117
res["num_element"],
114118
res["input_dim"],
115119
res["output_dim"],
116-
res["padding"],
120+
res["message_id"],
117121
) = struct.unpack(fmt, body[:hsize])
118122
assert hsize == res["hsize"]
119123
assert res["datatype"] in [4, 8]
@@ -134,6 +138,7 @@ def _parse_header(self, body: str) -> dict:
134138
self.domain_name_size = int(res["domain_size"])
135139
self.input_dim = int(res["input_dim"])
136140
self.output_dim = int(res["output_dim"])
141+
self.message_id = int(res["message_id"])
137142

138143
return res
139144

src/AMSWorkflow/ams/stage.py

+24-10
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def __init__(self, o_queue, loader, pattern):
215215
self.loader = loader
216216
self.datasize = 0
217217

218-
@AMSMonitor(record=["datasize"])
218+
@AMSMonitor(array=["msgs"], record=["datasize"])
219219
def __call__(self):
220220
"""
221221
Busy loop of reading all files matching the pattern and creating
@@ -226,6 +226,7 @@ def __call__(self):
226226
start = time.time()
227227
files = list(glob.glob(self.pattern))
228228
for fn in files:
229+
start_time_fs = time.time_ns()
229230
with self.loader(fn) as fd:
230231
domain_name, input_data, output_data = fd.load()
231232
print("Domain Name is", domain_name)
@@ -237,14 +238,28 @@ def __call__(self):
237238
for j, (i, o) in enumerate(zip(input_batches, output_batches)):
238239
self.o_queue.put(QueueMessage(MessageType.Process, DataBlob(i, o, domain_name)))
239240
self.datasize += input_data.nbytes + output_data.nbytes
241+
242+
end_time_fs = time.time_ns()
243+
msg = {
244+
"file": fn,
245+
"domain_name": domain_name,
246+
"row_size": row_size,
247+
"batch_size": BATCH_SIZE,
248+
"rows_per_batch": rows_per_batch,
249+
"num_batches": num_batches,
250+
"size_bytes": input_data.nbytes + output_data.nbytes,
251+
"process_time_ns": end_time_fs - start_time_fs,
252+
}
253+
# Msgs is the array (list) we push to (managed by AMSMonitor)
254+
msgs.append(msg)
255+
240256
print(f"Sending Delete Message Type {self.__class__.__name__}")
241257
self.o_queue.put(QueueMessage(MessageType.Delete, fn))
242258
self.o_queue.put(QueueMessage(MessageType.Terminate, None))
243259

244260
end = time.time()
245261
print(f"Spend {end - start} at {self.__class__.__name__}")
246262

247-
248263
class RMQDomainDataLoaderTask(Task):
249264
"""
250265
A RMQDomainDataLoaderTask consumes 'AMSMessages' from RabbitMQ bundles the data of
@@ -283,8 +298,6 @@ def __init__(
283298
self.orig_sig_handlers = {}
284299
self.policy = policy
285300

286-
# Counter that get incremented when we receive a message
287-
self.internal_msg_cnt = 0
288301

289302
# Signals can only be used within the main thread
290303
if self.policy != "thread":
@@ -340,12 +353,10 @@ def callback_message(self, ch, basic_deliver, properties, body):
340353

341354
self.total_time += (end_time - start_time)
342355
# TODO: Improve the code to manage potentially multiple messages per AMSMessage
343-
# TODO: Right now the ID is not encoded in the AMSMessage by AMSlib
344-
# If order of messages matters we might have to encode it
345356
msg = {
346-
"id": self.internal_msg_cnt,
347357
"delivery_tag": basic_deliver.delivery_tag,
348358
"mpi_rank": msg.mpi_rank,
359+
"message_id": msg.message_id,
349360
"domain_name": domain_name,
350361
"num_elements": msg.num_elements,
351362
"input_dim": msg.input_dim,
@@ -356,7 +367,6 @@ def callback_message(self, ch, basic_deliver, properties, body):
356367
}
357368
# Msgs is the array (list) we push to (managed by AMSMonitor)
358369
msgs.append(msg)
359-
self.internal_msg_cnt += 1
360370

361371
def signal_wrapper(self, name, pid):
362372
def handler(signum, frame):
@@ -621,11 +631,15 @@ def handler(signum, frame):
621631
# This should only trigger RMQDomainDataLoaderTask
622632

623633
# TODO: I don't like this system to shutdown the pipeline on demand
624-
# It's extremely easy to mess thing up with signals.. and it's
634+
# It's extremely easy to mess things up with signals.. and it's
625635
# not a robust solution (if a task is not managing correctly SIGINT
626636
# the pipeline can explode)
627637
for e in self._executors:
628-
os.kill(e.pid, signal.SIGINT)
638+
if e is not None:
639+
try:
640+
os.kill(e.pid, signal.SIGINT)
641+
except Exception as e:
642+
print(f"Error: {e}")
629643
self.release_signals()
630644
return handler
631645

src/AMSWorkflow/ams_wf/AMSDBStage.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def main():
9393
print(f"End to End time spend : {end - start}")
9494

9595
if args.output_json is not None:
96-
print(f"{AMSMonitor.info()}")
96+
# print(f"{AMSMonitor.info()}")
9797
# Output profiling output to JSON (just as an example)
9898
AMSMonitor.json(args.output_json)
9999

0 commit comments

Comments
 (0)