Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Got Segmentation fault (core dumped) of TensorRT 10.3 when running execute_async_v3 on GPU H20 #4395

Open
simonzgx opened this issue Mar 21, 2025 · 4 comments

Comments

@simonzgx
Copy link

simonzgx commented Mar 21, 2025

Description

I used the following commands to convert an ONNX model to a TRT engine, where the input.onnx file is the original model:

polygraphy surgeon sanitize --fold-constants ./input.onnx  -o output.onnx
trtexec --onnx=output.onnx --saveEngine=model.plan --minShapes=audio:1x256x128,att_cache:32x20x128x128,frame_cache:1x4x128,trunc_start:1,offset:1 \
        --optShapes=audio:1x256x128,att_cache:32x20x128x128,frame_cache:1x4x128,trunc_start:1,offset:1 \
        --maxShapes=audio:1x256x128,att_cache:32x20x128x128,frame_cache:1x4x128,trunc_start:1,offset:1 --verbose=true

Then I tried to perform inference using TensorRT, but encountered a “Segmentation fault (core dumped)” error. Below are my model information and code:

TRT Engine

[I] ==== TensorRT Engine ====
    Name: Unnamed Network 0 | Explicit Batch Engine
    
    ---- 5 Engine Input(s) ----
    {audio [dtype=float32, shape=(1, 256, 128)],
     trunc_start [dtype=int64, shape=(1,)],
     offset [dtype=int64, shape=(1,)],
     att_cache [dtype=float32, shape=(32, 20, 128, 128)],
     frame_cache [dtype=float32, shape=(1, 4, 128)]}
    
    ---- 2 Engine Output(s) ----
    {output [dtype=float32, shape=(64, 6400)],
     att_cache_out [dtype=float32, shape=(32, 20, 256, 128)]}
    
    ---- Memory ----
    Device Memory: 13281280 bytes
    
    ---- 1 Profile(s) (7 Tensor(s) Each) ----
    - Profile: 0
        Tensor: audio                  (Input), Index: 0 | Shapes: min=(1, 256, 128), opt=(1, 256, 128), max=(1, 256, 128)
        Tensor: trunc_start            (Input), Index: 1 | Shapes: min=(1,), opt=(1,), max=(1,)
        Tensor: offset                 (Input), Index: 2 | Shapes: min=(1,), opt=(1,), max=(1,)
        Tensor: att_cache              (Input), Index: 3 | Shapes: min=(32, 20, 128, 128), opt=(32, 20, 128, 128), max=(32, 20, 128, 128)
        Tensor: frame_cache            (Input), Index: 4 | Shapes: min=(1, 4, 128), opt=(1, 4, 128), max=(1, 4, 128)
        Tensor: output                (Output), Index: 5 | Shape: (64, 6400)
        Tensor: att_cache_out         (Output), Index: 6 | Shape: (32, 20, 256, 128)
    
    ---- 505 Layer(s) ----

My Code

from typing import Optional, List, Union
import ctypes
import os
import sys
import argparse
import numpy as np
import tensorrt as trt
import pdb
import torch
import onnxruntime as ort


class TensorRTInfer:
    """
    Implements inference for the Model TensorRT engine.
    """
    def __init__(self, engine_path):
        """
        :param engine_path: The path to the serialized engine to load from disk.
        """
        # Load TRT engine
        self.logger = trt.Logger(trt.Logger.INFO)
        trt.init_libnvinfer_plugins(self.logger, namespace="")
        with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
            assert runtime
            self.engine = runtime.deserialize_cuda_engine(f.read())
        assert self.engine
        self.context = self.engine.create_execution_context()
        assert self.context

    def infer(self):
        self.stream = torch.cuda.Stream().cuda_stream
        # Prepare the input data
        audio = torch.randn(1, 256, 128, dtype=torch.float32, device="cuda")
        trunc_start = torch.tensor([2], dtype=torch.int64, device="cuda")
        offset = torch.tensor([126], dtype=torch.int64, device="cuda")
        att_cache = torch.randn(
            32, 20, 128, 128, dtype=torch.float32, device="cuda")
        frame_cache = torch.randn(
            1, 4, 128, dtype=torch.float32, device="cuda")
        self.context.set_input_shape('audio', audio.shape)
        self.context.set_tensor_address('audio', audio.contiguous().data_ptr())
        self.context.set_input_shape('trunc_start', trunc_start.shape)
        self.context.set_tensor_address(
            'trunc_start', trunc_start.contiguous().data_ptr())
        self.context.set_input_shape('offset', offset.shape)
        self.context.set_tensor_address(
            'offset', offset.contiguous().data_ptr())
        self.context.set_input_shape('att_cache', att_cache.shape)
        self.context.set_tensor_address(
            'att_cache', att_cache.contiguous().data_ptr())
        self.context.set_input_shape('frame_cache', frame_cache.shape)
        self.context.set_tensor_address(
            'frame_cache', frame_cache.contiguous().data_ptr())
        # Prepare the output data
        att_cache_out = torch.zeros(
            32, 20, 256, 128, dtype=torch.float32, device="cuda")
        output = torch.zeros(64, 6400, dtype=torch.float32, device="cuda")
        self.context.set_tensor_address(
            'att_cache_out', att_cache_out.contiguous().data_ptr())
        self.context.set_tensor_address(
            'output', output.contiguous().data_ptr())
        # self.context.set_optimization_profile_async(0, self.stream)
        torch.cuda.synchronize()
        # a = datetime.now()
        self.context.execute_async_v3(self.stream)
        torch.cuda.synchronize()
        return output


