Skip to content

Commit 4cde00b

Browse files
authored
Update: Analyze api, add DataStructures for summary (#288)
* Add:ModuleAnalysisSummary schema for easily displaying comparing summaries from ModelAnalysis Add: Tests for yaml serialization of the same * Move: yaml de-serialization methods to YAMLSerializableBaseModel * Style * Add: Tests for serialization, pretty printing, and subtraction * Get path programmatically * Add from_analysis method * Move relevant schemas to models.py Move ModelAnalysisSummary to analysis.py Moved analysis tests to a separate dir Remove extraneous test Simplify logic Move Summary generation to a separate class Delete old code * Fix failing tests * Update CLI * Update CLI * Add by_types info to analyze api (#292) * Add by types * Remove unintended changes * Style * Add kwargs * Fix failing tests * Fix failing test + `by-layers` * Add compare functionality * Add deprecation warning for `ModelAnalysis.pretty_print_summary(...)` Move pandas import within a pretty_print_summary(...) function that will be deprecated in a future version * Add `by-layers` support to analyze api(s) (#301) * Remove Not Implemented Error * Add: `by-layers` analysis Add: total property to `ZeroNonZeroParams` * Propagate `by-layers` to comparison summary * Rename: LINEAR_OP_TYPES --> TARGETED_LINEAR_OP_TYPES * BugFix: int32 was ignored from dense ops during analysis by types * Add support to compare by types and layers (#302) * Feature: Add compare across types and layers * Renames: * sparsity --> sparsity_percent * quantized --> quantized_percent Update: Size calculation to include sparsity Add: __name__ to avoid using root logger Added: slight improvements to print messaging, and logs Style Updates * Rename: * _get_entries_to_compare --> _get_comparable_entries * Add: docstrings * Add: Multiline row printing (#303) * Feature: Add compare across types and layers * Renames: * sparsity --> sparsity_percent * quantized --> quantized_percent Update: Size calculation to include sparsity Add: __name__ to avoid using root logger Added: slight improvements to print messaging, and logs Style Updates * Rename: * _get_entries_to_compare --> _get_comparable_entries * Add: docstrings * Add: Support to print multiline rows * Connect Deepsparse.Analyze (#304) * Feature: Add compare across types and layers * Renames: * sparsity --> sparsity_percent * quantized --> quantized_percent Update: Size calculation to include sparsity Add: __name__ to avoid using root logger Added: slight improvements to print messaging, and logs Style Updates * Rename: * _get_entries_to_compare --> _get_comparable_entries * Add: docstrings * Add: Support to print multiline rows * Connect deepsparse.analyze to sparsezoo.analyze Add: PerformanceEntry to entry types, else it will be converted to a ModelEntry Fix: model_name while instantiating PerformanceEntry * Add: supported_graph_percentage to BenchmarkResult * Add node level timings * Style
1 parent 86ce9ca commit 4cde00b

File tree

8 files changed

+888
-237
lines changed

8 files changed

+888
-237
lines changed

src/sparsezoo/analyze/analysis.py

+420-223
Large diffs are not rendered by default.

src/sparsezoo/analyze/cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def analyze_options(command: click.Command):
8484
"boolean string, generates and records the results by operator type",
8585
)
8686
@click.option(
87-
"--by-layer",
87+
"--by-layers",
8888
default=None,
8989
type=str,
9090
help="A flag to enable analysis results by layer type. If set or "

src/sparsezoo/analyze/utils/models.py

+294-3
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from typing import Dict, List, Optional, Union
14+
import logging
15+
import textwrap
16+
from typing import Dict, List, Optional, Tuple, Union
1617

1718
from pydantic import BaseModel, Field
1819

@@ -28,6 +29,8 @@
2829
"ParameterComponent",
2930
]
3031

32+
_LOGGER = logging.getLogger(__name__)
33+
3134

3235
class PropertyBaseModel(BaseModel):
3336
"""
@@ -124,12 +127,16 @@ class ZeroNonZeroParams(PropertyBaseModel):
124127

125128
@property
126129
def sparsity(self):
127-
total_values = self.non_zero + self.zero
130+
total_values = self.total
128131
if total_values > 0:
129132
return self.zero / total_values
130133
else:
131134
return 0
132135

136+
@property
137+
def total(self):
138+
return self.non_zero + self.zero
139+
133140

134141
class DenseSparseOps(PropertyBaseModel):
135142
"""
@@ -221,3 +228,287 @@ class ParameterComponent(BaseModel):
221228
description="A summary of the parameter"
222229
)
223230
dtype: str = Field(description="The data type of the parameter")
231+
232+
233+
class Entry(BaseModel):
234+
"""
235+
A BaseModel with subtraction and pretty_print support
236+
"""
237+
238+
_print_order: List[str] = []
239+
240+
def __sub__(self, other):
241+
"""
242+
Allows base functionality for all inheriting classes to be subtract-able,
243+
subtracts the fields of self with other while providing some additional
244+
support for string and unrolling list type fields
245+
"""
246+
my_fields = self.__fields__
247+
other_fields = other.__fields__
248+
249+
assert list(my_fields) == list(other_fields)
250+
new_fields = {}
251+
for field in my_fields:
252+
if field.startswith("_"):
253+
# ignore private fields
254+
continue
255+
my_value = getattr(self, field)
256+
other_value = getattr(other, field)
257+
258+
assert type(my_value) == type(other_value)
259+
if field == "section_name":
260+
new_fields[field] = my_value
261+
elif isinstance(my_value, str):
262+
new_fields[field] = (
263+
my_value
264+
if my_value == other_value
265+
else f"{my_value} - {other_value}"
266+
)
267+
elif isinstance(my_value, list):
268+
new_fields[field] = [
269+
item_a - item_b for item_a, item_b in zip(my_value, other_value)
270+
]
271+
else:
272+
new_fields[field] = my_value - other_value
273+
274+
return self.__class__(**new_fields)
275+
276+
def pretty_print(self, headers: bool = False, column_width=30):
277+
"""
278+
pretty print current Entry object with all it's fields
279+
"""
280+
field_names = self._print_order
281+
field_values = []
282+
for field_name in field_names:
283+
field_value = getattr(self, field_name)
284+
if isinstance(field_value, float):
285+
field_value = f"{field_value:.2f}"
286+
field_values.append(field_value)
287+
288+
if headers:
289+
print(
290+
multiline_pretty_print(
291+
row=[field_name.upper() for field_name in field_names],
292+
column_width=column_width,
293+
)
294+
)
295+
print(multiline_pretty_print(row=field_values, column_width=column_width))
296+
297+
298+
class BaseEntry(Entry):
299+
"""
300+
The BaseModel representing a row entry
301+
302+
:param sparsity: A float between 0-100 representing sparsity percentage
303+
:param quantized: A float between 0-100 representing quantized percentage
304+
"""
305+
306+
sparsity: float
307+
quantized: float
308+
309+
_print_order = ["sparsity", "quantized"]
310+
311+
312+
class NamedEntry(BaseEntry):
313+
"""
314+
BaseEntry with additional info like name, total and size
315+
"""
316+
317+
name: str
318+
total: float
319+
size: int
320+
321+
_print_order = ["name", "total", "size"] + BaseEntry._print_order
322+
323+
324+
class TypedEntry(BaseEntry):
325+
"""
326+
BaseEntry with additional info like type and size
327+
"""
328+
329+
type: str
330+
size: int
331+
332+
_print_order = ["type", "size"] + BaseEntry._print_order
333+
334+
335+
class ModelEntry(BaseEntry):
336+
"""
337+
BaseEntry which includes name of the model
338+
"""
339+
340+
model: str
341+
_print_order = ["model"] + BaseEntry._print_order
342+
343+
344+
class SizedModelEntry(ModelEntry):
345+
"""
346+
A ModelEntry with additional info like count and size
347+
"""
348+
349+
count: int
350+
size: int
351+
_print_order = ModelEntry._print_order + ["count", "size"]
352+
353+
354+
class PerformanceEntry(BaseEntry):
355+
"""
356+
A BaseEntry with additional performance info
357+
"""
358+
359+
model: str
360+
latency: float
361+
throughput: float
362+
supported_graph: float
363+
364+
_print_order = [
365+
"model",
366+
"latency",
367+
"throughput",
368+
"supported_graph",
369+
] + BaseEntry._print_order
370+
371+
372+
class NodeTimingEntry(Entry):
373+
"""
374+
A BaseEntry with additional performance info
375+
"""
376+
377+
node_name: str
378+
avg_runtime: float
379+
380+
_print_order = [
381+
"node_name",
382+
"avg_runtime",
383+
] + Entry._print_order
384+
385+
386+
class Section(Entry):
387+
"""
388+
Represents a list of Entries with an optional name
389+
"""
390+
391+
entries: List[
392+
Union[
393+
NodeTimingEntry,
394+
PerformanceEntry,
395+
NamedEntry,
396+
TypedEntry,
397+
SizedModelEntry,
398+
ModelEntry,
399+
BaseEntry,
400+
]
401+
]
402+
403+
section_name: str = ""
404+
405+
def pretty_print(self):
406+
"""
407+
pretty print current section, with its entries
408+
"""
409+
if self.section_name:
410+
if not self.entries:
411+
print(f"No entries found in: {self.section_name}")
412+
else:
413+
print(f"{self.section_name}:")
414+
415+
for idx, entry in enumerate(self.entries):
416+
if idx == 0:
417+
entry.pretty_print(headers=True)
418+
else:
419+
entry.pretty_print(headers=False)
420+
print()
421+
422+
def __sub__(self, other: "Section"):
423+
"""
424+
A method that allows us to subtract two Section objects,
425+
If the section includes `NamedEntry` or `TypedEntry` then we only compare
426+
the entries which have the same name or type (and others will be ignored),
427+
Subtraction of other Entry types is delegated to their own implementation
428+
This function also assumes that a Section has entries of the same type
429+
"""
430+
431+
if not isinstance(other, Section):
432+
raise TypeError(
433+
f"unsupported operand type(s) for -: {type(self)} and {type(other)}"
434+
)
435+
436+
section_name = self.section_name or ""
437+
self_entries, other_entries = self.get_comparable_entries(other)
438+
439+
compared_entries = [
440+
self_entry - other_entry
441+
for self_entry, other_entry in zip(self_entries, other_entries)
442+
]
443+
444+
return Section(
445+
section_name=section_name,
446+
entries=compared_entries,
447+
)
448+
449+
def get_comparable_entries(self, other: "Section") -> Tuple[List[Entry], ...]:
450+
"""
451+
Get comparable entries by same name or type if they belong to
452+
`NamedEntry`, `TypedEntry`, or `NodeTimingEntry`, else return all entries
453+
454+
:return: A tuple composed of two lists, containing comparable entries
455+
in correct order from current and other Section objects
456+
"""
457+
assert self.entries
458+
entry_type_to_extractor = {
459+
"NamedEntry": lambda entry: entry.name,
460+
"TypedEntry": lambda entry: entry.type,
461+
"NodeTimingEntry": lambda entry: entry.node_name,
462+
}
463+
entry_type = self.entries[0].__class__.__name__
464+
465+
if entry_type not in entry_type_to_extractor:
466+
return self.entries, other.entries
467+
468+
key_extractor = entry_type_to_extractor[entry_type]
469+
self_entry_dict = {key_extractor(entry): entry for entry in self.entries}
470+
other_entry_dict = {key_extractor(entry): entry for entry in other.entries}
471+
472+
self_comparable_entries = []
473+
other_comparable_entries = []
474+
475+
for key, value in self_entry_dict.items():
476+
if key in other_entry_dict:
477+
self_comparable_entries.append(value)
478+
other_comparable_entries.append(other_entry_dict[key])
479+
480+
if len(self_comparable_entries) != len(self_entry_dict):
481+
_LOGGER.info(
482+
"Found mismatching entries, these will be ignored during "
483+
f"comparison in Section: {self.section_name}"
484+
)
485+
return self_comparable_entries, other_comparable_entries
486+
487+
488+
def multiline_pretty_print(row: List[str], column_width=20) -> str:
489+
"""
490+
Formats the contents of the specified row into a multiline string which
491+
each column is wrapped into a multiline string if its length is greater
492+
than the specified column_width
493+
494+
:param row: A list of strings to be formatted into a multiline row
495+
:param column_width: The max width of each column for formatting, default is 20
496+
:returns: A multiline formatted string representing the row,
497+
"""
498+
row = [str(column) for column in row]
499+
result_string = ""
500+
col_delim = " "
501+
wrapped_row = [textwrap.wrap(col, column_width) for col in row]
502+
max_lines_needed = max(len(col) for col in wrapped_row)
503+
504+
for line_idx in range(max_lines_needed):
505+
result_string += col_delim
506+
for column in wrapped_row:
507+
if line_idx < len(column):
508+
result_string += column[line_idx].ljust(column_width)
509+
else:
510+
result_string += " " * column_width
511+
result_string += col_delim
512+
if line_idx < max_lines_needed - 1:
513+
result_string += "\n"
514+
return result_string

0 commit comments

Comments
 (0)