|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 |
| - |
15 |
| -from typing import Dict, List, Optional, Union |
| 14 | +import logging |
| 15 | +import textwrap |
| 16 | +from typing import Dict, List, Optional, Tuple, Union |
16 | 17 |
|
17 | 18 | from pydantic import BaseModel, Field
|
18 | 19 |
|
|
28 | 29 | "ParameterComponent",
|
29 | 30 | ]
|
30 | 31 |
|
| 32 | +_LOGGER = logging.getLogger(__name__) |
| 33 | + |
31 | 34 |
|
32 | 35 | class PropertyBaseModel(BaseModel):
|
33 | 36 | """
|
@@ -124,12 +127,16 @@ class ZeroNonZeroParams(PropertyBaseModel):
|
124 | 127 |
|
125 | 128 | @property
|
126 | 129 | def sparsity(self):
|
127 |
| - total_values = self.non_zero + self.zero |
| 130 | + total_values = self.total |
128 | 131 | if total_values > 0:
|
129 | 132 | return self.zero / total_values
|
130 | 133 | else:
|
131 | 134 | return 0
|
132 | 135 |
|
| 136 | + @property |
| 137 | + def total(self): |
| 138 | + return self.non_zero + self.zero |
| 139 | + |
133 | 140 |
|
134 | 141 | class DenseSparseOps(PropertyBaseModel):
|
135 | 142 | """
|
@@ -221,3 +228,287 @@ class ParameterComponent(BaseModel):
|
221 | 228 | description="A summary of the parameter"
|
222 | 229 | )
|
223 | 230 | dtype: str = Field(description="The data type of the parameter")
|
| 231 | + |
| 232 | + |
| 233 | +class Entry(BaseModel): |
| 234 | + """ |
| 235 | + A BaseModel with subtraction and pretty_print support |
| 236 | + """ |
| 237 | + |
| 238 | + _print_order: List[str] = [] |
| 239 | + |
| 240 | + def __sub__(self, other): |
| 241 | + """ |
| 242 | + Allows base functionality for all inheriting classes to be subtract-able, |
| 243 | + subtracts the fields of self with other while providing some additional |
| 244 | + support for string and unrolling list type fields |
| 245 | + """ |
| 246 | + my_fields = self.__fields__ |
| 247 | + other_fields = other.__fields__ |
| 248 | + |
| 249 | + assert list(my_fields) == list(other_fields) |
| 250 | + new_fields = {} |
| 251 | + for field in my_fields: |
| 252 | + if field.startswith("_"): |
| 253 | + # ignore private fields |
| 254 | + continue |
| 255 | + my_value = getattr(self, field) |
| 256 | + other_value = getattr(other, field) |
| 257 | + |
| 258 | + assert type(my_value) == type(other_value) |
| 259 | + if field == "section_name": |
| 260 | + new_fields[field] = my_value |
| 261 | + elif isinstance(my_value, str): |
| 262 | + new_fields[field] = ( |
| 263 | + my_value |
| 264 | + if my_value == other_value |
| 265 | + else f"{my_value} - {other_value}" |
| 266 | + ) |
| 267 | + elif isinstance(my_value, list): |
| 268 | + new_fields[field] = [ |
| 269 | + item_a - item_b for item_a, item_b in zip(my_value, other_value) |
| 270 | + ] |
| 271 | + else: |
| 272 | + new_fields[field] = my_value - other_value |
| 273 | + |
| 274 | + return self.__class__(**new_fields) |
| 275 | + |
| 276 | + def pretty_print(self, headers: bool = False, column_width=30): |
| 277 | + """ |
| 278 | + pretty print current Entry object with all it's fields |
| 279 | + """ |
| 280 | + field_names = self._print_order |
| 281 | + field_values = [] |
| 282 | + for field_name in field_names: |
| 283 | + field_value = getattr(self, field_name) |
| 284 | + if isinstance(field_value, float): |
| 285 | + field_value = f"{field_value:.2f}" |
| 286 | + field_values.append(field_value) |
| 287 | + |
| 288 | + if headers: |
| 289 | + print( |
| 290 | + multiline_pretty_print( |
| 291 | + row=[field_name.upper() for field_name in field_names], |
| 292 | + column_width=column_width, |
| 293 | + ) |
| 294 | + ) |
| 295 | + print(multiline_pretty_print(row=field_values, column_width=column_width)) |
| 296 | + |
| 297 | + |
| 298 | +class BaseEntry(Entry): |
| 299 | + """ |
| 300 | + The BaseModel representing a row entry |
| 301 | +
|
| 302 | + :param sparsity: A float between 0-100 representing sparsity percentage |
| 303 | + :param quantized: A float between 0-100 representing quantized percentage |
| 304 | + """ |
| 305 | + |
| 306 | + sparsity: float |
| 307 | + quantized: float |
| 308 | + |
| 309 | + _print_order = ["sparsity", "quantized"] |
| 310 | + |
| 311 | + |
| 312 | +class NamedEntry(BaseEntry): |
| 313 | + """ |
| 314 | + BaseEntry with additional info like name, total and size |
| 315 | + """ |
| 316 | + |
| 317 | + name: str |
| 318 | + total: float |
| 319 | + size: int |
| 320 | + |
| 321 | + _print_order = ["name", "total", "size"] + BaseEntry._print_order |
| 322 | + |
| 323 | + |
| 324 | +class TypedEntry(BaseEntry): |
| 325 | + """ |
| 326 | + BaseEntry with additional info like type and size |
| 327 | + """ |
| 328 | + |
| 329 | + type: str |
| 330 | + size: int |
| 331 | + |
| 332 | + _print_order = ["type", "size"] + BaseEntry._print_order |
| 333 | + |
| 334 | + |
| 335 | +class ModelEntry(BaseEntry): |
| 336 | + """ |
| 337 | + BaseEntry which includes name of the model |
| 338 | + """ |
| 339 | + |
| 340 | + model: str |
| 341 | + _print_order = ["model"] + BaseEntry._print_order |
| 342 | + |
| 343 | + |
| 344 | +class SizedModelEntry(ModelEntry): |
| 345 | + """ |
| 346 | + A ModelEntry with additional info like count and size |
| 347 | + """ |
| 348 | + |
| 349 | + count: int |
| 350 | + size: int |
| 351 | + _print_order = ModelEntry._print_order + ["count", "size"] |
| 352 | + |
| 353 | + |
| 354 | +class PerformanceEntry(BaseEntry): |
| 355 | + """ |
| 356 | + A BaseEntry with additional performance info |
| 357 | + """ |
| 358 | + |
| 359 | + model: str |
| 360 | + latency: float |
| 361 | + throughput: float |
| 362 | + supported_graph: float |
| 363 | + |
| 364 | + _print_order = [ |
| 365 | + "model", |
| 366 | + "latency", |
| 367 | + "throughput", |
| 368 | + "supported_graph", |
| 369 | + ] + BaseEntry._print_order |
| 370 | + |
| 371 | + |
| 372 | +class NodeTimingEntry(Entry): |
| 373 | + """ |
| 374 | + A BaseEntry with additional performance info |
| 375 | + """ |
| 376 | + |
| 377 | + node_name: str |
| 378 | + avg_runtime: float |
| 379 | + |
| 380 | + _print_order = [ |
| 381 | + "node_name", |
| 382 | + "avg_runtime", |
| 383 | + ] + Entry._print_order |
| 384 | + |
| 385 | + |
| 386 | +class Section(Entry): |
| 387 | + """ |
| 388 | + Represents a list of Entries with an optional name |
| 389 | + """ |
| 390 | + |
| 391 | + entries: List[ |
| 392 | + Union[ |
| 393 | + NodeTimingEntry, |
| 394 | + PerformanceEntry, |
| 395 | + NamedEntry, |
| 396 | + TypedEntry, |
| 397 | + SizedModelEntry, |
| 398 | + ModelEntry, |
| 399 | + BaseEntry, |
| 400 | + ] |
| 401 | + ] |
| 402 | + |
| 403 | + section_name: str = "" |
| 404 | + |
| 405 | + def pretty_print(self): |
| 406 | + """ |
| 407 | + pretty print current section, with its entries |
| 408 | + """ |
| 409 | + if self.section_name: |
| 410 | + if not self.entries: |
| 411 | + print(f"No entries found in: {self.section_name}") |
| 412 | + else: |
| 413 | + print(f"{self.section_name}:") |
| 414 | + |
| 415 | + for idx, entry in enumerate(self.entries): |
| 416 | + if idx == 0: |
| 417 | + entry.pretty_print(headers=True) |
| 418 | + else: |
| 419 | + entry.pretty_print(headers=False) |
| 420 | + print() |
| 421 | + |
| 422 | + def __sub__(self, other: "Section"): |
| 423 | + """ |
| 424 | + A method that allows us to subtract two Section objects, |
| 425 | + If the section includes `NamedEntry` or `TypedEntry` then we only compare |
| 426 | + the entries which have the same name or type (and others will be ignored), |
| 427 | + Subtraction of other Entry types is delegated to their own implementation |
| 428 | + This function also assumes that a Section has entries of the same type |
| 429 | + """ |
| 430 | + |
| 431 | + if not isinstance(other, Section): |
| 432 | + raise TypeError( |
| 433 | + f"unsupported operand type(s) for -: {type(self)} and {type(other)}" |
| 434 | + ) |
| 435 | + |
| 436 | + section_name = self.section_name or "" |
| 437 | + self_entries, other_entries = self.get_comparable_entries(other) |
| 438 | + |
| 439 | + compared_entries = [ |
| 440 | + self_entry - other_entry |
| 441 | + for self_entry, other_entry in zip(self_entries, other_entries) |
| 442 | + ] |
| 443 | + |
| 444 | + return Section( |
| 445 | + section_name=section_name, |
| 446 | + entries=compared_entries, |
| 447 | + ) |
| 448 | + |
| 449 | + def get_comparable_entries(self, other: "Section") -> Tuple[List[Entry], ...]: |
| 450 | + """ |
| 451 | + Get comparable entries by same name or type if they belong to |
| 452 | + `NamedEntry`, `TypedEntry`, or `NodeTimingEntry`, else return all entries |
| 453 | +
|
| 454 | + :return: A tuple composed of two lists, containing comparable entries |
| 455 | + in correct order from current and other Section objects |
| 456 | + """ |
| 457 | + assert self.entries |
| 458 | + entry_type_to_extractor = { |
| 459 | + "NamedEntry": lambda entry: entry.name, |
| 460 | + "TypedEntry": lambda entry: entry.type, |
| 461 | + "NodeTimingEntry": lambda entry: entry.node_name, |
| 462 | + } |
| 463 | + entry_type = self.entries[0].__class__.__name__ |
| 464 | + |
| 465 | + if entry_type not in entry_type_to_extractor: |
| 466 | + return self.entries, other.entries |
| 467 | + |
| 468 | + key_extractor = entry_type_to_extractor[entry_type] |
| 469 | + self_entry_dict = {key_extractor(entry): entry for entry in self.entries} |
| 470 | + other_entry_dict = {key_extractor(entry): entry for entry in other.entries} |
| 471 | + |
| 472 | + self_comparable_entries = [] |
| 473 | + other_comparable_entries = [] |
| 474 | + |
| 475 | + for key, value in self_entry_dict.items(): |
| 476 | + if key in other_entry_dict: |
| 477 | + self_comparable_entries.append(value) |
| 478 | + other_comparable_entries.append(other_entry_dict[key]) |
| 479 | + |
| 480 | + if len(self_comparable_entries) != len(self_entry_dict): |
| 481 | + _LOGGER.info( |
| 482 | + "Found mismatching entries, these will be ignored during " |
| 483 | + f"comparison in Section: {self.section_name}" |
| 484 | + ) |
| 485 | + return self_comparable_entries, other_comparable_entries |
| 486 | + |
| 487 | + |
| 488 | +def multiline_pretty_print(row: List[str], column_width=20) -> str: |
| 489 | + """ |
| 490 | + Formats the contents of the specified row into a multiline string which |
| 491 | + each column is wrapped into a multiline string if its length is greater |
| 492 | + than the specified column_width |
| 493 | +
|
| 494 | + :param row: A list of strings to be formatted into a multiline row |
| 495 | + :param column_width: The max width of each column for formatting, default is 20 |
| 496 | + :returns: A multiline formatted string representing the row, |
| 497 | + """ |
| 498 | + row = [str(column) for column in row] |
| 499 | + result_string = "" |
| 500 | + col_delim = " " |
| 501 | + wrapped_row = [textwrap.wrap(col, column_width) for col in row] |
| 502 | + max_lines_needed = max(len(col) for col in wrapped_row) |
| 503 | + |
| 504 | + for line_idx in range(max_lines_needed): |
| 505 | + result_string += col_delim |
| 506 | + for column in wrapped_row: |
| 507 | + if line_idx < len(column): |
| 508 | + result_string += column[line_idx].ljust(column_width) |
| 509 | + else: |
| 510 | + result_string += " " * column_width |
| 511 | + result_string += col_delim |
| 512 | + if line_idx < max_lines_needed - 1: |
| 513 | + result_string += "\n" |
| 514 | + return result_string |
0 commit comments