Skip to content

Commit aadc4a7

Browse files
committed
Add AMSMonitor interface and unify both RMQ API (#32)
Signed-off-by: Loic Pottier <[email protected]>
1 parent 3e4b410 commit aadc4a7

File tree

7 files changed

+1002
-621
lines changed

7 files changed

+1002
-621
lines changed

pyproject.toml

+8-3
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,18 @@ exclude = [
7777
# E226: Missing white space around arithmetic operator
7878

7979
[tool.ruff]
80-
ignore = ["E501", "W503", "E226", "BLK100", "E203"]
80+
lint.ignore = ["E501", "E226", "E203"]
8181
show-fixes = true
82-
82+
exclude = [
83+
".git",
84+
"__pycache__",
85+
"*.egg-info",
86+
"build"
87+
]
8388
# change the default line length number or characters.
8489
line-length = 120
90+
lint.select = ['E', 'F', 'W', 'A', 'PLC', 'PLE', 'PLW', 'I', 'N', 'Q']
8591

8692
[tool.yapf]
8793
ignore = ["E501", "W503", "E226", "BLK100", "E203"]
8894
column_limit = 120
89-

src/AMSWorkflow/ams/monitor.py

+308
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other
3+
# AMSLib Project Developers
4+
#
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import datetime
8+
import json
9+
import logging
10+
import multiprocessing
11+
import threading
12+
import time
13+
from typing import Callable, List, Union
14+
15+
16+
class AMSMonitor:
17+
"""
18+
AMSMonitor can be used to decorate class methods and will
19+
record automatically the duration of the tasks in a hashmap
20+
with timestamp. The decorator will also automatically
21+
record the values of all attributes of the class.
22+
23+
class ExampleTask1(Task):
24+
def __init__(self):
25+
self.total_bytes = 0
26+
self.total_bytes2 = 0
27+
28+
# Example: we do not want to record total_bytes
29+
# but just total_bytes2
30+
#
31+
# @AMSMonitor() would record all attributes
32+
# (total_bytes and total_bytes2)
33+
#
34+
# @AMSMonitor(accumulate=True) would record all
35+
# attributes and accumulate their values
36+
# (sum up total_bytes instead of recording a
37+
# different total_bytes per invocation)
38+
39+
@AMSMonitor(record=["total_bytes2"])
40+
def __call__(self):
41+
i = 0
42+
with AMSMonitor(object=self, tag="while_loop"):
43+
while (i<=3):
44+
self.total_bytes += 10
45+
self.total_bytes2 = 1
46+
i += 1
47+
48+
Each time `ExampleTask1()` is being called, AMSMonitor will
49+
populate `_stats` as follows (showed with two calls here):
50+
{'ExampleTask1':
51+
{'while_loop':
52+
{'02/29/2024-19:27:53':
53+
{
54+
'total_bytes2': 1,
55+
'amsmonitor_duration': 4.004607439041138
56+
}
57+
}
58+
},
59+
{'__call__':
60+
{'02/29/2024-19:29:24':
61+
{
62+
'total_bytes2': 1,
63+
'amsmonitor_duration': 4.10461138
64+
}
65+
}
66+
}
67+
}
68+
69+
Attributes:
70+
record: attributes to record, if empty ([]) all attributes
71+
will be recorded.
72+
accumulate: If True, AMSMonitor will accumulate recorded
73+
data instead of recording a new timestamp for
74+
any subsequent call of AMSMonitor on the same method.
75+
For example, if we record a method from the class `A`
76+
with some attributes, `attr1`, `attr2`. The first
77+
time `AMSMonitor` is being called on `A`, a new entry
78+
for `A` is created and values for `attr1 and `attr2`
79+
are populated, if `ams_accumulate=True`, for each call
80+
to `AMSMonitor()` the value of `attr1` and `attr2` will
81+
be added to their previous values.
82+
object: Mandatory if using `with` statement, `object` is
83+
the main object should be provided (i.e., self).
84+
tag: Mandatory if using `with` statement, `tag` is the
85+
name that will appear in the record for that
86+
context manager statement.
87+
"""
88+
89+
_manager = multiprocessing.Manager()
90+
_stats = _manager.dict()
91+
_ts_format = "%m/%d/%Y-%H:%M:%S"
92+
_reserved_keys = ["amsmonitor_duration"]
93+
_lock = threading.Lock()
94+
_count = 0
95+
96+
def __init__(self, record=None, accumulate=False, obj=None, tag=None, logger: logging.Logger = None, **kwargs):
97+
self.accumulate = accumulate
98+
self.kwargs = kwargs
99+
self.record = record
100+
if not isinstance(record, list):
101+
self.record = None
102+
# We make sure we do not overwrite protected attributes managed by AMSMonitor
103+
if self.record:
104+
self.record = self._remove_reserved_keys(self.record)
105+
self.object = obj
106+
self.start_time = 0
107+
self.internal_ts = 0
108+
self.tag = tag
109+
AMSMonitor._count += 1
110+
self.logger = logger if logger else logging.getLogger(__name__)
111+
112+
def __str__(self) -> str:
113+
return AMSMonitor.info() if AMSMonitor._stats != {} else "{}"
114+
115+
def __repr__(self) -> str:
116+
return self.__str__()
117+
118+
def lock(self):
119+
AMSMonitor._lock.acquire()
120+
121+
def unlock(self):
122+
AMSMonitor._lock.release()
123+
124+
def __enter__(self):
125+
if not self.object or not self.tag:
126+
self.logger.error('missing parameter "object" or "tag" when using context manager syntax')
127+
return
128+
self.start_monitor()
129+
return self
130+
131+
def __exit__(self, exc_type, exc_val, exc_tb):
132+
self.stop_monitor()
133+
134+
@classmethod
135+
def info(cls) -> str:
136+
s = ""
137+
if cls._stats == {}:
138+
return "{}"
139+
for k, v in cls._stats.items():
140+
s += f"{k}\n"
141+
for i, j in v.items():
142+
s += f" {i}\n"
143+
for p, z in j.items():
144+
s += f" {p:<10}\n"
145+
for r, q in z.items():
146+
s += f" {r:<30} => {q}\n"
147+
return s.rstrip()
148+
149+
@classmethod
150+
@property
151+
def stats(cls):
152+
return AMSMonitor._stats
153+
154+
@classmethod
155+
@property
156+
def format_ts(cls):
157+
return AMSMonitor._ts_format
158+
159+
@classmethod
160+
def convert_ts(cls, ts: str) -> datetime.datetime:
161+
return datetime.strptime(ts, cls.format_ts)
162+
163+
@classmethod
164+
def json(cls, json_output: str):
165+
"""
166+
Write the collected metrics to a JSON file.
167+
"""
168+
with open(json_output, "w") as fp:
169+
# we have to use .copy() as DictProxy is not serializable
170+
json.dump(cls._stats.copy(), fp, indent=4)
171+
# To avoid partial line at the end of the file
172+
fp.write("\n")
173+
174+
def start_monitor(self, *args, **kwargs):
175+
self.start_time = time.time()
176+
self.internal_ts = datetime.datetime.now().strftime(self._ts_format)
177+
178+
def stop_monitor(self):
179+
end = time.time()
180+
class_name = self.object.__class__.__name__
181+
func_name = self.tag
182+
183+
new_data = vars(self.object)
184+
# Filter out multiprocessing which cannot be stored without causing RuntimeError
185+
new_data = self._filter_out_object(new_data)
186+
# We remove stuff we do not want (attribute of the calling class captured by vars())
187+
if self.record != []:
188+
new_data = self._filter(new_data, self.record)
189+
# We inject some data we want to record
190+
new_data["amsmonitor_duration"] = end - self.start_time
191+
self._update_db(new_data, class_name, func_name, self.internal_ts)
192+
193+
# We reinitialize some variables
194+
self.start_time = 0
195+
self.internal_ts = 0
196+
197+
def __call__(self, func: Callable):
198+
"""
199+
The main decorator.
200+
"""
201+
202+
def wrapper(*args, **kwargs):
203+
ts = datetime.datetime.now().strftime(self._ts_format)
204+
start = time.time()
205+
value = func(*args, **kwargs)
206+
end = time.time()
207+
if not hasattr(args[0], "__dict__"):
208+
return value
209+
class_name = args[0].__class__.__name__
210+
func_name = self.tag if self.tag else func.__name__
211+
new_data = vars(args[0])
212+
213+
# Filter out multiprocessing which cannot be stored without causing RuntimeError
214+
new_data = self._filter_out_object(new_data)
215+
216+
# We remove stuff we do not want (attribute of the calling class captured by vars())
217+
new_data = self._filter(new_data, self.record)
218+
new_data["amsmonitor_duration"] = end - start
219+
self._update_db(new_data, class_name, func_name, ts)
220+
return value
221+
222+
return wrapper
223+
224+
def _update_db(self, new_data: dict, class_name: str, func_name: str, ts: str):
225+
"""
226+
This function update the hashmap containing all the records.
227+
"""
228+
self.lock()
229+
if class_name not in AMSMonitor._stats:
230+
AMSMonitor._stats[class_name] = {}
231+
232+
if func_name not in AMSMonitor._stats[class_name]:
233+
temp = AMSMonitor._stats[class_name]
234+
temp.update({func_name: {}})
235+
AMSMonitor._stats[class_name] = temp
236+
temp = AMSMonitor._stats[class_name]
237+
238+
# We accumulate for each class with a different name
239+
if self.accumulate and temp[func_name] != {}:
240+
ts = self._get_ts(class_name, func_name)
241+
temp[func_name][ts] = self._acc(temp[func_name][ts], new_data)
242+
else:
243+
temp[func_name][ts] = {}
244+
for k, v in new_data.items():
245+
temp[func_name][ts][k] = v
246+
# This trick is needed because AMSMonitor._stats is a manager.dict (not shared memory)
247+
AMSMonitor._stats[class_name] = temp
248+
self.unlock()
249+
250+
def _remove_reserved_keys(self, d: Union[dict, List]) -> dict:
251+
for key in self._reserved_keys:
252+
if key in d:
253+
self.logger.warning(f"attribute {key} is protected and will be ignored ({d})")
254+
if isinstance(d, list):
255+
idx = d.index(key)
256+
d.pop(idx)
257+
elif isinstance(d, dict):
258+
del d[key]
259+
return d
260+
261+
def _acc(self, original: dict, new_data: dict) -> dict:
262+
"""
263+
Sum up element-wise two hashmaps (ignore fields that are not common)
264+
"""
265+
for k, v in new_data.items():
266+
# We accumalate variable internally managed by AMSMonitor (duration etc)
267+
if k in AMSMonitor._reserved_keys:
268+
original[k] = float(original[k]) + float(v)
269+
else:
270+
original[k] = v
271+
return original
272+
273+
def _filter_out_object(self, data: dict) -> dict:
274+
"""
275+
Filter out a hashmap to remove objects which can cause errors
276+
"""
277+
278+
def is_serializable(x):
279+
try:
280+
json.dumps(x)
281+
return True
282+
except (TypeError, OverflowError):
283+
return False
284+
285+
new_dict = {k: v for k, v in data.items() if is_serializable(v)}
286+
287+
return new_dict
288+
289+
def _filter(self, data: dict, keys: List[str]) -> dict:
290+
"""
291+
Filter out a hashmap to contains only keys listed by list of keys
292+
"""
293+
if not self.record:
294+
return data
295+
return {k: v for k, v in data.items() if k in keys}
296+
297+
def _get_ts(self, class_name: str, tag: str) -> str:
298+
"""
299+
Return initial timestamp for a given monitored function.
300+
"""
301+
ts = datetime.datetime.now().strftime(self._ts_format)
302+
if class_name not in AMSMonitor._stats or tag not in AMSMonitor._stats[class_name]:
303+
return ts
304+
305+
init_ts = list(AMSMonitor._stats[class_name][tag].keys())
306+
if len(init_ts) > 1:
307+
self.logger.warning(f"more than 1 timestamp detected for {class_name} / {tag}")
308+
return ts if init_ts == [] else init_ts[0]

0 commit comments

Comments
 (0)