Skip to content

Commit 88a4384

Browse files
authored
Create utilities for PyTorch (#43)
Create save_node_data_from_verification_info for visualizing ndoe data
1 parent d959e08 commit 88a4384

File tree

4 files changed

+161
-0
lines changed

4 files changed

+161
-0
lines changed

Diff for: README.md

+33
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,39 @@ Graph input/output/initializers in ONNX are values (edges), not nodes. A node is
3535

3636
Get node color themes [here](./themes)
3737

38+
## Visualizing PyTorch ONNX exporter (`dynamo=True`) accuracy results
39+
40+
> [!NOTE]
41+
> `VerificationInterpreter` requires PyTorch 2.7 or newer
42+
43+
```py
44+
import torch
45+
from torch.onnx.verification import VerificationInterpreter
46+
47+
from model_explorer_onnx.torch_utils import save_node_data_from_verification_info
48+
49+
# Export the and save model
50+
onnx_program = torch.onnx.export(model, args, dynamo=True)
51+
onnx_program.save("model.onnx")
52+
53+
# Use the VerificationInterpreter to obtain accuracy results
54+
interpreter = VerificationInterpreter(onnx_program)
55+
interpreter.run(*args)
56+
57+
# Produce node data for Model Explorer for visualization
58+
save_node_data_from_verification_info(
59+
interpreter.verification_infos, onnx_program.model, model_name="model"
60+
)
61+
```
62+
63+
You can then use Model Explorer to visualize the results by loading the generated node data files:
64+
65+
```sh
66+
onnxvis model.onnx --node_data_paths=model_max_abs_diff.json,model_max_rel_diff.json
67+
```
68+
69+
![node_data](./screenshots/node_data.png)
70+
3871
## Screenshots
3972

4073
<img width="1294" alt="image" src="https://github.com/justinchuby/model-explorer-onnx/assets/11205048/ed7e1eee-a693-48bd-811d-b384f784ef9b">

Diff for: screenshots/node_data.png

279 KB
Loading

Diff for: src/model_explorer_onnx/__init__.py

Whitespace-only changes.

Diff for: src/model_explorer_onnx/torch_utils.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""Utilities for PyTorch"""
2+
3+
from __future__ import annotations
4+
import os
5+
from typing import Collection, TYPE_CHECKING
6+
7+
from model_explorer import node_data_builder as ndb
8+
from onnxscript import ir
9+
import logging
10+
11+
if TYPE_CHECKING:
12+
# TODO: Change the import when it is exposed to public
13+
from torch.onnx._internal.exporter._verification import VerificationInfo
14+
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
def _create_value_mapping(graph: ir.Graph) -> dict[str, ir.Value]:
20+
"""Return a dictionary mapping names to values in the graph.
21+
22+
The mapping does not include values from subgraphs.
23+
24+
Args:
25+
graph: The graph to extract the mapping from.
26+
27+
Returns:
28+
A dictionary mapping names to values.
29+
"""
30+
values = {}
31+
values.update(graph.initializers)
32+
# The names of the values can be None or "", which we need to exclude
33+
for input in graph.inputs:
34+
if not input.name:
35+
continue
36+
values[input.name] = input
37+
for node in graph:
38+
for value in node.outputs:
39+
if not value.name:
40+
continue
41+
values[value.name] = value
42+
return values
43+
44+
45+
def save_node_data_from_verification_info(
46+
verification_infos: Collection[VerificationInfo],
47+
onnx_model: ir.Model,
48+
directory: str = "",
49+
model_name: str = "model",
50+
):
51+
"""Saves the node data for model explorer.
52+
53+
Example::
54+
55+
onnx_program = torch.onnx.export(
56+
model,
57+
args,
58+
dynamo=True
59+
)
60+
61+
onnx_program.save("model.onnx")
62+
63+
from torch.onnx.verification import VerificationInterpreter
64+
65+
interpreter = VerificationInterpreter(onnx_program)
66+
interpreter.run(*args)
67+
68+
from model_explorer_onnx.torch_utils import save_node_data_from_verification_info
69+
70+
save_node_data_from_verification_info(
71+
interpreter.verification_infos, onnx_program.model, model_name="model"
72+
)
73+
74+
You can then use Model Explorer to visualize the results by loading the generated node data files.
75+
76+
Args:
77+
verification_infos: The verification information objects.
78+
node_names: The names of the nodes each VerificationInfo corresponds to.
79+
model_name: The name of the model, used for constructing the file names.
80+
"""
81+
values = _create_value_mapping(onnx_model.graph)
82+
node_names = []
83+
for info in verification_infos:
84+
print(info.name, info.max_abs_diff)
85+
if info.name in values:
86+
node_names.append(values[info.name].producer().name)
87+
else:
88+
node_names.append(info.name)
89+
logger.warning(
90+
"The name %s is not found in the graph. Please ensure the model provided matches the "
91+
"verification information.",
92+
info.name,
93+
)
94+
for field in ("max_abs_diff", "max_rel_diff"):
95+
# Populate values for the main graph in a model.
96+
main_graph_results: dict[str, ndb.NodeDataResult] = {}
97+
for info, node_name in zip(verification_infos, node_names):
98+
if (
99+
values.get(info.name) is not None
100+
and values[info.name].is_graph_output()
101+
):
102+
main_graph_results[f"[value] {info.name}"] = ndb.NodeDataResult(
103+
value=getattr(info, field)
104+
)
105+
else:
106+
main_graph_results[node_name] = ndb.NodeDataResult(
107+
value=getattr(info, field)
108+
)
109+
110+
thresholds: list[ndb.ThresholdItem] = [
111+
ndb.ThresholdItem(value=0.00001, bgColor="#388e3c"),
112+
ndb.ThresholdItem(value=0.0001, bgColor="#8bc34a"),
113+
ndb.ThresholdItem(value=0.001, bgColor="#c8e6c9"),
114+
ndb.ThresholdItem(value=0.01, bgColor="#ffa000"),
115+
ndb.ThresholdItem(value=1, bgColor="#ff5722"),
116+
ndb.ThresholdItem(value=100, bgColor="#d32f2f"),
117+
]
118+
119+
# Construct the data for the main graph.
120+
main_graph_data = ndb.GraphNodeData(
121+
results=main_graph_results, thresholds=thresholds
122+
)
123+
124+
# Construct the data for the model.
125+
# "main_graph" is the default graph name produced by the exporter.
126+
model_data = ndb.ModelNodeData(graphsData={"main_graph": main_graph_data})
127+
128+
model_data.save_to_file(os.path.join(directory, f"{model_name}_{field}.json"))

0 commit comments

Comments
 (0)