Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assertion failed of TensorRT 10.9 when export nvidia-embed-v2 ONNX Model to TensorRT (polygraphy and trtexec) #4393

Open
ducknificient opened this issue Mar 20, 2025 · 1 comment

Comments

@ducknificient
Copy link

ducknificient commented Mar 20, 2025


name: Report a TensorRT issue
about: Failed to export ONNX Model (Transformer) to TensorRT
title: 'Assertion failed of TensorRT 10.9 when export ONNX Model to TensorRT (polygraphy and trtexec)'
labels: ''
assignees: ''


Description

I tried to export the onnx model on tensorrt environment (trtexec and polygraphy), but it fails with error

[E] In node 70 with name: node_ReduceMean_70 and operator: ReduceMean (reduceTensor): UNSUPPORTED_NODE: Assertion failed: inputAxes.is_weights(): Axis input must be an initializer!

I also tried to run the model. The model work on onnxtrt, but not trt

Environment

Google Colab

TensorRT Version: TensorRT v100900
TensorRT PIP Version: 10.9.0.34
NVIDIA GPU: NVIDIA A100-SXM4-40GB
NVIDIA Driver Version: 550.54.15
Cuda compilation tools, release 12.5, V12.5.82
Build cuda_12.5.r12.5/compiler.34385749_0

Polygraphy | Version: 0.49.20
Collecting environment information...
PyTorch version: 2.6.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.31.6
Libc version: glibc-2.35

Python version: 3.11.11 (main, Dec 4 2024, 08:55:07) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-6.1.85+-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.5.82
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA L4
Nvidia driver version: 550.54.15
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.2.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) CPU @ 2.20GHz
CPU family: 6
Model: 85
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
Stepping: 7
BogoMIPS: 4400.48
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat avx512_vnni md_clear arch_capabilities
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 192 KiB (6 instances)
L1i cache: 192 KiB (6 instances)
L2 cache: 6 MiB (6 instances)
L3 cache: 38.5 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Vulnerable
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1: Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2: Vulnerable; IBPB: disabled; STIBP: disabled; PBRSB-eIBRS: Vulnerable; BHI: Vulnerable (Syscall hardening enabled)
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Vulnerable

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] nvtx==0.2.11
[pip3] onnx==1.17.0
[pip3] onnxruntime-gpu==1.21.0
[pip3] onnxscript==0.2.2
[pip3] optree==0.14.1
[pip3] pynvjitlink-cu12==0.5.2
[pip3] torch==2.6.0+cu124
[pip3] torchaudio==2.6.0+cu124
[pip3] torchsummary==1.5.1
[pip3] torchvision==0.21.0+cu124
[pip3] triton==3.2.0
[conda] Could not collect

Model link: https://huggingface.co/nvidia/NV-Embed-v2
with forked transformer : https://github.com/ducknificient/transformers
latest transformers with mistral-model in 4.42.4

Steps To Reproduce

Step 1 : Export to ONNX

!pip install datasets==3.3.2
!pip install onnx==1.17.0 onnxscript==0.2.2 torch==2.6.0 torchvision==0.21.0 onnxruntime-gpu==1.21.0

!git clone https://github.com/ducknificient/transformers.git
!pip install /content/transformers

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

model_path = "nvidia/nv-embed-v2"
model = AutoModel.from_pretrained(
      model_path,
      trust_remote_code=True,)
model.eval()

from torch.export import Dim

# Step 1: Define Input
batch_size = 4
dummy_input_ids = torch.randint(0, 32000, (batch_size, 128))  # Batch size 2, sequence length 128
dummy_attention_mask = torch.ones((batch_size, 128), dtype=torch.int64)
dummy_pool_mask = torch.ones((batch_size, 128), dtype=torch.int64)

# Step 2: Define Dynamic shapes
dynamic_shapes = {
    "input_ids": (Dim.DYNAMIC, Dim.DYNAMIC),
    "attention_mask": (Dim.DYNAMIC, Dim.DYNAMIC),
    "pool_mask": (Dim.DYNAMIC, Dim.DYNAMIC),
}

# Step 3: Define outputh path
output_path = "/content/gdrive/MyDrive/nv-embed-v2-onnx-13/model-onnx"

