|
| 1 | +"""Utilities for PyTorch""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | +import os |
| 5 | +from typing import Collection, TYPE_CHECKING |
| 6 | + |
| 7 | +from model_explorer import node_data_builder as ndb |
| 8 | +from onnxscript import ir |
| 9 | +import logging |
| 10 | + |
| 11 | +if TYPE_CHECKING: |
| 12 | + # TODO: Change the import when it is exposed to public |
| 13 | + from torch.onnx._internal.exporter._verification import VerificationInfo |
| 14 | + |
| 15 | + |
| 16 | +logger = logging.getLogger(__name__) |
| 17 | + |
| 18 | + |
| 19 | +def _create_value_mapping(graph: ir.Graph) -> dict[str, ir.Value]: |
| 20 | + """Return a dictionary mapping names to values in the graph. |
| 21 | +
|
| 22 | + The mapping does not include values from subgraphs. |
| 23 | +
|
| 24 | + Args: |
| 25 | + graph: The graph to extract the mapping from. |
| 26 | +
|
| 27 | + Returns: |
| 28 | + A dictionary mapping names to values. |
| 29 | + """ |
| 30 | + values = {} |
| 31 | + values.update(graph.initializers) |
| 32 | + # The names of the values can be None or "", which we need to exclude |
| 33 | + for input in graph.inputs: |
| 34 | + if not input.name: |
| 35 | + continue |
| 36 | + values[input.name] = input |
| 37 | + for node in graph: |
| 38 | + for value in node.outputs: |
| 39 | + if not value.name: |
| 40 | + continue |
| 41 | + values[value.name] = value |
| 42 | + return values |
| 43 | + |
| 44 | + |
| 45 | +def save_node_data_from_verification_info( |
| 46 | + verification_infos: Collection[VerificationInfo], |
| 47 | + onnx_model: ir.Model, |
| 48 | + directory: str = "", |
| 49 | + model_name: str = "model", |
| 50 | +): |
| 51 | + """Saves the node data for model explorer. |
| 52 | +
|
| 53 | + Example:: |
| 54 | +
|
| 55 | + onnx_program = torch.onnx.export( |
| 56 | + model, |
| 57 | + args, |
| 58 | + dynamo=True |
| 59 | + ) |
| 60 | +
|
| 61 | + onnx_program.save("model.onnx") |
| 62 | +
|
| 63 | + from torch.onnx.verification import VerificationInterpreter |
| 64 | +
|
| 65 | + interpreter = VerificationInterpreter(onnx_program) |
| 66 | + interpreter.run(*args) |
| 67 | +
|
| 68 | + from model_explorer_onnx.torch_utils import save_node_data_from_verification_info |
| 69 | +
|
| 70 | + save_node_data_from_verification_info( |
| 71 | + interpreter.verification_infos, onnx_program.model, model_name="model" |
| 72 | + ) |
| 73 | +
|
| 74 | + You can then use Model Explorer to visualize the results by loading the generated node data files. |
| 75 | +
|
| 76 | + Args: |
| 77 | + verification_infos: The verification information objects. |
| 78 | + node_names: The names of the nodes each VerificationInfo corresponds to. |
| 79 | + model_name: The name of the model, used for constructing the file names. |
| 80 | + """ |
| 81 | + values = _create_value_mapping(onnx_model.graph) |
| 82 | + node_names = [] |
| 83 | + for info in verification_infos: |
| 84 | + print(info.name, info.max_abs_diff) |
| 85 | + if info.name in values: |
| 86 | + node_names.append(values[info.name].producer().name) |
| 87 | + else: |
| 88 | + node_names.append(info.name) |
| 89 | + logger.warning( |
| 90 | + "The name %s is not found in the graph. Please ensure the model provided matches the " |
| 91 | + "verification information.", |
| 92 | + info.name, |
| 93 | + ) |
| 94 | + for field in ("max_abs_diff", "max_rel_diff"): |
| 95 | + # Populate values for the main graph in a model. |
| 96 | + main_graph_results: dict[str, ndb.NodeDataResult] = {} |
| 97 | + for info, node_name in zip(verification_infos, node_names): |
| 98 | + if ( |
| 99 | + values.get(info.name) is not None |
| 100 | + and values[info.name].is_graph_output() |
| 101 | + ): |
| 102 | + main_graph_results[f"[value] {info.name}"] = ndb.NodeDataResult( |
| 103 | + value=getattr(info, field) |
| 104 | + ) |
| 105 | + else: |
| 106 | + main_graph_results[node_name] = ndb.NodeDataResult( |
| 107 | + value=getattr(info, field) |
| 108 | + ) |
| 109 | + |
| 110 | + thresholds: list[ndb.ThresholdItem] = [ |
| 111 | + ndb.ThresholdItem(value=0.00001, bgColor="#388e3c"), |
| 112 | + ndb.ThresholdItem(value=0.0001, bgColor="#8bc34a"), |
| 113 | + ndb.ThresholdItem(value=0.001, bgColor="#c8e6c9"), |
| 114 | + ndb.ThresholdItem(value=0.01, bgColor="#ffa000"), |
| 115 | + ndb.ThresholdItem(value=1, bgColor="#ff5722"), |
| 116 | + ndb.ThresholdItem(value=100, bgColor="#d32f2f"), |
| 117 | + ] |
| 118 | + |
| 119 | + # Construct the data for the main graph. |
| 120 | + main_graph_data = ndb.GraphNodeData( |
| 121 | + results=main_graph_results, thresholds=thresholds |
| 122 | + ) |
| 123 | + |
| 124 | + # Construct the data for the model. |
| 125 | + # "main_graph" is the default graph name produced by the exporter. |
| 126 | + model_data = ndb.ModelNodeData(graphsData={"main_graph": main_graph_data}) |
| 127 | + |
| 128 | + model_data.save_to_file(os.path.join(directory, f"{model_name}_{field}.json")) |
0 commit comments