1
1
from __future__ import annotations
2
2
3
3
import logging
4
- import os
5
4
from typing import Any , Literal , Sequence
6
5
7
6
import ml_dtypes
20
19
def display_tensor (tensor : ir .TensorProtocol | None ) -> str :
21
20
if tensor is None :
22
21
return "Data not available"
23
- if tensor .size < _TENSOR_DISPLAY_LIMIT :
24
- try :
25
- array = tensor . numpy ()
26
- if tensor .dtype == ir . DataType . BFLOAT16 :
27
- array = array . view ( ml_dtypes . bfloat16 )
28
- elif tensor . dtype == ir . DataType . FLOAT8E4M3FN :
29
- array = array . view ( ml_dtypes . float8_e4m3fn )
30
- elif tensor . dtype == ir . DataType . FLOAT8E4M3FNUZ :
31
- array = array . view ( ml_dtypes . float8_e4m3fnuz )
32
- elif tensor . dtype == ir . DataType . FLOAT8E5M2 :
33
- array = array . view ( ml_dtypes . float8_e5m2 )
34
- elif tensor . dtype == ir . DataType . FLOAT8E5M2FNUZ :
35
- array = array . view ( ml_dtypes . float8_e5m2fnuz )
36
- return np . array2string ( array , separator = "," )
37
- except Exception as e :
38
- logger . warning ( "Failed to display tensor: %s" , e )
39
- return str ( tensor )
22
+ if tensor .size > _TENSOR_DISPLAY_LIMIT or isinstance ( tensor , ir . ExternalTensor ) :
23
+ return str ( tensor )
24
+ try :
25
+ array = tensor .numpy ()
26
+ if tensor . dtype == ir . DataType . BFLOAT16 :
27
+ array = array . view ( ml_dtypes . bfloat16 )
28
+ elif tensor . dtype == ir . DataType . FLOAT8E4M3FN :
29
+ array = array . view ( ml_dtypes . float8_e4m3fn )
30
+ elif tensor . dtype == ir . DataType . FLOAT8E4M3FNUZ :
31
+ array = array . view ( ml_dtypes . float8_e4m3fnuz )
32
+ elif tensor . dtype == ir . DataType . FLOAT8E5M2 :
33
+ array = array . view ( ml_dtypes . float8_e5m2 )
34
+ elif tensor . dtype == ir . DataType . FLOAT8E5M2FNUZ :
35
+ array = array . view ( ml_dtypes . float8_e5m2fnuz )
36
+ return np . array2string ( array , separator = "," )
37
+ except Exception as e :
38
+ logger . warning ( "Failed to display tensor: %s" , e )
40
39
return str (tensor )
41
40
42
41
@@ -262,9 +261,13 @@ def create_node(
262
261
"""
263
262
assert onnx_node .name , "Bug: Node name is required"
264
263
265
- embedded_namespace = parse_namespace (onnx_node .name )
266
- if embedded_namespace :
267
- namespace = namespace + "/" + "/" .join (embedded_namespace )
264
+ if onnx_node .op_type == "Constant" :
265
+ # Move the constant closer to the user node's namespace
266
+ namespace = get_constant_namespace (onnx_node .outputs [0 ], namespace )
267
+ else :
268
+ embedded_namespace = parse_namespace (onnx_node .name )
269
+ if embedded_namespace :
270
+ namespace = namespace + "/" + "/" .join (embedded_namespace )
268
271
node = graph_builder .GraphNode (
269
272
id = onnx_node .name ,
270
273
label = create_op_label (onnx_node .domain , onnx_node .op_type ),
@@ -312,8 +315,8 @@ def add_graph_io(
312
315
all_nodes [node .id ] = node
313
316
314
317
315
- def get_initializer_namespace (initializer : ir .Value , root_namespace : str ) -> str :
316
- # If the initializer is used by a single node, move it to the same namespace as the node
318
+ def get_constant_namespace (initializer : ir .Value , root_namespace : str ) -> str :
319
+ """Move the constant/ initializer closer to the user's namespace."""
317
320
initializer_namespace = root_namespace
318
321
# A single node can have multiple uses of the same value.
319
322
# Here we only count the unique nodes that use the initializer to push the
@@ -323,6 +326,7 @@ def get_initializer_namespace(initializer: ir.Value, root_namespace: str) -> str
323
326
# The initializer is not used by any node. Keep it in the root namespace.
324
327
return initializer_namespace
325
328
if len (user_nodes ) == 1 :
329
+ # If the initializer is used by a single node, move it to the same namespace as the node
326
330
user_node = user_nodes [0 ]
327
331
assert (
328
332
user_node .name
@@ -376,7 +380,7 @@ def add_initializers(
376
380
node = graph_builder .GraphNode (
377
381
id = initializer_node_name ,
378
382
label = "Initializer" ,
379
- namespace = get_initializer_namespace (initializer , namespace ),
383
+ namespace = get_constant_namespace (initializer , namespace ),
380
384
)
381
385
# Add metadata for the output tensor
382
386
if initializer .const_value is None :
@@ -458,24 +462,25 @@ def convert(
458
462
) -> model_explorer .ModelExplorerGraphs :
459
463
del settings # Unused
460
464
465
+ # Do not load external data because the model file is copied to a temporary location
466
+ # and the external data paths are not valid anymore.
461
467
onnx_model = onnx .load (model_path , load_external_data = False )
462
468
try :
463
- onnx_model = onnx .shape_inference .infer_shapes (onnx_model )
469
+ onnx_model = onnx .shape_inference .infer_shapes (onnx_model , data_prop = True )
464
470
except Exception as e :
465
471
logger .warning (
466
472
"Failed to infer shapes. Continue with the original model. Error: %s" , e
467
473
)
468
474
469
- # Load external data after shape inference
470
- model_filepath = os .path .abspath (model_path )
471
- base_dir = os .path .dirname (model_filepath )
472
- onnx .load_external_data_for_model (onnx_model , base_dir )
473
-
474
475
# Convert to ONNX IR
475
476
model = ir .serde .deserialize_model (onnx_model )
476
477
all_function_ids = set (model .functions )
477
478
graphs = []
478
- opset_version = model .opset_imports .get ("" , _DEFAULT_OPSET_VERSION )
479
+ opset_version = model .opset_imports .get ("" )
480
+ if opset_version is None :
481
+ opset_version = model .opset_imports .get ("ai.onnx" )
482
+ if opset_version is None :
483
+ opset_version = _DEFAULT_OPSET_VERSION
479
484
# TODO: Better support subgraphs in nodes
480
485
main_graph = create_graph (
481
486
model .graph , all_function_ids , opset_version = opset_version
0 commit comments