# Step 4: Export to ONNX
torch.onnx.export(
    model,                                      # PyTorch model
    (dummy_input_ids, dummy_attention_mask, dummy_pool_mask),
    output_path,                                # Output file
    export_params=True,                         # Store the trained weights
    opset_version=13,                           # ONNX opset version
    do_constant_folding=True,
    input_names=['input_ids', 'attention_mask','pool_mask'], # Input names
    output_names=['sentence_embeddings'], # Output names
    dynamic_shapes=dynamic_shapes,
    dynamo=True,
    verbose=True                               # Detailed output
)

print(f"Model exported to {output_path}")

Step 2 : Run in onnxtrt

This is working

import tensorrt

!polygraphy run /content/gdrive/MyDrive/nv-embed-v2-onnx/model-onnx \
--model-type onnx \
--execution-providers=cuda \
--onnxrt \
--verbose 
[I] RUNNING | Command: /usr/local/bin/polygraphy run /content/gdrive/MyDrive/nv-embed-v2-onnx/model-onnx --model-type onnx --execution-providers=cuda --onnxrt
[I] onnxrt-runner-N0-03/20/25-06:36:33  | Activating and starting inference
[I] Creating ONNX-Runtime Inference Session with providers: ['CUDAExecutionProvider']
2025-03-20 06:44:05.321069459 [W:onnxruntime:, transformer_memcpy.cc:83 ApplyImpl] 37 Memcpy nodes are added to the graph main_graph for CUDAExecutionProvider. It might have negative impact on performance (including unable to run CUDA graph). Set session_options.log_severity_level=1 to see the detail logs before this message.
2025-03-20 06:44:05.354382116 [W:onnxruntime:, session_state.cc:1263 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.
2025-03-20 06:44:05.354408937 [W:onnxruntime:, session_state.cc:1265 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.
[W] Input tensor: input_ids [shape=BoundedShape(['4', '128'], min=None, max=None)] | Will generate data of shape: [1, 1].
    If this is incorrect, please provide a custom data loader.
[W] Input tensor: attention_mask [shape=BoundedShape(['4', '128'], min=None, max=None)] | Will generate data of shape: [1, 1].
    If this is incorrect, please provide a custom data loader.
[W] Input tensor: pool_mask [shape=BoundedShape(['4', '128'], min=None, max=None)] | Will generate data of shape: [1, 1].
    If this is incorrect, please provide a custom data loader.
[I] onnxrt-runner-N0-03/20/25-06:36:33 
    ---- Inference Input(s) ----
    {input_ids [dtype=int64, shape=(1, 1)],
     attention_mask [dtype=int64, shape=(1, 1)],
     pool_mask [dtype=int64, shape=(1, 1)]}
[I] onnxrt-runner-N0-03/20/25-06:36:33 
    ---- Inference Output(s) ----
    {sentence_embeddings [dtype=float32, shape=(1, 4096)]}
[I] onnxrt-runner-N0-03/20/25-06:36:33  | Completed 1 iteration(s) in 616.3 ms | Average inference time: 616.3 ms.
[I] PASSED | Runtime: 475.567s | Command: /usr/local/bin/polygraphy run /content/gdrive/MyDrive/nv-embed-v2-onnx/model-onnx --model-type onnx --execution-providers=cuda --onnxrt

Step 3 : Failed to export, or run for tensorrt

with trtexec

!/usr/src/tensorrt/bin/trtexec \
--onnx=/content/gdrive/MyDrive/nv-embed-v2-onnx-13/model-onnx \
--saveEngine=/content/gdrive/MyDrive/nv-embed-v2-trt/model-trt.trt \
--int8 \
--verbose
&&&& RUNNING TensorRT.trtexec [TensorRT v100900] [b34] # /usr/src/tensorrt/bin/trtexec --onnx=/content/gdrive/MyDrive/nv-embed-v2-onnx-13/model-onnx --saveEngine=/content/gdrive/MyDrive/nv-embed-v2-trt/model-trt.trt --int8 --verbose
[03/20/2025-06:20:36] [I] === Model Options ===
[03/20/2025-06:20:36] [I] Format: ONNX
[03/20/2025-06:20:36] [I] Model: /content/gdrive/MyDrive/nv-embed-v2-onnx-13/model-onnx
[03/20/2025-06:20:36] [I] Output:
[03/20/2025-06:20:36] [I] === Build Options ===
[03/20/2025-06:20:36] [I] Memory Pools: workspace: default, dlaSRAM: default, dlaLocalDRAM: default, dlaGlobalDRAM: default, tacticSharedMem: default
[03/20/2025-06:20:36] [I] avgTiming: 8
[03/20/2025-06:20:36] [I] Precision: FP32+INT8
[03/20/2025-06:20:36] [I] LayerPrecisions: 
[03/20/2025-06:20:36] [I] Layer Device Types: 
[03/20/2025-06:20:36] [I] Calibration: Dynamic
[03/20/2025-06:20:36] [I] Refit: Disabled
[03/20/2025-06:20:36] [I] Strip weights: Disabled
[03/20/2025-06:20:36] [I] Version Compatible: Disabled
[03/20/2025-06:20:36] [I] ONNX Plugin InstanceNorm: Disabled
[03/20/2025-06:20:36] [I] TensorRT runtime: full
[03/20/2025-06:20:36] [I] Lean DLL Path: 
[03/20/2025-06:20:36] [I] Tempfile Controls: { in_memory: allow, temporary: allow }
[03/20/2025-06:20:36] [I] Exclude Lean Runtime: Disabled
[03/20/2025-06:20:36] [I] Sparsity: Disabled
[03/20/2025-06:20:36] [I] Safe mode: Disabled
[03/20/2025-06:20:36] [I] Build DLA standalone loadable: Disabled
[03/20/2025-06:20:36] [I] Allow GPU fallback for DLA: Disabled
[03/20/2025-06:20:36] [I] DirectIO mode: Disabled
[03/20/2025-06:20:36] [I] Restricted mode: Disabled
[03/20/2025-06:20:36] [I] Skip inference: Disabled
[03/20/2025-06:20:36] [I] Save engine: /content/gdrive/MyDrive/nv-embed-v2-trt/model-trt.trt
[03/20/2025-06:20:36] [I] Load engine: 
[03/20/2025-06:20:36] [I] Profiling verbosity: 0
[03/20/2025-06:20:36] [I] Tactic sources: Using default tactic sources
[03/20/2025-06:20:36] [I] timingCacheMode: local
[03/20/2025-06:20:36] [I] timingCacheFile: 
[03/20/2025-06:20:36] [I] Enable Compilation Cache: Enabled
[03/20/2025-06:20:36] [I] Enable Monitor Memory: Disabled
[03/20/2025-06:20:36] [I] errorOnTimingCacheMiss: Disabled
[03/20/2025-06:20:36] [I] Preview Features: Use default preview flags.
[03/20/2025-06:20:36] [I] MaxAuxStreams: -1
[03/20/2025-06:20:36] [I] BuilderOptimizationLevel: -1
[03/20/2025-06:20:36] [I] MaxTactics: -1
[03/20/2025-06:20:36] [I] Calibration Profile Index: 0
[03/20/2025-06:20:36] [I] Weight Streaming: Disabled
[03/20/2025-06:20:36] [I] Runtime Platform: Same As Build
[03/20/2025-06:20:36] [I] Debug Tensors: 
[03/20/2025-06:20:36] [I] Input(s)s format: fp32:CHW
[03/20/2025-06:20:36] [I] Output(s)s format: fp32:CHW
[03/20/2025-06:20:36] [I] Input build shapes: model
[03/20/2025-06:20:36] [I] Input calibration shapes: model
[03/20/2025-06:20:36] [I] === System Options ===
[03/20/2025-06:20:36] [I] Device: 0
[03/20/2025-06:20:36] [I] DLACore: 
[03/20/2025-06:20:36] [I] Plugins:
[03/20/2025-06:20:36] [I] setPluginsToSerialize:
[03/20/2025-06:20:36] [I] dynamicPlugins:
[03/20/2025-06:20:36] [I] ignoreParsedPluginLibs: 0
[03/20/2025-06:20:36] [I] 
[03/20/2025-06:20:36] [I] === Inference Options ===
[03/20/2025-06:20:36] [I] Batch: Explicit
[03/20/2025-06:20:36] [I] Input inference shapes: model
[03/20/2025-06:20:36] [I] Iterations: 10
[03/20/2025-06:20:36] [I] Duration: 3s (+ 200ms warm up)
[03/20/2025-06:20:36] [I] Sleep time: 0ms
[03/20/2025-06:20:36] [I] Idle time: 0ms
[03/20/2025-06:20:36] [I] Inference Streams: 1
[03/20/2025-06:20:36] [I] ExposeDMA: Disabled
[03/20/2025-06:20:36] [I] Data transfers: Enabled
[03/20/2025-06:20:36] [I] Spin-wait: Disabled
[03/20/2025-06:20:36] [I] Multithreading: Disabled
[03/20/2025-06:20:36] [I] CUDA Graph: Disabled
[03/20/2025-06:20:36] [I] Separate profiling: Disabled
[03/20/2025-06:20:36] [I] Time Deserialize: Disabled
[03/20/2025-06:20:36] [I] Time Refit: Disabled
[03/20/2025-06:20:36] [I] NVTX verbosity: 0
[03/20/2025-06:20:36] [I] Persistent Cache Ratio: 0
[03/20/2025-06:20:36] [I] Optimization Profile Index: 0
[03/20/2025-06:20:36] [I] Weight Streaming Budget: 100.000000%
[03/20/2025-06:20:36] [I] Inputs:
[03/20/2025-06:20:36] [I] Debug Tensor Save Destinations:
[03/20/2025-06:20:36] [I] === Reporting Options ===
[03/20/2025-06:20:36] [I] Verbose: Enabled
[03/20/2025-06:20:36] [I] Averages: 10 inferences
[03/20/2025-06:20:36] [I] Percentiles: 90,95,99
[03/20/2025-06:20:36] [I] Dump refittable layers:Disabled
[03/20/2025-06:20:36] [I] Dump output: Disabled
[03/20/2025-06:20:36] [I] Profile: Disabled
[03/20/2025-06:20:36] [I] Export timing to JSON file: 
[03/20/2025-06:20:36] [I] Export output to JSON file: 
[03/20/2025-06:20:36] [I] Export profile to JSON file: 
[03/20/2025-06:20:36] [I] 
[03/20/2025-06:20:36] [I] === Device Information ===
[03/20/2025-06:20:36] [I] Available Devices: 
[03/20/2025-06:20:36] [I]   Device 0: "NVIDIA A100-SXM4-40GB" UUID: GPU-a61c6023-2edd-7390-7ef3-c82d1e2b47d5
[03/20/2025-06:20:36] [I] Selected Device: NVIDIA A100-SXM4-40GB
[03/20/2025-06:20:36] [I] Selected Device ID: 0
[03/20/2025-06:20:36] [I] Selected Device UUID: GPU-a61c6023-2edd-7390-7ef3-c82d1e2b47d5
[03/20/2025-06:20:36] [I] Compute Capability: 8.0
[03/20/2025-06:20:36] [I] SMs: 108
[03/20/2025-06:20:36] [I] Device Global Memory: 40506 MiB
[03/20/2025-06:20:36] [I] Shared Memory per SM: 164 KiB
[03/20/2025-06:20:36] [I] Memory Bus Width: 5120 bits (ECC enabled)
[03/20/2025-06:20:36] [I] Application Compute Clock Rate: 1.41 GHz
[03/20/2025-06:20:36] [I] Application Memory Clock Rate: 1.215 GHz
[03/20/2025-06:20:36] [I] 
[03/20/2025-06:20:36] [I] Note: The application clock rates do not reflect the actual clock rates that the GPU is currently running at.
[03/20/2025-06:20:36] [I] 
[03/20/2025-06:20:36] [I] TensorRT version: 10.9.0
[03/20/2025-06:20:36] [I] Loading standard plugins
...
[03/20/2025-06:27:16] [V] [TRT] Static check for parsing node: node_Constant_68 [Constant]
[03/20/2025-06:27:16] [V] [TRT] Parsing node: node_Constant_68 [Constant]
[03/20/2025-06:27:16] [V] [TRT] node_Constant_68 [Constant] inputs: 
[03/20/2025-06:27:16] [V] [TRT] node_Constant_68 [Constant] outputs: [val_51 -> (1)[INT64]], 
[03/20/2025-06:27:16] [V] [TRT] Static check for parsing node: node_Reshape_69 [Reshape]
[03/20/2025-06:27:16] [V] [TRT] Parsing node: node_Reshape_69 [Reshape]
[03/20/2025-06:27:16] [V] [TRT] Searching for input: val_8
[03/20/2025-06:27:16] [V] [TRT] Searching for input: val_51
[03/20/2025-06:27:16] [V] [TRT] node_Reshape_69 [Reshape] inputs: [val_8 -> (1)[INT64]], [val_51 -> (1)[INT64]], 
[03/20/2025-06:27:16] [V] [TRT] Registering layer: ONNXTRT_ShapeShuffle_118 required by ONNX-TRT
[03/20/2025-06:27:16] [V] [TRT] Registering layer: node_Reshape_69 for ONNX node: node_Reshape_69
[03/20/2025-06:27:16] [V] [TRT] Registering tensor: val_52 for ONNX tensor: val_52
[03/20/2025-06:27:16] [V] [TRT] node_Reshape_69 [Reshape] outputs: [val_52 -> (1)[INT64]], 
[03/20/2025-06:27:16] [V] [TRT] Static check for parsing node: node_ReduceMean_70 [ReduceMean]
[03/20/2025-06:27:16] [V] [TRT] Parsing node: node_ReduceMean_70 [ReduceMean]
[03/20/2025-06:27:16] [V] [TRT] Searching for input: pow_1
[03/20/2025-06:27:16] [V] [TRT] Searching for input: val_52
[03/20/2025-06:27:16] [V] [TRT] node_ReduceMean_70 [ReduceMean] inputs: [pow_1 -> (-1, -1, 4096)[FLOAT]], [val_52 -> (1)[INT64]], 
[03/20/2025-06:27:16] [E] [TRT] ModelImporter.cpp:961: While parsing node number 70 [ReduceMean -> "mean"]:
[03/20/2025-06:27:16] [E] [TRT] ModelImporter.cpp:964: --- Begin node ---
input: "pow_1"
input: "val_52"
output: "mean"
name: "node_ReduceMean_70"
op_type: "ReduceMean"
attribute {
  name: "keepdims"
  i: 1
  type: INT
}
attribute {
  name: "noop_with_empty_axes"
  i: 0
  type: INT
}
metadata_props {
  key: "namespace"
  value: ": transformers_modules.nvidia.nv-embed-v2.c50d55f43bde7e6a18e0eaa15a62fd63a930f1a1.modeling_nvembed.NVEmbedModel/mean: aten.mean.dim"
}
metadata_props {
  key: "pkg.torch.onnx.class_hierarchy"
  value: "[\'transformers_modules.nvidia.nv-embed-v2.c50d55f43bde7e6a18e0eaa15a62fd63a930f1a1.modeling_nvembed.NVEmbedModel\', \'aten.mean.dim\']"
}
metadata_props {
  key: "pkg.torch.onnx.fx_node"
  value: "%mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})"
}
metadata_props {
  key: "pkg.torch.onnx.name_scopes"
  value: "[\'\', \'mean\']"
}
metadata_props {
  key: "pkg.torch.onnx.stack_trace"
  value: "File \"<eval_with_key>.206\", line 20, in forward\n    mean = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None"
}

[03/20/2025-06:27:16] [E] [TRT] ModelImporter.cpp:965: --- End node ---
[03/20/2025-06:27:16] [E] [TRT] ModelImporter.cpp:967: ERROR: importerUtils.cpp:1598 In function reduceTensor:
[8] Assertion failed: inputAxes.is_weights(): Axis input must be an initializer!
[03/20/2025-06:27:16] [E] Failed to parse onnx file
[03/20/2025-06:27:18] [I] Finished parsing network model. Parse time: 388.921
[03/20/2025-06:27:18] [E] Parsing model failed
[03/20/2025-06:27:18] [E] Failed to create engine from model or file.
[03/20/2025-06:27:18] [E] Engine set up failed
&&&& FAILED TensorRT.trtexec [TensorRT v100900] [b34] # /usr/src/tensorrt/bin/trtexec --onnx=/content/gdrive/MyDrive/nv-embed-v2-onnx-13/model-onnx --saveEngine=/content/gdrive/MyDrive/nv-embed-v2-trt/model-trt.trt --int8 --verbose

Commands or scripts:

!polygraphy inspect \
model /content/gdrive/MyDrive/nv-embed-v2-onnx/model-onnx \
--model-type onnx
[I] Loading model: /content/gdrive/MyDrive/nv-embed-v2-onnx/model-onnx
[I] ==== ONNX Model ====
    Name: main_graph | ONNX Opset: 18 | Other Opsets: {'pkg.onnxscript.torch_lib.common': 1, 'pkg.onnxscript.torch_lib': 1}
    
    ---- 3 Graph Input(s) ----
    {input_ids [dtype=int64, shape=('4', '128')],
     attention_mask [dtype=int64, shape=('4', '128')],
     pool_mask [dtype=int64, shape=('4', '128')]}
    
    ---- 1 Graph Output(s) ----
    {sentence_embeddings [dtype=float32, shape=('4', 4096)]}
    
    ---- 336 Initializer(s) ----
    
    ---- 12114 Node(s) ----

trtexec

!/usr/src/tensorrt/bin/trtexec \
--onnx=/content/gdrive/MyDrive/nv-embed-v2-onnx-13/model-onnx \
--saveEngine=/content/gdrive/MyDrive/nv-embed-v2-trt/model-trt.trt \
--int8 \
--verbose

polygraphy

import tensorrt

!polygraphy run /content/gdrive/MyDrive/nv-embed-v2-onnx/model-onnx \
--model-type onnx \
--trt \
--verbose
[I] RUNNING | Command: /usr/local/bin/polygraphy run /content/gdrive/MyDrive/nv-embed-v2-onnx/model-onnx --model-type onnx --trt
[I] TF32 is disabled by default. Turn on TF32 for better performance with minor accuracy differences.
[I] trt-runner-N0-03/20/25-06:15:01     | Activating and starting inference
[W] ModelImporter.cpp:459: Make sure input input_ids has Int64 binding.
[W] ModelImporter.cpp:459: Make sure input attention_mask has Int64 binding.
[W] ModelImporter.cpp:459: Make sure input pool_mask has Int64 binding.
[E] ModelImporter.cpp:961: While parsing node number 70 [ReduceMean -> "mean"]:
[E] ModelImporter.cpp:964: --- Begin node ---
    input: "pow_1"
    input: "val_52"
    output: "mean"
    name: "node_ReduceMean_70"
    op_type: "ReduceMean"
    attribute {
      name: "keepdims"
      i: 1
      type: INT
    }
    attribute {
      name: "noop_with_empty_axes"
      i: 0
      type: INT
    }
    metadata_props {
      key: "namespace"
      value: ": transformers_modules.nvidia.nv-embed-v2.c50d55f43bde7e6a18e0eaa15a62fd63a930f1a1.modeling_nvembed.NVEmbedModel/mean: aten.mean.dim"
    }
    metadata_props {
      key: "pkg.torch.onnx.class_hierarchy"
      value: "[\'transformers_modules.nvidia.nv-embed-v2.c50d55f43bde7e6a18e0eaa15a62fd63a930f1a1.modeling_nvembed.NVEmbedModel\', \'aten.mean.dim\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.fx_node"
      value: "%mean : [num_users=1] = call_function[target=torch.ops.aten.mean.dim](args = (%pow_1, [-1], True), kwargs = {})"
    }
    metadata_props {
      key: "pkg.torch.onnx.name_scopes"
      value: "[\'\', \'mean\']"
    }
    metadata_props {
      key: "pkg.torch.onnx.stack_trace"
      value: "File \"<eval_with_key>.779\", line 20, in forward\n    mean = torch.ops.aten.mean.dim(pow_1, [-1], True);  pow_1 = None"
    }
[E] ModelImporter.cpp:965: --- End node ---
[E] ModelImporter.cpp:967: ERROR: importerUtils.cpp:1598 In function reduceTensor:
    [8] Assertion failed: inputAxes.is_weights(): Axis input must be an initializer!
[E] In node 70 with name: node_ReduceMean_70 and operator: ReduceMean (reduceTensor): UNSUPPORTED_NODE: Assertion failed: inputAxes.is_weights(): Axis input must be an initializer!
[!] Could not parse ONNX correctly
[E] FAILED | Runtime: 85.222s | Command: /usr/local/bin/polygraphy run /content/gdrive/MyDrive/nv-embed-v2-onnx/model-onnx --model-type onnx --trt

Have you tried the latest release?: No
Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt): Yes, this model work on onnxruntime

@ducknificient ducknificient changed the title Assertion failed of TensorRT 10.9 when export ONNX Model to TensorRT (polygraphy and trtexec) Assertion failed of TensorRT 10.9 when export nvidia-embed-v2 ONNX Model to TensorRT (polygraphy and trtexec) Mar 20, 2025
@lix19937
Copy link

Assertion failed: inputAxes.is_weights(): Axis input must be an initializer.

Check the node node_ReduceMean_70 (ReduceMean) , ReduceMean axes must be an initializer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants