Skip to content

Commit 41687c2

Browse files
authored
Add AMSMonitor interface and unify both RMQ API (#32) (#62)
Signed-off-by: Loic Pottier <[email protected]>
1 parent a102b31 commit 41687c2

File tree

7 files changed

+1006
-621
lines changed

7 files changed

+1006
-621
lines changed

pyproject.toml

+8-3
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,18 @@ exclude = [
7777
# E226: Missing white space around arithmetic operator
7878

7979
[tool.ruff]
80-
ignore = ["E501", "W503", "E226", "BLK100", "E203"]
80+
lint.ignore = ["E501", "E226", "E203"]
8181
show-fixes = true
82-
82+
exclude = [
83+
".git",
84+
"__pycache__",
85+
"*.egg-info",
86+
"build"
87+
]
8388
# change the default line length number or characters.
8489
line-length = 120
90+
lint.select = ['E', 'F', 'W', 'A', 'PLC', 'PLE', 'PLW', 'I', 'N', 'Q']
8591

8692
[tool.yapf]
8793
ignore = ["E501", "W503", "E226", "BLK100", "E203"]
8894
column_limit = 120
89-

src/AMSWorkflow/ams/monitor.py

+312
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
3+
# AMSLib Project Developers
4+
#
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import datetime
8+
import json
9+
import logging
10+
import multiprocessing
11+
import threading
12+
import time
13+
from typing import Callable, List, Union
14+
15+
16+
class AMSMonitor:
17+
"""
18+
AMSMonitor can be used to decorate class methods and will
19+
record automatically the duration of the tasks in a hashmap
20+
with timestamp. The decorator will also automatically
21+
record the values of all attributes of the class.
22+
23+
class ExampleTask1(Task):
24+
def __init__(self):
25+
self.total_bytes = 0
26+
self.total_bytes2 = 0
27+
28+
# @AMSMonitor() would record all attributes
29+
# (total_bytes and total_bytes2) and the duration
30+
# of the block under the name amsmonitor_duration.
31+
# Each time the same block (function in class or
32+
# predefined tag) is being monitored, AMSMonitor
33+
# create a new record with a timestamp (see below).
34+
#
35+
# @AMSMonitor(accumulate=True) records also all
36+
# attributes but does not create a new record each
37+
# time that block is being monitored, the first
38+
# timestamp is always being used and only
39+
# amsmonitor_duration is being accumulated.
40+
# The user-managed attributes (like total_bytes
41+
# and total_bytes2 ) are not being accumulated.
42+
# By default, accumulate=False.
43+
44+
# Example: we do not want to record total_bytes
45+
# but just total_bytes2
46+
@AMSMonitor(record=["total_bytes2"])
47+
def __call__(self):
48+
i = 0
49+
# Here we have to manually provide the current object being monitored
50+
with AMSMonitor(obj=self, tag="while_loop"):
51+
while (i<=3):
52+
self.total_bytes += 10
53+
self.total_bytes2 = 1
54+
i += 1
55+
56+
Each time `ExampleTask1()` is being called, AMSMonitor will
57+
populate `_stats` as follows (showed with two calls here):
58+
{
59+
"ExampleTask1": {
60+
"while_loop": {
61+
"02/29/2024-19:27:53": {
62+
"total_bytes2": 30,
63+
"amsmonitor_duration": 4.004607439041138
64+
}
65+
},
66+
"__call__": {
67+
"02/29/2024-19:29:24": {
68+
"total_bytes2": 30,
69+
"amsmonitor_duration": 4.10461138
70+
}
71+
}
72+
}
73+
}
74+
75+
Attributes:
76+
record: attributes to record, if None, all attributes
77+
will be recorded, except objects (e.g., multiprocessing.Queue)
78+
which can cause problem. if empty ([]), no attributes will
79+
be recorded, only amsmonitor_duration will be recorded.
80+
accumulate: If True, AMSMonitor will accumulate recorded
81+
data instead of recording a new timestamp for
82+
any subsequent call of AMSMonitor on the same method.
83+
We accumulate only records managed by AMSMonitor, like
84+
amsmonitor_duration. We do not accumulate records
85+
from the monitored class/function.
86+
obj: Mandatory if using `with` statement, `object` is
87+
the main object should be provided (i.e., self).
88+
tag: Mandatory if using `with` statement, `tag` is the
89+
name that will appear in the record for that
90+
context manager statement.
91+
"""
92+
93+
_manager = multiprocessing.Manager()
94+
_stats = _manager.dict()
95+
_ts_format = "%m/%d/%Y-%H:%M:%S"
96+
_reserved_keys = ["amsmonitor_duration"]
97+
_lock = threading.Lock()
98+
_count = 0
99+
100+
def __init__(self, record=None, accumulate=False, obj=None, tag=None, logger: logging.Logger = None, **kwargs):
101+
self.accumulate = accumulate
102+
self.kwargs = kwargs
103+
self.record = record
104+
if not isinstance(record, list):
105+
self.record = None
106+
# We make sure we do not overwrite protected attributes managed by AMSMonitor
107+
if self.record:
108+
self.record = self._remove_reserved_keys(self.record)
109+
self.object = obj
110+
self.start_time = 0
111+
self.internal_ts = 0
112+
self.tag = tag
113+
AMSMonitor._count += 1
114+
self.logger = logger if logger else logging.getLogger(__name__)
115+
116+
def __str__(self) -> str:
117+
return AMSMonitor.info() if AMSMonitor._stats != {} else "{}"
118+
119+
def __repr__(self) -> str:
120+
return self.__str__()
121+
122+
def lock(self):
123+
AMSMonitor._lock.acquire()
124+
125+
def unlock(self):
126+
AMSMonitor._lock.release()
127+
128+
def __enter__(self):
129+
if not self.object or not self.tag:
130+
self.logger.error('missing parameter "object" or "tag" when using context manager syntax')
131+
return
132+
self.start_monitor()
133+
return self
134+
135+
def __exit__(self, exc_type, exc_val, exc_tb):
136+
self.stop_monitor()
137+
138+
@classmethod
139+
def info(cls) -> str:
140+
s = ""
141+
if cls._stats == {}:
142+
return "{}"
143+
for k, v in cls._stats.items():
144+
s += f"{k}\n"
145+
for i, j in v.items():
146+
s += f" {i}\n"
147+
for p, z in j.items():
148+
s += f" {p:<10}\n"
149+
for r, q in z.items():
150+
s += f" {r:<30} => {q}\n"
151+
return s.rstrip()
152+
153+
@classmethod
154+
@property
155+
def stats(cls):
156+
return AMSMonitor._stats
157+
158+
@classmethod
159+
@property
160+
def format_ts(cls):
161+
return AMSMonitor._ts_format
162+
163+
@classmethod
164+
def convert_ts(cls, ts: str) -> datetime.datetime:
165+
return datetime.strptime(ts, cls.format_ts)
166+
167+
@classmethod
168+
def json(cls, json_output: str):
169+
"""
170+
Write the collected metrics to a JSON file.
171+
"""
172+
with open(json_output, "w") as fp:
173+
# we have to use .copy() as DictProxy is not serializable
174+
json.dump(cls._stats.copy(), fp, indent=4)
175+
# To avoid partial line at the end of the file
176+
fp.write("\n")
177+
178+
def start_monitor(self, *args, **kwargs):
179+
self.start_time = time.time()
180+
self.internal_ts = datetime.datetime.now().strftime(self._ts_format)
181+
182+
def stop_monitor(self):
183+
end = time.time()
184+
class_name = self.object.__class__.__name__
185+
func_name = self.tag
186+
187+
new_data = vars(self.object)
188+
# Filter out multiprocessing which cannot be stored without causing RuntimeError
189+
new_data = self._filter_out_object(new_data)
190+
# We remove stuff we do not want (attribute of the calling class captured by vars())
191+
if self.record != []:
192+
new_data = self._filter(new_data, self.record)
193+
# We inject some data we want to record
194+
new_data["amsmonitor_duration"] = end - self.start_time
195+
self._update_db(new_data, class_name, func_name, self.internal_ts)
196+
197+
# We reinitialize some variables
198+
self.start_time = 0
199+
self.internal_ts = 0
200+
201+
def __call__(self, func: Callable):
202+
"""
203+
The main decorator.
204+
"""
205+
206+
def wrapper(*args, **kwargs):
207+
ts = datetime.datetime.now().strftime(self._ts_format)
208+
start = time.time()
209+
value = func(*args, **kwargs)
210+
end = time.time()
211+
if not hasattr(args[0], "__dict__"):
212+
return value
213+
class_name = args[0].__class__.__name__
214+
func_name = self.tag if self.tag else func.__name__
215+
new_data = vars(args[0])
216+
217+
# Filter out multiprocessing which cannot be stored without causing RuntimeError
218+
new_data = self._filter_out_object(new_data)
219+
220+
# We remove stuff we do not want (attribute of the calling class captured by vars())
221+
new_data = self._filter(new_data, self.record)
222+
new_data["amsmonitor_duration"] = end - start
223+
self._update_db(new_data, class_name, func_name, ts)
224+
return value
225+
226+
return wrapper
227+
228+
def _update_db(self, new_data: dict, class_name: str, func_name: str, ts: str):
229+
"""
230+
This function update the hashmap containing all the records.
231+
"""
232+
self.lock()
233+
if class_name not in AMSMonitor._stats:
234+
AMSMonitor._stats[class_name] = {}
235+
236+
if func_name not in AMSMonitor._stats[class_name]:
237+
temp = AMSMonitor._stats[class_name]
238+
temp.update({func_name: {}})
239+
AMSMonitor._stats[class_name] = temp
240+
temp = AMSMonitor._stats[class_name]
241+
242+
# We accumulate for each class with a different name
243+
if self.accumulate and temp[func_name] != {}:
244+
ts = self._get_ts(class_name, func_name)
245+
temp[func_name][ts] = self._acc(temp[func_name][ts], new_data)
246+
else:
247+
temp[func_name][ts] = {}
248+
for k, v in new_data.items():
249+
temp[func_name][ts][k] = v
250+
# This trick is needed because AMSMonitor._stats is a manager.dict (not shared memory)
251+
AMSMonitor._stats[class_name] = temp
252+
self.unlock()
253+
254+
def _remove_reserved_keys(self, d: Union[dict, List]) -> dict:
255+
for key in self._reserved_keys:
256+
if key in d:
257+
self.logger.warning(f"attribute {key} is protected and will be ignored ({d})")
258+
if isinstance(d, list):
259+
idx = d.index(key)
260+
d.pop(idx)
261+
elif isinstance(d, dict):
262+
del d[key]
263+
return d
264+
265+
def _acc(self, original: dict, new_data: dict) -> dict:
266+
"""
267+
Sum up element-wise two hashmaps (ignore fields that are not common)
268+
"""
269+
for k, v in new_data.items():
270+
# We accumalate variable internally managed by AMSMonitor (duration etc)
271+
if k in AMSMonitor._reserved_keys:
272+
original[k] = float(original[k]) + float(v)
273+
else:
274+
original[k] = v
275+
return original
276+
277+
def _filter_out_object(self, data: dict) -> dict:
278+
"""
279+
Filter out a hashmap to remove objects which can cause errors
280+
"""
281+
282+
def is_serializable(x):
283+
try:
284+
json.dumps(x)
285+
return True
286+
except (TypeError, OverflowError):
287+
return False
288+
289+
new_dict = {k: v for k, v in data.items() if is_serializable(v)}
290+
291+
return new_dict
292+
293+
def _filter(self, data: dict, keys: List[str]) -> dict:
294+
"""
295+
Filter out a hashmap to contains only keys listed by list of keys
296+
"""
297+
if not self.record:
298+
return data
299+
return {k: v for k, v in data.items() if k in keys}
300+
301+
def _get_ts(self, class_name: str, tag: str) -> str:
302+
"""
303+
Return initial timestamp for a given monitored function.
304+
"""
305+
ts = datetime.datetime.now().strftime(self._ts_format)
306+
if class_name not in AMSMonitor._stats or tag not in AMSMonitor._stats[class_name]:
307+
return ts
308+
309+
init_ts = list(AMSMonitor._stats[class_name][tag].keys())
310+
if len(init_ts) > 1:
311+
self.logger.warning(f"more than 1 timestamp detected for {class_name} / {tag}")
312+
return ts if init_ts == [] else init_ts[0]

0 commit comments

Comments
 (0)