Skip to content

Commit b2458c2

Browse files
authored
Skip loading external tensors; move constants closer to their users (#21)
<img width="1269" alt="image" src="https://github.com/justinchuby/model-explorer-onnx/assets/11205048/7a761e95-aa8d-4caa-87c7-461516ddaa8f">
1 parent 164d553 commit b2458c2

File tree

1 file changed

+36
-31
lines changed

1 file changed

+36
-31
lines changed

Diff for: src/model_explorer_onnx/main.py

+36-31
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import logging
4-
import os
54
from typing import Any, Literal, Sequence
65

76
import ml_dtypes
@@ -20,23 +19,23 @@
2019
def display_tensor(tensor: ir.TensorProtocol | None) -> str:
2120
if tensor is None:
2221
return "Data not available"
23-
if tensor.size < _TENSOR_DISPLAY_LIMIT:
24-
try:
25-
array = tensor.numpy()
26-
if tensor.dtype == ir.DataType.BFLOAT16:
27-
array = array.view(ml_dtypes.bfloat16)
28-
elif tensor.dtype == ir.DataType.FLOAT8E4M3FN:
29-
array = array.view(ml_dtypes.float8_e4m3fn)
30-
elif tensor.dtype == ir.DataType.FLOAT8E4M3FNUZ:
31-
array = array.view(ml_dtypes.float8_e4m3fnuz)
32-
elif tensor.dtype == ir.DataType.FLOAT8E5M2:
33-
array = array.view(ml_dtypes.float8_e5m2)
34-
elif tensor.dtype == ir.DataType.FLOAT8E5M2FNUZ:
35-
array = array.view(ml_dtypes.float8_e5m2fnuz)
36-
return np.array2string(array, separator=",")
37-
except Exception as e:
38-
logger.warning("Failed to display tensor: %s", e)
39-
return str(tensor)
22+
if tensor.size > _TENSOR_DISPLAY_LIMIT or isinstance(tensor, ir.ExternalTensor):
23+
return str(tensor)
24+
try:
25+
array = tensor.numpy()
26+
if tensor.dtype == ir.DataType.BFLOAT16:
27+
array = array.view(ml_dtypes.bfloat16)
28+
elif tensor.dtype == ir.DataType.FLOAT8E4M3FN:
29+
array = array.view(ml_dtypes.float8_e4m3fn)
30+
elif tensor.dtype == ir.DataType.FLOAT8E4M3FNUZ:
31+
array = array.view(ml_dtypes.float8_e4m3fnuz)
32+
elif tensor.dtype == ir.DataType.FLOAT8E5M2:
33+
array = array.view(ml_dtypes.float8_e5m2)
34+
elif tensor.dtype == ir.DataType.FLOAT8E5M2FNUZ:
35+
array = array.view(ml_dtypes.float8_e5m2fnuz)
36+
return np.array2string(array, separator=",")
37+
except Exception as e:
38+
logger.warning("Failed to display tensor: %s", e)
4039
return str(tensor)
4140

4241

@@ -262,9 +261,13 @@ def create_node(
262261
"""
263262
assert onnx_node.name, "Bug: Node name is required"
264263

265-
embedded_namespace = parse_namespace(onnx_node.name)
266-
if embedded_namespace:
267-
namespace = namespace + "/" + "/".join(embedded_namespace)
264+
if onnx_node.op_type == "Constant":
265+
# Move the constant closer to the user node's namespace
266+
namespace = get_constant_namespace(onnx_node.outputs[0], namespace)
267+
else:
268+
embedded_namespace = parse_namespace(onnx_node.name)
269+
if embedded_namespace:
270+
namespace = namespace + "/" + "/".join(embedded_namespace)
268271
node = graph_builder.GraphNode(
269272
id=onnx_node.name,
270273
label=create_op_label(onnx_node.domain, onnx_node.op_type),
@@ -312,8 +315,8 @@ def add_graph_io(
312315
all_nodes[node.id] = node
313316

314317

315-
def get_initializer_namespace(initializer: ir.Value, root_namespace: str) -> str:
316-
# If the initializer is used by a single node, move it to the same namespace as the node
318+
def get_constant_namespace(initializer: ir.Value, root_namespace: str) -> str:
319+
"""Move the constant/initializer closer to the user's namespace."""
317320
initializer_namespace = root_namespace
318321
# A single node can have multiple uses of the same value.
319322
# Here we only count the unique nodes that use the initializer to push the
@@ -323,6 +326,7 @@ def get_initializer_namespace(initializer: ir.Value, root_namespace: str) -> str
323326
# The initializer is not used by any node. Keep it in the root namespace.
324327
return initializer_namespace
325328
if len(user_nodes) == 1:
329+
# If the initializer is used by a single node, move it to the same namespace as the node
326330
user_node = user_nodes[0]
327331
assert (
328332
user_node.name
@@ -376,7 +380,7 @@ def add_initializers(
376380
node = graph_builder.GraphNode(
377381
id=initializer_node_name,
378382
label="Initializer",
379-
namespace=get_initializer_namespace(initializer, namespace),
383+
namespace=get_constant_namespace(initializer, namespace),
380384
)
381385
# Add metadata for the output tensor
382386
if initializer.const_value is None:
@@ -458,24 +462,25 @@ def convert(
458462
) -> model_explorer.ModelExplorerGraphs:
459463
del settings # Unused
460464

465+
# Do not load external data because the model file is copied to a temporary location
466+
# and the external data paths are not valid anymore.
461467
onnx_model = onnx.load(model_path, load_external_data=False)
462468
try:
463-
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
469+
onnx_model = onnx.shape_inference.infer_shapes(onnx_model, data_prop=True)
464470
except Exception as e:
465471
logger.warning(
466472
"Failed to infer shapes. Continue with the original model. Error: %s", e
467473
)
468474

469-
# Load external data after shape inference
470-
model_filepath = os.path.abspath(model_path)
471-
base_dir = os.path.dirname(model_filepath)
472-
onnx.load_external_data_for_model(onnx_model, base_dir)
473-
474475
# Convert to ONNX IR
475476
model = ir.serde.deserialize_model(onnx_model)
476477
all_function_ids = set(model.functions)
477478
graphs = []
478-
opset_version = model.opset_imports.get("", _DEFAULT_OPSET_VERSION)
479+
opset_version = model.opset_imports.get("")
480+
if opset_version is None:
481+
opset_version = model.opset_imports.get("ai.onnx")
482+
if opset_version is None:
483+
opset_version = _DEFAULT_OPSET_VERSION
479484
# TODO: Better support subgraphs in nodes
480485
main_graph = create_graph(
481486
model.graph, all_function_ids, opset_version=opset_version

0 commit comments

Comments
 (0)