Open
Description
Description
I used the following commands to convert an ONNX model to a TRT engine, where the input.onnx file is the original model:
polygraphy surgeon sanitize --fold-constants ./input.onnx -o output.onnx
trtexec --onnx=output.onnx --saveEngine=model.plan --minShapes=audio:1x256x128,att_cache:32x20x128x128,frame_cache:1x4x128,trunc_start:1,offset:1 \
--optShapes=audio:1x256x128,att_cache:32x20x128x128,frame_cache:1x4x128,trunc_start:1,offset:1 \
--maxShapes=audio:1x256x128,att_cache:32x20x128x128,frame_cache:1x4x128,trunc_start:1,offset:1 --verbose=true
Then I tried to perform inference using TensorRT, but encountered a “Segmentation fault (core dumped)” error. Below are my model information and code:
TRT Engine
[I] ==== TensorRT Engine ====
Name: Unnamed Network 0 | Explicit Batch Engine
---- 5 Engine Input(s) ----
{audio [dtype=float32, shape=(1, 256, 128)],
trunc_start [dtype=int64, shape=(1,)],
offset [dtype=int64, shape=(1,)],
att_cache [dtype=float32, shape=(32, 20, 128, 128)],
frame_cache [dtype=float32, shape=(1, 4, 128)]}
---- 2 Engine Output(s) ----
{output [dtype=float32, shape=(64, 6400)],
att_cache_out [dtype=float32, shape=(32, 20, 256, 128)]}
---- Memory ----
Device Memory: 13281280 bytes
---- 1 Profile(s) (7 Tensor(s) Each) ----
- Profile: 0
Tensor: audio (Input), Index: 0 | Shapes: min=(1, 256, 128), opt=(1, 256, 128), max=(1, 256, 128)
Tensor: trunc_start (Input), Index: 1 | Shapes: min=(1,), opt=(1,), max=(1,)
Tensor: offset (Input), Index: 2 | Shapes: min=(1,), opt=(1,), max=(1,)
Tensor: att_cache (Input), Index: 3 | Shapes: min=(32, 20, 128, 128), opt=(32, 20, 128, 128), max=(32, 20, 128, 128)
Tensor: frame_cache (Input), Index: 4 | Shapes: min=(1, 4, 128), opt=(1, 4, 128), max=(1, 4, 128)
Tensor: output (Output), Index: 5 | Shape: (64, 6400)
Tensor: att_cache_out (Output), Index: 6 | Shape: (32, 20, 256, 128)
---- 505 Layer(s) ----
My Code
from typing import Optional, List, Union
import ctypes
import os
import sys
import argparse
import numpy as np
import tensorrt as trt
import pdb
import torch
import onnxruntime as ort
class TensorRTInfer:
"""
Implements inference for the Model TensorRT engine.
"""
def __init__(self, engine_path):
"""
:param engine_path: The path to the serialized engine to load from disk.
"""
# Load TRT engine
self.logger = trt.Logger(trt.Logger.INFO)
trt.init_libnvinfer_plugins(self.logger, namespace="")
with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
assert runtime
self.engine = runtime.deserialize_cuda_engine(f.read())
assert self.engine
self.context = self.engine.create_execution_context()
assert self.context
def infer(self):
self.stream = torch.cuda.Stream().cuda_stream
# Prepare the input data
audio = torch.randn(1, 256, 128, dtype=torch.float32, device="cuda")
trunc_start = torch.tensor([2], dtype=torch.int64, device="cuda")
offset = torch.tensor([126], dtype=torch.int64, device="cuda")
att_cache = torch.randn(
32, 20, 128, 128, dtype=torch.float32, device="cuda")
frame_cache = torch.randn(
1, 4, 128, dtype=torch.float32, device="cuda")
self.context.set_input_shape('audio', audio.shape)
self.context.set_tensor_address('audio', audio.contiguous().data_ptr())
self.context.set_input_shape('trunc_start', trunc_start.shape)
self.context.set_tensor_address(
'trunc_start', trunc_start.contiguous().data_ptr())
self.context.set_input_shape('offset', offset.shape)
self.context.set_tensor_address(
'offset', offset.contiguous().data_ptr())
self.context.set_input_shape('att_cache', att_cache.shape)
self.context.set_tensor_address(
'att_cache', att_cache.contiguous().data_ptr())
self.context.set_input_shape('frame_cache', frame_cache.shape)
self.context.set_tensor_address(
'frame_cache', frame_cache.contiguous().data_ptr())
# Prepare the output data
att_cache_out = torch.zeros(
32, 20, 256, 128, dtype=torch.float32, device="cuda")
output = torch.zeros(64, 6400, dtype=torch.float32, device="cuda")
self.context.set_tensor_address(
'att_cache_out', att_cache_out.contiguous().data_ptr())
self.context.set_tensor_address(
'output', output.contiguous().data_ptr())
# self.context.set_optimization_profile_async(0, self.stream)
torch.cuda.synchronize()
# a = datetime.now()
self.context.execute_async_v3(self.stream)
torch.cuda.synchronize()
return output
def trt_infer(args):
trt_infer = TensorRTInfer(args.engine)
trt_infer.infer()
def onnx_infer():
ort_session = ort.InferenceSession(
"./poly/encoder.onnx", providers=['CUDAExecutionProvider'])
audio = torch.randn(1, 256, 128, dtype=torch.float32, device="cuda")
trunc_start = torch.tensor([2], dtype=torch.int64, device="cuda")
offset = torch.tensor([126], dtype=torch.int64, device="cuda")
att_cache = torch.randn(
32, 20, 128, 128, dtype=torch.float32, device="cuda")
frame_cache = torch.randn(
1, 4, 128, dtype=torch.float32, device="cuda")
outputs = ort_session.run(
None, {"audio": audio.cpu().numpy(), "trunc_start": trunc_start.cpu().numpy(), "offset": offset.cpu().numpy(), "att_cache": att_cache.cpu().numpy(), "frame_cache": frame_cache.cpu().numpy()})
print(f"output shape={outputs[0].shape} att shape={outputs[1].shape}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-e", "--engine", default="./model.plan", help="The serialized TensorRT engine"
)
args = parser.parse_args()
trt_infer(args)
# onnx_infer()
Environment
TensorRT Version: 10.3
NVIDIA GPU: H20
NVIDIA Driver Version: 535.161.08
CUDA Version: 12.4
CUDNN Version: 8.9
Operating System:
Python Version (if applicable): 3.10.4
Tensorflow Version (if applicable):
PyTorch Version (if applicable): 2.3.0
Baremetal or Container (if so, version):
Relevant Files
Model link:
Steps To Reproduce
Commands or scripts:
Have you tried the latest release?:
Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt
):
&&&& PASSED TensorRT.trtexec [TensorRT v100300] # trtexec --loadEngine=model.plan --verbose