def trt_infer(args):

    trt_infer = TensorRTInfer(args.engine)
    trt_infer.infer()


def onnx_infer():
    ort_session = ort.InferenceSession(
        "./poly/encoder.onnx", providers=['CUDAExecutionProvider'])
    audio = torch.randn(1, 256, 128, dtype=torch.float32, device="cuda")
    trunc_start = torch.tensor([2], dtype=torch.int64, device="cuda")
    offset = torch.tensor([126], dtype=torch.int64, device="cuda")
    att_cache = torch.randn(
        32, 20, 128, 128, dtype=torch.float32, device="cuda")
    frame_cache = torch.randn(
        1, 4, 128, dtype=torch.float32, device="cuda")
    outputs = ort_session.run(
        None, {"audio": audio.cpu().numpy(), "trunc_start": trunc_start.cpu().numpy(), "offset": offset.cpu().numpy(), "att_cache": att_cache.cpu().numpy(), "frame_cache": frame_cache.cpu().numpy()})
    print(f"output shape={outputs[0].shape} att shape={outputs[1].shape}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-e", "--engine", default="./model.plan", help="The serialized TensorRT engine"
    )
    args = parser.parse_args()
    trt_infer(args)
    # onnx_infer()

Environment

TensorRT Version: 10.3

NVIDIA GPU: H20

NVIDIA Driver Version: 535.161.08

CUDA Version: 12.4

CUDNN Version: 8.9

Operating System:

Python Version (if applicable): 3.10.4

Tensorflow Version (if applicable):

PyTorch Version (if applicable): 2.3.0

Baremetal or Container (if so, version):

Relevant Files

Model link:

Steps To Reproduce

Commands or scripts:

Have you tried the latest release?:

Can this model run on other frameworks? For example run ONNX model with ONNXRuntime (polygraphy run <model.onnx> --onnxrt):

&&&& PASSED TensorRT.trtexec [TensorRT v100300] # trtexec --loadEngine=model.plan --verbose
@lix19937
Copy link

@simonzgx
Copy link
Author

You can ref https://github.com/lix19937/tensorrt-insight/blob/main/tool/infer_from_engine_v10x.py

Thanks for your reply. According to the code you provided, I update my code as follows:

import os
from collections import OrderedDict  # keep the order of the tensors implicitly
from pathlib import Path

import numpy as np
import tensorrt as trt
from cuda import cudart

# yapf:disable

trt_file = Path("model.plan")
audio = np.arange(1 * 256 * 128, dtype=np.float32).reshape(1, 256, 128)                  # inference input data
trunc_start = np.array([2], dtype=np.int64)
offset = np.array([126], dtype=np.int64)
att_cache = np.arange(32 * 20 * 128 * 128, dtype=np.float32).reshape(32, 20, 128, 128)
frame_cache = np.arange(4 * 128, dtype=np.float32).reshape(1, 4, 128)


