|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other |
| 3 | +# AMSLib Project Developers |
| 4 | +# |
| 5 | +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | + |
| 7 | +import datetime |
| 8 | +import json |
| 9 | +import logging |
| 10 | +import multiprocessing |
| 11 | +import threading |
| 12 | +import time |
| 13 | +from typing import Callable, List, Union |
| 14 | + |
| 15 | + |
| 16 | +class AMSMonitor: |
| 17 | + """ |
| 18 | + AMSMonitor can be used to decorate class methods and will |
| 19 | + record automatically the duration of the tasks in a hashmap |
| 20 | + with timestamp. The decorator will also automatically |
| 21 | + record the values of all attributes of the class. |
| 22 | +
|
| 23 | + class ExampleTask1(Task): |
| 24 | + def __init__(self): |
| 25 | + self.total_bytes = 0 |
| 26 | + self.total_bytes2 = 0 |
| 27 | +
|
| 28 | + # Example: we do not want to record total_bytes |
| 29 | + # but just total_bytes2 |
| 30 | + # |
| 31 | + # @AMSMonitor() would record all attributes |
| 32 | + # (total_bytes and total_bytes2) |
| 33 | + # |
| 34 | + # @AMSMonitor(accumulate=True) would record all |
| 35 | + # attributes and accumulate their values |
| 36 | + # (sum up total_bytes instead of recording a |
| 37 | + # different total_bytes per invocation) |
| 38 | +
|
| 39 | + @AMSMonitor(record=["total_bytes2"]) |
| 40 | + def __call__(self): |
| 41 | + i = 0 |
| 42 | + with AMSMonitor(object=self, tag="while_loop"): |
| 43 | + while (i<=3): |
| 44 | + self.total_bytes += 10 |
| 45 | + self.total_bytes2 = 1 |
| 46 | + i += 1 |
| 47 | +
|
| 48 | + Each time `ExampleTask1()` is being called, AMSMonitor will |
| 49 | + populate `_stats` as follows (showed with two calls here): |
| 50 | + {'ExampleTask1': |
| 51 | + {'while_loop': |
| 52 | + {'02/29/2024-19:27:53': |
| 53 | + { |
| 54 | + 'total_bytes2': 1, |
| 55 | + 'amsmonitor_duration': 4.004607439041138 |
| 56 | + } |
| 57 | + } |
| 58 | + }, |
| 59 | + {'__call__': |
| 60 | + {'02/29/2024-19:29:24': |
| 61 | + { |
| 62 | + 'total_bytes2': 1, |
| 63 | + 'amsmonitor_duration': 4.10461138 |
| 64 | + } |
| 65 | + } |
| 66 | + } |
| 67 | + } |
| 68 | +
|
| 69 | + Attributes: |
| 70 | + record: attributes to record, if empty ([]) all attributes |
| 71 | + will be recorded. |
| 72 | + accumulate: If True, AMSMonitor will accumulate recorded |
| 73 | + data instead of recording a new timestamp for |
| 74 | + any subsequent call of AMSMonitor on the same method. |
| 75 | + For example, if we record a method from the class `A` |
| 76 | + with some attributes, `attr1`, `attr2`. The first |
| 77 | + time `AMSMonitor` is being called on `A`, a new entry |
| 78 | + for `A` is created and values for `attr1 and `attr2` |
| 79 | + are populated, if `ams_accumulate=True`, for each call |
| 80 | + to `AMSMonitor()` the value of `attr1` and `attr2` will |
| 81 | + be added to their previous values. |
| 82 | + object: Mandatory if using `with` statement, `object` is |
| 83 | + the main object should be provided (i.e., self). |
| 84 | + tag: Mandatory if using `with` statement, `tag` is the |
| 85 | + name that will appear in the record for that |
| 86 | + context manager statement. |
| 87 | + """ |
| 88 | + |
| 89 | + _manager = multiprocessing.Manager() |
| 90 | + _stats = _manager.dict() |
| 91 | + _ts_format = "%m/%d/%Y-%H:%M:%S" |
| 92 | + _reserved_keys = ["amsmonitor_duration"] |
| 93 | + _lock = threading.Lock() |
| 94 | + _count = 0 |
| 95 | + |
| 96 | + def __init__(self, record=None, accumulate=False, obj=None, tag=None, logger: logging.Logger = None, **kwargs): |
| 97 | + self.accumulate = accumulate |
| 98 | + self.kwargs = kwargs |
| 99 | + self.record = record |
| 100 | + if not isinstance(record, list): |
| 101 | + self.record = None |
| 102 | + # We make sure we do not overwrite protected attributes managed by AMSMonitor |
| 103 | + if self.record: |
| 104 | + self.record = self._remove_reserved_keys(self.record) |
| 105 | + self.object = obj |
| 106 | + self.start_time = 0 |
| 107 | + self.internal_ts = 0 |
| 108 | + self.tag = tag |
| 109 | + AMSMonitor._count += 1 |
| 110 | + self.logger = logger if logger else logging.getLogger(__name__) |
| 111 | + |
| 112 | + def __str__(self) -> str: |
| 113 | + return AMSMonitor.info() if AMSMonitor._stats != {} else "{}" |
| 114 | + |
| 115 | + def __repr__(self) -> str: |
| 116 | + return self.__str__() |
| 117 | + |
| 118 | + def lock(self): |
| 119 | + AMSMonitor._lock.acquire() |
| 120 | + |
| 121 | + def unlock(self): |
| 122 | + AMSMonitor._lock.release() |
| 123 | + |
| 124 | + def __enter__(self): |
| 125 | + if not self.object or not self.tag: |
| 126 | + self.logger.error('missing parameter "object" or "tag" when using context manager syntax') |
| 127 | + return |
| 128 | + self.start_monitor() |
| 129 | + return self |
| 130 | + |
| 131 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 132 | + self.stop_monitor() |
| 133 | + |
| 134 | + @classmethod |
| 135 | + def info(cls) -> str: |
| 136 | + s = "" |
| 137 | + if cls._stats == {}: |
| 138 | + return "{}" |
| 139 | + for k, v in cls._stats.items(): |
| 140 | + s += f"{k}\n" |
| 141 | + for i, j in v.items(): |
| 142 | + s += f" {i}\n" |
| 143 | + for p, z in j.items(): |
| 144 | + s += f" {p:<10}\n" |
| 145 | + for r, q in z.items(): |
| 146 | + s += f" {r:<30} => {q}\n" |
| 147 | + return s.rstrip() |
| 148 | + |
| 149 | + @classmethod |
| 150 | + @property |
| 151 | + def stats(cls): |
| 152 | + return AMSMonitor._stats |
| 153 | + |
| 154 | + @classmethod |
| 155 | + @property |
| 156 | + def format_ts(cls): |
| 157 | + return AMSMonitor._ts_format |
| 158 | + |
| 159 | + @classmethod |
| 160 | + def convert_ts(cls, ts: str) -> datetime.datetime: |
| 161 | + return datetime.strptime(ts, cls.format_ts) |
| 162 | + |
| 163 | + @classmethod |
| 164 | + def json(cls, json_output: str): |
| 165 | + """ |
| 166 | + Write the collected metrics to a JSON file. |
| 167 | + """ |
| 168 | + with open(json_output, "w") as fp: |
| 169 | + # we have to use .copy() as DictProxy is not serializable |
| 170 | + json.dump(cls._stats.copy(), fp, indent=4) |
| 171 | + # To avoid partial line at the end of the file |
| 172 | + fp.write("\n") |
| 173 | + |
| 174 | + def start_monitor(self, *args, **kwargs): |
| 175 | + self.start_time = time.time() |
| 176 | + self.internal_ts = datetime.datetime.now().strftime(self._ts_format) |
| 177 | + |
| 178 | + def stop_monitor(self): |
| 179 | + end = time.time() |
| 180 | + class_name = self.object.__class__.__name__ |
| 181 | + func_name = self.tag |
| 182 | + |
| 183 | + new_data = vars(self.object) |
| 184 | + # Filter out multiprocessing which cannot be stored without causing RuntimeError |
| 185 | + new_data = self._filter_out_object(new_data) |
| 186 | + # We remove stuff we do not want (attribute of the calling class captured by vars()) |
| 187 | + if self.record != []: |
| 188 | + new_data = self._filter(new_data, self.record) |
| 189 | + # We inject some data we want to record |
| 190 | + new_data["amsmonitor_duration"] = end - self.start_time |
| 191 | + self._update_db(new_data, class_name, func_name, self.internal_ts) |
| 192 | + |
| 193 | + # We reinitialize some variables |
| 194 | + self.start_time = 0 |
| 195 | + self.internal_ts = 0 |
| 196 | + |
| 197 | + def __call__(self, func: Callable): |
| 198 | + """ |
| 199 | + The main decorator. |
| 200 | + """ |
| 201 | + |
| 202 | + def wrapper(*args, **kwargs): |
| 203 | + ts = datetime.datetime.now().strftime(self._ts_format) |
| 204 | + start = time.time() |
| 205 | + value = func(*args, **kwargs) |
| 206 | + end = time.time() |
| 207 | + if not hasattr(args[0], "__dict__"): |
| 208 | + return value |
| 209 | + class_name = args[0].__class__.__name__ |
| 210 | + func_name = self.tag if self.tag else func.__name__ |
| 211 | + new_data = vars(args[0]) |
| 212 | + |
| 213 | + # Filter out multiprocessing which cannot be stored without causing RuntimeError |
| 214 | + new_data = self._filter_out_object(new_data) |
| 215 | + |
| 216 | + # We remove stuff we do not want (attribute of the calling class captured by vars()) |
| 217 | + new_data = self._filter(new_data, self.record) |
| 218 | + new_data["amsmonitor_duration"] = end - start |
| 219 | + self._update_db(new_data, class_name, func_name, ts) |
| 220 | + return value |
| 221 | + |
| 222 | + return wrapper |
| 223 | + |
| 224 | + def _update_db(self, new_data: dict, class_name: str, func_name: str, ts: str): |
| 225 | + """ |
| 226 | + This function update the hashmap containing all the records. |
| 227 | + """ |
| 228 | + self.lock() |
| 229 | + if class_name not in AMSMonitor._stats: |
| 230 | + AMSMonitor._stats[class_name] = {} |
| 231 | + |
| 232 | + if func_name not in AMSMonitor._stats[class_name]: |
| 233 | + temp = AMSMonitor._stats[class_name] |
| 234 | + temp.update({func_name: {}}) |
| 235 | + AMSMonitor._stats[class_name] = temp |
| 236 | + temp = AMSMonitor._stats[class_name] |
| 237 | + |
| 238 | + # We accumulate for each class with a different name |
| 239 | + if self.accumulate and temp[func_name] != {}: |
| 240 | + ts = self._get_ts(class_name, func_name) |
| 241 | + temp[func_name][ts] = self._acc(temp[func_name][ts], new_data) |
| 242 | + else: |
| 243 | + temp[func_name][ts] = {} |
| 244 | + for k, v in new_data.items(): |
| 245 | + temp[func_name][ts][k] = v |
| 246 | + # This trick is needed because AMSMonitor._stats is a manager.dict (not shared memory) |
| 247 | + AMSMonitor._stats[class_name] = temp |
| 248 | + self.unlock() |
| 249 | + |
| 250 | + def _remove_reserved_keys(self, d: Union[dict, List]) -> dict: |
| 251 | + for key in self._reserved_keys: |
| 252 | + if key in d: |
| 253 | + self.logger.warning(f"attribute {key} is protected and will be ignored ({d})") |
| 254 | + if isinstance(d, list): |
| 255 | + idx = d.index(key) |
| 256 | + d.pop(idx) |
| 257 | + elif isinstance(d, dict): |
| 258 | + del d[key] |
| 259 | + return d |
| 260 | + |
| 261 | + def _acc(self, original: dict, new_data: dict) -> dict: |
| 262 | + """ |
| 263 | + Sum up element-wise two hashmaps (ignore fields that are not common) |
| 264 | + """ |
| 265 | + for k, v in new_data.items(): |
| 266 | + # We accumalate variable internally managed by AMSMonitor (duration etc) |
| 267 | + if k in AMSMonitor._reserved_keys: |
| 268 | + original[k] = float(original[k]) + float(v) |
| 269 | + else: |
| 270 | + original[k] = v |
| 271 | + return original |
| 272 | + |
| 273 | + def _filter_out_object(self, data: dict) -> dict: |
| 274 | + """ |
| 275 | + Filter out a hashmap to remove objects which can cause errors |
| 276 | + """ |
| 277 | + |
| 278 | + def is_serializable(x): |
| 279 | + try: |
| 280 | + json.dumps(x) |
| 281 | + return True |
| 282 | + except (TypeError, OverflowError): |
| 283 | + return False |
| 284 | + |
| 285 | + new_dict = {k: v for k, v in data.items() if is_serializable(v)} |
| 286 | + |
| 287 | + return new_dict |
| 288 | + |
| 289 | + def _filter(self, data: dict, keys: List[str]) -> dict: |
| 290 | + """ |
| 291 | + Filter out a hashmap to contains only keys listed by list of keys |
| 292 | + """ |
| 293 | + if not self.record: |
| 294 | + return data |
| 295 | + return {k: v for k, v in data.items() if k in keys} |
| 296 | + |
| 297 | + def _get_ts(self, class_name: str, tag: str) -> str: |
| 298 | + """ |
| 299 | + Return initial timestamp for a given monitored function. |
| 300 | + """ |
| 301 | + ts = datetime.datetime.now().strftime(self._ts_format) |
| 302 | + if class_name not in AMSMonitor._stats or tag not in AMSMonitor._stats[class_name]: |
| 303 | + return ts |
| 304 | + |
| 305 | + init_ts = list(AMSMonitor._stats[class_name][tag].keys()) |
| 306 | + if len(init_ts) > 1: |
| 307 | + self.logger.warning(f"more than 1 timestamp detected for {class_name} / {tag}") |
| 308 | + return ts if init_ts == [] else init_ts[0] |
0 commit comments