Skip to content

Commit caa369f

Browse files
authored
[Backend] TRT cast GPU input from int64 to int32, output from int32 to int64, and Windows support building CUDA files (#426)
* TRT cast int64 to int32 * windows cmake build cuda src * fix windows cmake error when build cuda src * add a notice in windows gpu build doc * cmake add cuda std=11 * TRT cast output from int32 to int64 * nits * trt get original input output dtype
1 parent 04704c8 commit caa369f

File tree

9 files changed

+181
-25
lines changed

9 files changed

+181
-25
lines changed

CMakeLists.txt

+14-11
Original file line numberDiff line numberDiff line change
@@ -92,16 +92,6 @@ if(BUILD_ON_JETSON)
9292
set(ENABLE_ORT_BACKEND ON)
9393
endif()
9494

95-
# Whether to build CUDA source files in fastdeploy
96-
# CUDA source files include CUDA preprocessing, TRT plugins, etc.
97-
if(WITH_GPU AND UNIX)
98-
set(BUILD_CUDA_SRC ON)
99-
enable_language(CUDA)
100-
set(CUDA_PROPAGATE_HOST_FLAGS FALSE)
101-
else()
102-
set(BUILD_CUDA_SRC OFF)
103-
endif()
104-
10595
# config GIT_URL with github mirrors to speed up dependent repos clone
10696
option(GIT_URL "Git URL to clone dependent repos" ${GIT_URL})
10797
if(NOT GIT_URL)
@@ -177,6 +167,7 @@ configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/core/config.h.
177167
configure_file(${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc.in ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/pybind/main.cc)
178168
file(GLOB_RECURSE ALL_DEPLOY_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/*.cc)
179169
file(GLOB_RECURSE FDTENSOR_FUNC_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cc)
170+
file(GLOB_RECURSE FDTENSOR_FUNC_CUDA_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/function/*.cu)
180171
file(GLOB_RECURSE DEPLOY_ORT_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/ort/*.cc)
181172
file(GLOB_RECURSE DEPLOY_PADDLE_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/paddle/*.cc)
182173
file(GLOB_RECURSE DEPLOY_POROS_SRCS ${PROJECT_SOURCE_DIR}/${CSRCS_DIR_NAME}/fastdeploy/backends/poros/*.cc)
@@ -320,6 +311,18 @@ if(WITH_GPU)
320311
endif()
321312
endif()
322313

314+
# Whether to build CUDA source files in fastdeploy
315+
# CUDA source files include CUDA preprocessing, TRT plugins, etc.
316+
if(WITH_GPU)
317+
set(BUILD_CUDA_SRC ON)
318+
enable_language(CUDA)
319+
set(CMAKE_CUDA_STANDARD 11)
320+
set(CUDA_PROPAGATE_HOST_FLAGS FALSE)
321+
list(APPEND ALL_DEPLOY_SRCS ${FDTENSOR_FUNC_CUDA_SRCS})
322+
else()
323+
set(BUILD_CUDA_SRC OFF)
324+
endif()
325+
323326
if(ENABLE_TRT_BACKEND)
324327
if(APPLE OR ANDROID OR IOS)
325328
message(FATAL_ERROR "Cannot enable tensorrt backend in mac/ios/android os, please set -DENABLE_TRT_BACKEND=OFF.")
@@ -463,7 +466,7 @@ endif()
463466
set_target_properties(${LIBRARY_NAME} PROPERTIES VERSION ${FASTDEPLOY_VERSION})
464467
if(MSVC)
465468
# disable warnings for dll export
466-
target_compile_options(${LIBRARY_NAME} PRIVATE /wd4251)
469+
target_compile_options(${LIBRARY_NAME} PRIVATE "$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CXX>>:/wd4251>$<$<BUILD_INTERFACE:$<COMPILE_LANGUAGE:CUDA>>:-Xcompiler=/wd4251>")
467470
endif()
468471
target_link_libraries(${LIBRARY_NAME} ${DEPEND_LIBS})
469472

docs/cn/build_and_install/gpu.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ FastDeploy当前在GPU环境支持Paddle Inference、ONNX Runtime和TensorRT,
1414

1515
## C++ SDK编译安装
1616

17-
### Linux
17+
### Linux
1818

1919
Linux上编译需满足
2020
- gcc/g++ >= 5.4(推荐8.2)
@@ -48,6 +48,8 @@ Windows编译需要满足条件
4848
- cuda >= 11.2
4949
- cudnn >= 8.2
5050

51+
注意:安装CUDA时需要勾选`Visual Studio Integration`, 或者手动将`C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\extras\visual_studio_integration\MSBuildExtensions\`文件夹下的4个文件复制到`C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations\`文件夹。否则执行cmake命令时可能会遇到`No CUDA toolset found`报错。
52+
5153
在Windows菜单中,找到`x64 Native Tools Command Prompt for VS 2019`打开,执行如下命令
5254

5355
```bat

docs/en/build_and_install/gpu.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ FastDeploy supports Paddle Inference, ONNX Runtime and TensorRT in the GPU envir
1010
| TensorRT | Windows(x64)<br>Linux(x64) | Paddle/ONNX | Support GPU only, and compilation switch is `ENABLE_TRT_BACKEND`. The default is OFF |
1111
| OpenVINO | Windows(x64)<br>Linux(x64) | Paddle/ONNX | Support CPU only, and compilation switch is `ENABLE_OPENVINO_BACKEND`. The default is OFF |
1212

13-
Note:
13+
Note:
1414

1515
When the environment is GPU, please set `WITH_GPU` as ON and specify `CUDA_DIRECTORY`. If TensorRT integration is needed, please specify `TRT_DIRECTORY` as well.
1616

@@ -51,6 +51,8 @@ Prerequisite for Compiling on Windows:
5151
- cuda >= 11.2
5252
- cudnn >= 8.2
5353

54+
Notice: Make sure `Visual Studio Integration` is installed during CUDA installation, or manually copy the 4 files under `C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.2\extras\visual_studio_integration\MSBuildExtensions\` into `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\MSBuild\Microsoft\VC\v160\BuildCustomizations\`. Otherwise, you may run into `No CUDA toolset found` error during cmake.
55+
5456
Launch the x64 Native Tools Command Prompt for VS 2019 from the Windows Start Menu and run the following commands:
5557

5658
```

fastdeploy/backends/tensorrt/trt_backend.cc

+51-11
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
// limitations under the License.
1414

1515
#include "fastdeploy/backends/tensorrt/trt_backend.h"
16+
#include "fastdeploy/function/cuda_cast.h"
1617

1718
#include <cstring>
19+
#include <unordered_map>
1820

1921
#include "NvInferRuntime.h"
2022
#include "fastdeploy/utils/utils.h"
@@ -234,6 +236,7 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
234236
inputs_desc_[i].name = name;
235237
inputs_desc_[i].shape.assign(shape.begin(), shape.end());
236238
inputs_desc_[i].dtype = ReaderDtypeToTrtDtype(onnx_reader.inputs[i].dtype);
239+
inputs_desc_[i].original_dtype = ReaderDtypeToFDDtype(onnx_reader.inputs[i].dtype);
237240
auto info = ShapeRangeInfo(shape);
238241
info.name = name;
239242
auto iter_min = option.min_shape.find(name);
@@ -256,6 +259,8 @@ bool TrtBackend::InitFromOnnx(const std::string& model_file,
256259
outputs_desc_[i].shape.assign(shape.begin(), shape.end());
257260
outputs_desc_[i].dtype =
258261
ReaderDtypeToTrtDtype(onnx_reader.outputs[i].dtype);
262+
outputs_desc_[i].original_dtype =
263+
ReaderDtypeToFDDtype(onnx_reader.outputs[i].dtype);
259264
}
260265

261266
if (option_.external_stream_) {
@@ -315,9 +320,29 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
315320
FDERROR << "Failed to Infer with TensorRT." << std::endl;
316321
return false;
317322
}
323+
for (size_t i = 0; i < outputs->size(); ++i) {
324+
// if the final output tensor's dtype is different from the model output tensor's dtype,
325+
// then we need cast the data to the final output's dtype
326+
auto model_output_dtype = GetFDDataType(outputs_device_buffer_[(*outputs)[i].name].dtype());
327+
if ((*outputs)[i].dtype != model_output_dtype) {
328+
FDTensor output_tensor;
329+
output_tensor.SetExternalData((*outputs)[i].shape, model_output_dtype,
330+
outputs_device_buffer_[(*outputs)[i].name].data(),
331+
Device::GPU);
332+
333+
casted_output_tensors_[(*outputs)[i].name].Resize((*outputs)[i].shape, (*outputs)[i].dtype,
334+
(*outputs)[i].name, Device::GPU);
335+
CudaCast(output_tensor, &casted_output_tensors_[(*outputs)[i].name], stream_);
336+
} else {
337+
casted_output_tensors_[(*outputs)[i].name].SetExternalData(
338+
(*outputs)[i].shape, model_output_dtype,
339+
outputs_device_buffer_[(*outputs)[i].name].data(),
340+
Device::GPU);
341+
}
342+
}
318343
for (size_t i = 0; i < outputs->size(); ++i) {
319344
FDASSERT(cudaMemcpyAsync((*outputs)[i].Data(),
320-
outputs_device_buffer_[(*outputs)[i].name].data(),
345+
casted_output_tensors_[(*outputs)[i].name].Data(),
321346
(*outputs)[i].Nbytes(), cudaMemcpyDeviceToHost,
322347
stream_) == 0,
323348
"[ERROR] Error occurs while copy memory from GPU to CPU.");
@@ -329,6 +354,17 @@ bool TrtBackend::Infer(std::vector<FDTensor>& inputs,
329354
}
330355

331356
void TrtBackend::GetInputOutputInfo() {
357+
// Read the original dtypes from inputs_desc_ and outputs_desc_
358+
std::unordered_map<std::string, FDDataType> inputs_original_dtype_map;
359+
std::unordered_map<std::string, FDDataType> outputs_original_dtype_map;
360+
for (size_t i = 0; i < inputs_desc_.size(); ++i) {
361+
inputs_original_dtype_map[inputs_desc_[i].name] = inputs_desc_[i].original_dtype;
362+
}
363+
for (size_t i = 0; i < outputs_desc_.size(); ++i) {
364+
outputs_original_dtype_map[outputs_desc_[i].name] = outputs_desc_[i].original_dtype;
365+
}
366+
367+
// Re-read the tensor infos from TRT model and write into inputs_desc_ and outputs_desc_
332368
std::vector<TrtValueInfo>().swap(inputs_desc_);
333369
std::vector<TrtValueInfo>().swap(outputs_desc_);
334370
inputs_desc_.clear();
@@ -339,11 +375,14 @@ void TrtBackend::GetInputOutputInfo() {
339375
auto shape = ToVec(engine_->getBindingDimensions(i));
340376
auto dtype = engine_->getBindingDataType(i);
341377
if (engine_->bindingIsInput(i)) {
342-
inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
378+
auto original_dtype = inputs_original_dtype_map.count(name) ? inputs_original_dtype_map[name] : GetFDDataType(dtype);
379+
inputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype, original_dtype});
343380
inputs_device_buffer_[name] = FDDeviceBuffer(dtype);
344381
} else {
345-
outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype});
382+
auto original_dtype = outputs_original_dtype_map.count(name) ? outputs_original_dtype_map[name] : GetFDDataType(dtype);
383+
outputs_desc_.emplace_back(TrtValueInfo{name, shape, dtype, original_dtype});
346384
outputs_device_buffer_[name] = FDDeviceBuffer(dtype);
385+
casted_output_tensors_[name] = FDTensor();
347386
}
348387
}
349388
bindings_.resize(num_binds);
@@ -358,11 +397,12 @@ void TrtBackend::SetInputs(const std::vector<FDTensor>& inputs) {
358397

359398
if (item.device == Device::GPU) {
360399
if (item.dtype == FDDataType::INT64) {
361-
// TODO(liqi): cast int64 to int32
362-
// TRT don't support INT64
363-
FDASSERT(false,
364-
"TRT don't support INT64 input on GPU, "
365-
"please use INT32 input");
400+
inputs_device_buffer_[item.name].resize(dims);
401+
FDTensor input_tensor;
402+
input_tensor.SetExternalData(item.shape, FDDataType::INT32,
403+
inputs_device_buffer_[item.name].data(),
404+
Device::GPU);
405+
CudaCast(item, &input_tensor, stream_);
366406
} else {
367407
// no copy
368408
inputs_device_buffer_[item.name].SetExternalData(dims, item.Data());
@@ -413,7 +453,7 @@ void TrtBackend::AllocateOutputsBuffer(std::vector<FDTensor>* outputs) {
413453
std::vector<int64_t> shape(output_dims.d,
414454
output_dims.d + output_dims.nbDims);
415455
(*outputs)[ori_idx].is_pinned_memory = option_.enable_pinned_memory;
416-
(*outputs)[ori_idx].Resize(shape, GetFDDataType(outputs_desc_[i].dtype),
456+
(*outputs)[ori_idx].Resize(shape, outputs_desc_[i].original_dtype,
417457
outputs_desc_[i].name);
418458

419459
// Allocate output buffer memory
@@ -629,7 +669,7 @@ TensorInfo TrtBackend::GetInputInfo(int index) {
629669
info.name = inputs_desc_[index].name;
630670
info.shape.assign(inputs_desc_[index].shape.begin(),
631671
inputs_desc_[index].shape.end());
632-
info.dtype = GetFDDataType(inputs_desc_[index].dtype);
672+
info.dtype = inputs_desc_[index].original_dtype;
633673
return info;
634674
}
635675

@@ -649,7 +689,7 @@ TensorInfo TrtBackend::GetOutputInfo(int index) {
649689
info.name = outputs_desc_[index].name;
650690
info.shape.assign(outputs_desc_[index].shape.begin(),
651691
outputs_desc_[index].shape.end());
652-
info.dtype = GetFDDataType(outputs_desc_[index].dtype);
692+
info.dtype = outputs_desc_[index].original_dtype;
653693
return info;
654694
}
655695

fastdeploy/backends/tensorrt/trt_backend.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ namespace fastdeploy {
5757
struct TrtValueInfo {
5858
std::string name;
5959
std::vector<int> shape;
60-
nvinfer1::DataType dtype;
60+
nvinfer1::DataType dtype; // dtype of TRT model
61+
FDDataType original_dtype; // dtype of original ONNX/Paddle model
6162
};
6263

6364
struct TrtBackendOption {
@@ -141,6 +142,13 @@ class TrtBackend : public BaseBackend {
141142
// Also will update the range information while inferencing
142143
std::map<std::string, ShapeRangeInfo> shape_range_info_;
143144

145+
// If the final output tensor's dtype is different from the
146+
// model output tensor's dtype, then we need cast the data
147+
// to the final output's dtype.
148+
// E.g. When trt model output tensor is int32, but final tensor is int64
149+
// This map stores the casted tensors.
150+
std::map<std::string, FDTensor> casted_output_tensors_;
151+
144152
void GetInputOutputInfo();
145153
bool CreateTrtEngineFromOnnx(const std::string& onnx_model_buffer);
146154
bool BuildTrtEngine();

fastdeploy/backends/tensorrt/utils.cc

+20
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,26 @@ nvinfer1::DataType ReaderDtypeToTrtDtype(int reader_dtype) {
104104
return nvinfer1::DataType::kFLOAT;
105105
}
106106

107+
FDDataType ReaderDtypeToFDDtype(int reader_dtype) {
108+
if (reader_dtype == 0) {
109+
return FDDataType::FP32;
110+
} else if (reader_dtype == 1) {
111+
return FDDataType::FP64;
112+
} else if (reader_dtype == 2) {
113+
return FDDataType::UINT8;
114+
} else if (reader_dtype == 3) {
115+
return FDDataType::INT8;
116+
} else if (reader_dtype == 4) {
117+
return FDDataType::INT32;
118+
} else if (reader_dtype == 5) {
119+
return FDDataType::INT64;
120+
} else if (reader_dtype == 6) {
121+
return FDDataType::FP16;
122+
}
123+
FDASSERT(false, "Received unexpected data type of %d", reader_dtype);
124+
return FDDataType::FP32;
125+
}
126+
107127
std::vector<int> ToVec(const nvinfer1::Dims& dim) {
108128
std::vector<int> out(dim.d, dim.d + dim.nbDims);
109129
return out;

fastdeploy/backends/tensorrt/utils.h

+7
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ FDDataType GetFDDataType(const nvinfer1::DataType& dtype);
5555

5656
nvinfer1::DataType ReaderDtypeToTrtDtype(int reader_dtype);
5757

58+
FDDataType ReaderDtypeToFDDtype(int reader_dtype);
59+
5860
std::vector<int> ToVec(const nvinfer1::Dims& dim);
5961

6062
template <typename T>
@@ -153,6 +155,11 @@ class FDGenericBuffer {
153155
//!
154156
size_t nbBytes() const { return this->size() * TrtDataTypeSize(mType); }
155157

158+
//!
159+
//! \brief Returns the dtype of the buffer.
160+
//!
161+
nvinfer1::DataType dtype() const { return mType; }
162+
156163
//!
157164
//! \brief Set user memory buffer for TRT Buffer
158165
//!

fastdeploy/function/cuda_cast.cu

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "fastdeploy/function/cuda_cast.h"
16+
17+
namespace fastdeploy {
18+
19+
template <typename T_IN, typename T_OUT>
20+
__global__ void CudaCastKernel(const T_IN* in, T_OUT* out, int edge) {
21+
int position = blockDim.x * blockIdx.x + threadIdx.x;
22+
if (position >= edge) return;
23+
out[position] = (T_OUT)in[position];
24+
}
25+
26+
void CudaCast(const FDTensor& in, FDTensor* out, cudaStream_t stream) {
27+
int jobs = in.Numel();
28+
int threads = 256;
29+
int blocks = ceil(jobs / (float)threads);
30+
if (in.dtype == FDDataType::INT64 && out->dtype == FDDataType::INT32) {
31+
CudaCastKernel<int64_t, int32_t><<<blocks, threads, 0, stream>>>(
32+
reinterpret_cast<int64_t*>(const_cast<void*>(in.Data())),
33+
reinterpret_cast<int32_t*>(out->MutableData()),
34+
jobs);
35+
} else if (in.dtype == FDDataType::INT32 && out->dtype == FDDataType::INT64) {
36+
CudaCastKernel<int32_t, int64_t><<<blocks, threads, 0, stream>>>(
37+
reinterpret_cast<int32_t*>(const_cast<void*>(in.Data())),
38+
reinterpret_cast<int64_t*>(out->MutableData()),
39+
jobs);
40+
} else {
41+
FDASSERT(false, "CudaCast only support input INT64, output INT32.");
42+
}
43+
}
44+
45+
} // namespace fastdeploy

fastdeploy/function/cuda_cast.h

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "fastdeploy/core/fd_tensor.h"
18+
19+
namespace fastdeploy {
20+
21+
/** Cast the type of the data in GPU buffer.
22+
@param in The input tensor.
23+
@param out The output tensor
24+
@param stream CUDA stream
25+
*/
26+
FASTDEPLOY_DECL void CudaCast(const FDTensor& in, FDTensor* out,
27+
cudaStream_t stream);
28+
29+
} // namespace fastdeploy

0 commit comments

Comments
 (0)