Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/stager monitoring #97

Merged
merged 4 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 123 additions & 47 deletions src/AMSWorkflow/ams/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import multiprocessing
import threading
import time
from collections import deque
from typing import Callable, List, Union


Expand All @@ -27,7 +28,7 @@ def __init__(self):

# @AMSMonitor() would record all attributes
# (total_bytes and total_bytes2) and the duration
# of the block under the name amsmonitor_duration.
# of the block under the name amsmonitor_duration_ms.
# Each time the same block (function in class or
# predefined tag) is being monitored, AMSMonitor
# create a new record with a timestamp (see below).
Expand All @@ -36,7 +37,7 @@ def __init__(self):
# attributes but does not create a new record each
# time that block is being monitored, the first
# timestamp is always being used and only
# amsmonitor_duration is being accumulated.
# amsmonitor_duration_ms is being accumulated.
# The user-managed attributes (like total_bytes
# and total_bytes2 ) are not being accumulated.
# By default, accumulate=False.
Expand All @@ -53,21 +54,36 @@ def __call__(self):
self.total_bytes2 = 1
i += 1

# Example: We can also collect data at a finer grain
@AMSMonitor(array=["myarray"])
def f(self):
i = 0
while (i<2): myarray.append({"i":i})

Each time `ExampleTask1()` is being called, AMSMonitor will
populate `_stats` as follows (showed with two calls here):
{
"ExampleTask1": {
"while_loop": {
"02/29/2024-19:27:53": {
"total_bytes2": 30,
"amsmonitor_duration": 4.004607439041138
"amsmonitor_duration_ms": 4.004607439041138
}
},
"__call__": {
"02/29/2024-19:29:24": {
"total_bytes2": 30,
"amsmonitor_duration": 4.10461138
"amsmonitor_duration_ms": 4.10461138
}
},
"f": {
"myarray": [
{
"i": 0,
},
{
"i": 1,
}
}
}
}
Expand All @@ -76,12 +92,16 @@ def __call__(self):
record: attributes to record, if None, all attributes
will be recorded, except objects (e.g., multiprocessing.Queue)
which can cause problem. if empty ([]), no attributes will
be recorded, only amsmonitor_duration will be recorded.
be recorded, only amsmonitor_duration_ms will be recorded.
array: User can give a variable in which data can be accumulated over
function calls. For example, `@AMSMonitor(array=["msg"])`
give the possibilty to use the list `msg` within the decorated
function to accumalate data.
accumulate: If True, AMSMonitor will accumulate recorded
data instead of recording a new timestamp for
any subsequent call of AMSMonitor on the same method.
We accumulate only records managed by AMSMonitor, like
amsmonitor_duration. We do not accumulate records
amsmonitor_duration_ms. We do not accumulate records
from the monitored class/function.
obj: Mandatory if using `with` statement, `object` is
the main object should be provided (i.e., self).
Expand All @@ -93,11 +113,11 @@ def __call__(self):
_manager = multiprocessing.Manager()
_stats = _manager.dict()
_ts_format = "%m/%d/%Y-%H:%M:%S"
_reserved_keys = ["amsmonitor_duration"]
_reserved_keys = ["amsmonitor_duration_ms"]
_lock = threading.Lock()
_count = 0

def __init__(self, record=None, accumulate=False, obj=None, tag=None, logger: logging.Logger = None, **kwargs):
def __init__(self, record=None, array=[], accumulate=False, obj=None, tag=None, logger: logging.Logger = None, **kwargs):
self.accumulate = accumulate
self.kwargs = kwargs
self.record = record
Expand All @@ -113,6 +133,20 @@ def __init__(self, record=None, accumulate=False, obj=None, tag=None, logger: lo
AMSMonitor._count += 1
self.logger = logger if logger else logging.getLogger(__name__)

# Section to manage JSON array
if not isinstance(array, list):
array = [array]
self.array_names = array
self.variables_list = []
self.array_context = {}

for var in self.array_names:
self.variables_list.append(deque())
self.array_context[var] = self.variables_list[-1]

# convenient bool to know if we need to support appending operations to array
self.use_arrays = self.array_names != []

def __str__(self) -> str:
return AMSMonitor.info() if AMSMonitor._stats != {} else "{}"

Expand Down Expand Up @@ -142,14 +176,16 @@ def info(cls) -> str:
s = ""
if cls._stats == {}:
return "{}"
for k, v in cls._stats.items():
s += f"{k}\n"
for i, j in v.items():
s += f" {i}\n"
for p, z in j.items():
s += f" {p:<10}\n"
for r, q in z.items():
s += f" {r:<30} => {q}\n"
for class_name, func_calls in cls._stats.items():
s += f"{class_name}\n"
for func, categories in func_calls.items():
s += f" {func}\n"
for cat_name, elems in categories.items():
s += f" {cat_name:<10}\n"
for elem in elems:
for key, value in elem.items():
s += f" {key:<30} => {value}\n"
s += f"\n"
return s.rstrip()

@classmethod
Expand Down Expand Up @@ -184,11 +220,11 @@ def reset(cls):
cls.unlock()

def start_monitor(self, *args, **kwargs):
self.start_time = time.time()
self.internal_ts = datetime.datetime.now().strftime(self._ts_format)
self.start_time = time.time_ns()
self.internal_ts = time.time_ns()

def stop_monitor(self):
end = time.time()
end = time.time_ns()
class_name = self.object.__class__.__name__
func_name = self.tag

Expand All @@ -199,7 +235,7 @@ def stop_monitor(self):
if self.record != []:
new_data = self._filter(new_data, self.record)
# We inject some data we want to record
new_data["amsmonitor_duration"] = end - self.start_time
new_data["amsmonitor_duration_ms"] = end - self.start_time
self._update_db(new_data, class_name, func_name, self.internal_ts)

# We reinitialize some variables
Expand All @@ -212,22 +248,36 @@ def __call__(self, func: Callable):
"""

def wrapper(*args, **kwargs):
ts = datetime.datetime.now().strftime(self._ts_format)
start = time.time()
value = func(*args, **kwargs)
end = time.time()
ts = time.time_ns()
start = time.time_ns()

if self.use_arrays:
# Save copy of any global values that will be replaced.
saved_values = {key: func.__globals__[key] for key in self.array_context if key in func.__globals__}
func.__globals__.update(self.array_context)

try:
value = func(*args, **kwargs)
finally:
if self.use_arrays:
func.__globals__.update(saved_values) # Restore any replaced globals.

end = time.time_ns()
if not hasattr(args[0], "__dict__"):
return value
class_name = args[0].__class__.__name__
func_name = self.tag if self.tag else func.__name__
new_data = vars(args[0])

# Filter out multiprocessing which cannot be stored without causing RuntimeError
new_data = self._filter_out_object(new_data)

# We remove stuff we do not want (attribute of the calling class captured by vars())
new_data = self._filter(new_data, self.record)
new_data["amsmonitor_duration"] = end - start
if not self.use_arrays:
# new_data is a dict of value from vars(). It contains all the class variable etc
new_data = vars(args[0])
# Filter out multiprocessing which cannot be stored without causing RuntimeError
new_data = self._filter_out_object(new_data)

# We remove stuff we do not want (attribute of the calling class captured by vars())
new_data = self._filter(new_data, self.record)
new_data["amsmonitor_duration_ms"] = (end - start) / 1e6
else:
new_data = self.array_context
self._update_db(new_data, class_name, func_name, ts)
return value

Expand All @@ -237,29 +287,52 @@ def _update_db(self, new_data: dict, class_name: str, func_name: str, ts: str):
"""
This function update the hashmap containing all the records.
"""
if new_data == {}: return
AMSMonitor.lock()
if class_name not in AMSMonitor._stats:
AMSMonitor._stats[class_name] = {}

if func_name not in AMSMonitor._stats[class_name]:
temp = AMSMonitor._stats[class_name]
temp.update({func_name: {}})
temp.update({func_name: {"records" : []}})
AMSMonitor._stats[class_name] = temp
temp = AMSMonitor._stats[class_name]

# We accumulate for each class with a different name
if self.accumulate and temp[func_name] != {}:
ts = self._get_ts(class_name, func_name)
temp[func_name][ts] = self._acc(temp[func_name][ts], new_data)

# If we have to deal with arrays (if array != [])
# Note that if we record arrays for this class / function
# we do not record "records"
if self.use_arrays:
for tag in self.array_context:
# Each tag has a list of elems
while len(self.array_context[tag]) > 0:
# we remove the first elem to write it in the DB
elem = self.array_context[tag].popleft()
if tag not in temp[func_name]:
temp[func_name][tag] = []
temp[func_name][tag].append(elem)
else:
temp[func_name][ts] = {}
for k, v in new_data.items():
temp[func_name][ts][k] = v
# This trick is needed because AMSMonitor._stats is a manager.dict (not shared memory)
# We accumulate for each class with a different name
if self.accumulate and temp[func_name] != []:
ts = self._get_ts(class_name, func_name)
temp[func_name]["records"][ts] = self._acc(temp[func_name][ts], new_data)
else:
item = {'timestamp': ts}
for k, v in new_data.items():
item[k] = v
temp[func_name]["records"].append(item)

# This step is needed because AMSMonitor._stats is a manager.dict (not shared memory)
# by reassigning the dictionary, the manager.dict is notified of the change
AMSMonitor._stats[class_name] = temp
AMSMonitor.unlock()
# We flush the context array to receive the next chunk
# self.array_context = []

def _remove_reserved_keys(self, d: Union[dict, List]) -> dict:
"""
Remove all the reserved keys from the dict given as input.
"""
for key in self._reserved_keys:
if key in d:
self.logger.warning(f"attribute {key} is protected and will be ignored ({d})")
Expand Down Expand Up @@ -306,15 +379,18 @@ def _filter(self, data: dict, keys: List[str]) -> dict:
return data
return {k: v for k, v in data.items() if k in keys}

def _get_ts(self, class_name: str, tag: str) -> str:
def _get_ts(self, class_name: str, func: str) -> int:
"""
Return initial timestamp for a given monitored function.
"""
ts = datetime.datetime.now().strftime(self._ts_format)
if class_name not in AMSMonitor._stats or tag not in AMSMonitor._stats[class_name]:
ts = time.time_ns()
if class_name not in AMSMonitor._stats or func not in AMSMonitor._stats[class_name]:
return ts

init_ts = list(AMSMonitor._stats[class_name][tag].keys())
print(f"{class_name} {func} {AMSMonitor._stats}")
init_ts = AMSMonitor._stats[class_name][func]
if len(init_ts) == []:
return ts
if len(init_ts) > 1:
self.logger.warning(f"more than 1 timestamp detected for {class_name} / {tag}")
return ts if init_ts == [] else init_ts[0]
self.logger.warning(f"more than 1 timestamp detected for {class_name} / {func}")
return ts if init_ts == [] else init_ts[0]
12 changes: 6 additions & 6 deletions src/AMSWorkflow/ams/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,15 +350,15 @@ def callback_message(self, ch, basic_deliver, properties, body):
"""
Callback to be called each time the RMQ client consumes a message.
"""
start_time = time.time()
start_time = time.time_ns()
data = json.loads(body)

self.o_queue.put(QueueMessage(MessageType.Process, data))

self.num_messages += 1
self.total_time += time.time() - start_time
self.total_time_ns += time.time_ns() - start_time

@AMSMonitor(record=["total_time", "num_messages"])
@AMSMonitor(record=["total_time_ns", "num_messages"])
def __call__(self):
"""
Busy loop of consuming messages from RMQ queue
Expand Down Expand Up @@ -648,12 +648,12 @@ def callback_message(self, ch, basic_deliver, properties, body):
"""
Callback to be called each time the RMQ client consumes a message.
"""
start_time = time.time()
start_time = time.time_ns()
data = json.loads(body)
self.num_messages += 1
self.total_time += time.time() - start_time
self.total_time_ns += time.time_ns() - start_time

@AMSMonitor(record=["total_time", "num_messages"])
@AMSMonitor(record=["total_time_ns", "num_messages"])
def __call__(self):
"""
Busy loop of consuming messages from RMQ queue
Expand Down
13 changes: 8 additions & 5 deletions src/AMSWorkflow/ams/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,10 +893,11 @@ class AMSRMQConfiguration:
rabbitmq_user: str
rabbitmq_vhost: str
rabbitmq_cert: str
rabbitmq_inbound_queue: str
rabbitmq_outbound_queue: str
rabbitmq_ml_submit_queue: str
rabbitmq_ml_status_queue: str
rabbitmq_ml_submit_queue: str = ""
rabbitmq_ml_status_queue: str = ""
rabbitmq_exchange: str = "not-used"
rabbitmq_routing_key: str = ""

def __post_init__(self):
if not Path(self.rabbitmq_cert).exists():
Expand Down Expand Up @@ -926,7 +927,9 @@ def to_dict(self, AMSlib=False):
"rabbitmq-vhost": self.rabbitmq_vhost,
"rabbitmq-cert": self.rabbitmq_cert,
"rabbitmq-outbound-queue": self.rabbitmq_outbound_queue,
"rabbitmq-exchange": "not-used",
"rabbitmq-routing-key": "",
"rabbitmq-exchange": self.rabbitmq_exchange,
"rabbitmq-routing-key": self.rabbitmq_routing_key,
"rabbitmq-ml-submit-queue": self.rabbitmq_ml_submit_queue,
"rabbitmq-ml-status-queue": self.rabbitmq_ml_status_queue,
}
raise
Loading