def run():
    # create Logger, available level: VERBOSE, INFO, WARNING, ERROR, INTERNAL_ERROR
    logger = trt.Logger(trt.Logger.ERROR)
    # load engine from file and skip building process if it existed
    if trt_file.exists():
        with open(trt_file, "rb") as f:
            engine_bytes = f.read()
        if engine_bytes == None:
            print("Fail getting serialized engine")
            return
        print("Succeed getting serialized engine")
    # build a serialized network from scratch
    else:
        # create Builder
        builder = trt.Builder(logger)
        # create BuidlerConfig to set attribution of the network
        config = builder.create_builder_config()
        # create Network
        network = builder.create_network()
        # create OptimizationProfile if using Dynamic-Shape mode
        profile = builder.create_optimization_profile()
        # set workspace for the building process (all GPU memory is used by default)
        config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)

        # set input tensor of the network
        input_tensor = network.add_input(
            input_tensor_name, trt.float32, [-1, -1, -1])
        # set dynamic shape range of the input tensor
        profile.set_shape(input_tensor.name, [1, 1, 1], [3, 4, 5], [6, 8, 10])
        # add the Optimization Profile into the BuilderConfig
        config.add_optimization_profile(profile)

        # here is only an identity layer in this simple network, which the output is exactly equal to input
        identity_layer = network.add_identity(input_tensor)
        # mark the tensor for output
        network.mark_output(identity_layer.get_output(0))

        # create a serialized network from the network
        engine_bytes = builder.build_serialized_network(network, config)
        if engine_bytes == None:
            print("Fail building engine")
            return
        print("Succeed building engine")
        # save the serialized network as binaray file
        with open(trt_file, "wb") as f:
            f.write(engine_bytes)
            print(f"Succeed saving engine ({trt_file})")

    engine = trt.Runtime(logger).deserialize_cuda_engine(
        engine_bytes)          # create inference engine
    if engine == None:
        print("Fail getting engine for inference")
        return
    print("Succeed getting engine for inference")
    # create Execution Context from the engine (analogy to a GPU context, or a CPU process)
    context = engine.create_execution_context()

    tensor_name_list = [engine.get_tensor_name(
        i) for i in range(engine.num_io_tensors)]

    # set runtime size of input tensor if using Dynamic-Shape mode
    context.set_input_shape("audio", audio.shape)
    context.set_input_shape("trunc_start", trunc_start.shape)
    context.set_input_shape("offset", offset.shape)
    context.set_input_shape("att_cache", att_cache.shape)
    context.set_input_shape("frame_cache", frame_cache.shape)

    # Print information of input / output tensors
    for name in tensor_name_list:
        mode = engine.get_tensor_mode(name)
        data_type = engine.get_tensor_dtype(name)
        buildtime_shape = engine.get_tensor_shape(name)
        runtime_shape = context.get_tensor_shape(name)
        print(f"{'Input ' if mode == trt.TensorIOMode.INPUT else 'Output'}->{data_type}, {buildtime_shape}, {runtime_shape}, {name}")

    # prepare the memory buffer on host and device
    buffer = OrderedDict()
    for name in tensor_name_list:
        data_type = engine.get_tensor_dtype(name)
        runtime_shape = context.get_tensor_shape(name)
        n_byte = trt.volume(runtime_shape) * \
            np.dtype(trt.nptype(data_type)).itemsize
        host_buffer = np.empty(runtime_shape, dtype=trt.nptype(data_type))
        device_buffer = cudart.cudaMalloc(n_byte)[1]
        buffer[name] = [host_buffer, device_buffer, n_byte]
    import pdb; pdb.set_trace()
    # set runtime data, MUST use np.ascontiguousarray, it is a SERIOUS lesson
    buffer["audio"][0] = np.ascontiguousarray(audio)
    buffer["trunc_start"][0] = np.ascontiguousarray(trunc_start)
    buffer["offset"][0] = np.ascontiguousarray(offset)
    buffer["att_cache"][0] = np.ascontiguousarray(att_cache)
    buffer["frame_cache"][0] = np.ascontiguousarray(frame_cache)

    for name in tensor_name_list:
        # bind address of device buffer to context
        context.set_tensor_address(name, buffer[name][1])

    # copy input data from host to device
    for name in tensor_name_list:
        if engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
            cudart.cudaMemcpy(buffer[name][1], buffer[name][0].ctypes.data,
                              buffer[name][2], cudart.cudaMemcpyKind.cudaMemcpyHostToDevice)

    # do inference computation
    context.execute_async_v3(0)

    # copy output data from device to host
    for name in tensor_name_list:
        if engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
            cudart.cudaMemcpy(buffer[name][0].ctypes.data, buffer[name][1],
                              buffer[name][2], cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost)

    for name in tensor_name_list:
        print(name)
        print(buffer[name][0])

    # free the GPU memory buffer after all work
    for _, device_buffer, _ in buffer.values():
        cudart.cudaFree(device_buffer)


if __name__ == "__main__":
    os.system("rm -rf *.trt")

    # build a TensorRT engine and do inference
    run()
    # load a TensorRT engine and do inference
    run()

    print("Finish")

but the Segmentation fault (core dumped) still persists

@lix19937
Copy link

Use a common onnx like resnet50.onnx, then build a plan, and run my script, to check pass or not.

@simonzgx
Copy link
Author

Use a common onnx like resnet50.onnx, then build a plan, and run my script, to check pass or not.

I tried writing a very simple 2-layer LSTM model and converted it to a TRT engine in the same way, and it worked fine. Of course, I could also try ResNet50.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants