Skip to content

Got Segmentation fault (core dumped) of TensorRT 10.3 when running execute_async_v3 on GPU H20 #4395

Open
@simonzgx

Description

@simonzgx

Description

I used the following commands to convert an ONNX model to a TRT engine, where the input.onnx file is the original model:

polygraphy surgeon sanitize --fold-constants ./input.onnx  -o output.onnx
trtexec --onnx=output.onnx --saveEngine=model.plan --minShapes=audio:1x256x128,att_cache:32x20x128x128,frame_cache:1x4x128,trunc_start:1,offset:1 \
        --optShapes=audio:1x256x128,att_cache:32x20x128x128,frame_cache:1x4x128,trunc_start:1,offset:1 \
        --maxShapes=audio:1x256x128,att_cache:32x20x128x128,frame_cache:1x4x128,trunc_start:1,offset:1 --verbose=true

Then I tried to perform inference using TensorRT, but encountered a “Segmentation fault (core dumped)” error. Below are my model information and code:

TRT Engine

[I] ==== TensorRT Engine ====
    Name: Unnamed Network 0 | Explicit Batch Engine
    
    ---- 5 Engine Input(s) ----
    {audio [dtype=float32, shape=(1, 256, 128)],
     trunc_start [dtype=int64, shape=(1,)],
     offset [dtype=int64, shape=(1,)],
     att_cache [dtype=float32, shape=(32, 20, 128, 128)],
     frame_cache [dtype=float32, shape=(1, 4, 128)]}
    
    ---- 2 Engine Output(s) ----
    {output [dtype=float32, shape=(64, 6400)],
     att_cache_out [dtype=float32, shape=(32, 20, 256, 128)]}
    
    ---- Memory ----
    Device Memory: 13281280 bytes
    
    ---- 1 Profile(s) (7 Tensor(s) Each) ----
    - Profile: 0
        Tensor: audio                  (Input), Index: 0 | Shapes: min=(1, 256, 128), opt=(1, 256, 128), max=(1, 256, 128)
        Tensor: trunc_start            (Input), Index: 1 | Shapes: min=(1,), opt=(1,), max=(1,)
        Tensor: offset                 (Input), Index: 2 | Shapes: min=(1,), opt=(1,), max=(1,)
        Tensor: att_cache              (Input), Index: 3 | Shapes: min=(32, 20, 128, 128), opt=(32, 20, 128, 128), max=(32, 20, 128, 128)
        Tensor: frame_cache            (Input), Index: 4 | Shapes: min=(1, 4, 128), opt=(1, 4, 128), max=(1, 4, 128)
        Tensor: output                (Output), Index: 5 | Shape: (64, 6400)
        Tensor: att_cache_out         (Output), Index: 6 | Shape: (32, 20, 256, 128)
    
    ---- 505 Layer(s) ----

My Code

from typing import Optional, List, Union
import ctypes
import os
import sys
import argparse
import numpy as np
import tensorrt as trt
import pdb
import torch
import onnxruntime as ort


class TensorRTInfer:
    """
    Implements inference for the Model TensorRT engine.
    """
    def __init__(self, engine_path):
        """
        :param engine_path: The path to the serialized engine to load from disk.
        """
        # Load TRT engine
        self.logger = trt.Logger(trt.Logger.INFO)
        trt.init_libnvinfer_plugins(self.logger, namespace="")
        with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
            assert runtime
            self.engine = runtime.deserialize_cuda_engine(f.read())
        assert self.engine
        self.context = self.engine.create_execution_context()
        assert self.context

    def infer(self):
        self.stream = torch.cuda.Stream().cuda_stream
        # Prepare the input data
        audio = torch.randn(1, 256, 128, dtype=torch.float32, device="cuda")
        trunc_start = torch.tensor([2], dtype=torch.int64, device="cuda")
        offset = torch.tensor([126], dtype=torch.int64, device="cuda")
        att_cache = torch.randn(
            32, 20, 128, 128, dtype=torch.float32, device="cuda")
        frame_cache = torch.randn(
            1, 4, 128, dtype=torch.float32, device="cuda")
        self.context.set_input_shape('audio', audio.shape)
        self.context.set_tensor_address('audio', audio.contiguous().data_ptr())
        self.context.set_input_shape('trunc_start', trunc_start.shape)
        self.context.set_tensor_address(
            'trunc_start', trunc_start.contiguous().data_ptr())
        self.context.set_input_shape('offset', offset.shape)
        self.context.set_tensor_address(
            'offset', offset.contiguous().data_ptr())
        self.context.set_input_shape('att_cache', att_cache.shape)
        self.context.set_tensor_address(
            'att_cache', att_cache.contiguous().data_ptr())
        self.context.set_input_shape('frame_cache', frame_cache.shape)
        self.context.set_tensor_address(
            'frame_cache', frame_cache.contiguous().data_ptr())
        # Prepare the output data
        att_cache_out = torch.zeros(
            32, 20, 256, 128, dtype=torch.float32, device="cuda")
        output = torch.zeros(64, 6400, dtype=torch.float32, device="cuda")
        self.context.set_tensor_address(
            'att_cache_out', att_cache_out.contiguous().data_ptr())
        self.context.set_tensor_address(
            'output', output.contiguous().data_ptr())
        # self.context.set_optimization_profile_async(0, self.stream)
        torch.cuda.synchronize()
        # a = datetime.now()
        self.context.execute_async_v3(self.stream)
        torch.cuda.synchronize()
        return output


def trt_infer(args):

    trt_infer = TensorRTInfer(args.engine)
    trt_infer.infer()


def onnx_infer():
    ort_session = ort.InferenceSession(
        "./poly/encoder.onnx", providers=['CUDAExecutionProvider'])
    audio = torch.randn(1, 256, 128, dtype=torch.float32, device="cuda")
    trunc_start = torch.tensor([2], dtype=torch.int64, device="cuda")
    offset = torch.tensor([126], dtype=torch.int64, device="cuda")
    att_cache = torch.randn(
        32, 20, 128, 128, dtype=torch.float32, device="cuda")
    frame_cache = torch.randn(
        1, 4, 128, dtype=torch.float32, device="cuda")
    outputs = ort_session.run(
        None, {"audio": audio.cpu().numpy(), "trunc_start": trunc_start.cpu().numpy(), "offset": offset.cpu().numpy(), "att_cache": att_cache.cpu().numpy(), "frame_cache": frame_cache.cpu().numpy()})
    print(f"output shape={outputs[0].shape} att shape={outputs[1].shape}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-e", "--engine", default="./model.plan", help="The serialized TensorRT engine"
    )
    args = parser.parse_args()
    trt_infer(args)
    # onnx_infer()

Environment

TensorRT Version: 10.3

NVIDIA GPU: H20

NVIDIA Driver Version: 535.161.08

CUDA Version: 12.4

CUDNN Version: 8.9

Operating System:

Python Version (if applicable): 3.10.4

Tensorflow Version (if applicable):

PyTorch Version (if applicable): 2.3.0

Baremetal or Container (if so, version):

Relevant Files

Model link:

Steps To Reproduce

Commands or scripts:

Have you tried the latest release?:

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt):

&&&& PASSED TensorRT.trtexec [TensorRT v100300] # trtexec --loadEngine=model.plan --verbose

Metadata

Metadata

Assignees

No one assigned

    Labels

    triagedIssue has been triaged by maintainers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions