Skip to content

Commit 0e7f4fe

Browse files
authored
Unifying the FX and TS Runtimes (#1404)
* feat: Adding profiling support to the runtime Signed-off-by: Naren Dasan <[email protected]> * refactor: A new TRTModule implementation using the internal runtime which should give TS for free Signed-off-by: Naren Dasan <[email protected]> * feat: let Input generate random tensors following the spec Signed-off-by: Naren Dasan <[email protected]> * feat!(//core/runtime): Allow the Runtime to use binding names to align I/O BREAKING CHANGE: This commit contains an ABI version upgrade meaning that existing compiled modules will not work with this runtime. Recompilation with a newer version of Torch-TensorRT will fix this. This also ammends the C++ to allow users to explicitly set binding names in the order they will be passed in and are expected to be returned. This change is backwards compatible with the current API. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * fix(//core/runtime): Resolving some issues with the runtime ABI Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * feat(//core/runtime): Adding a TRT layer profiler Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * feat(//py): Exposed the new runtime in Python Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * feat(//py/torch_tensorrt/fx): Compliant TRTModule implementation based on shared Torch-TensorRT runtime Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * refactor: CUDADevice -> RTDevice for better distinction from compile time device Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * feat(//examples): Demo that you can compile using FX then deploy in TS!!! Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * refactor(//py/torch_tensorrt): Updates to existing APIs for use in fx Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * feat(//core/runtime): Encode TRT engine in base64 instead of raw bytes Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * feat(//py/torch_tensorrt/fx): Adding the option to use the experimental runtime Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * fix(//core/runtime): Fixing a bug where if an exception is thrown in downstream constructor, it would cause a segfault Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * feat(//py/torch_tensorrt/TRTModule): Allow state_dict extraction Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * chore: Addressing merge conflicts Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * chore: lint Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * chore: remove print statements Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * fix: Fix cmake build Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * refactor: Add a suffix to the TRTModuleNext class while it's experimental Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * docs: Update docs and examples Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> * refactor: Reorder the API since everything but the engine is optional Also new destructor to order cleanup Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent e7bb8c2 commit 0e7f4fe

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1704
-281
lines changed

CMakeLists.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
cmake_minimum_required(VERSION 3.17)
33
project(Torch-TensorRT LANGUAGES CXX)
44

5-
# use c++17
6-
set(CMAKE_CXX_STANDARD 17)
5+
# use c++14 like PyTorch
6+
set(CMAKE_CXX_STANDARD 14)
77

88
# Build the libraries with -fPIC
99
set(CMAKE_POSITION_INDEPENDENT_CODE ON)

core/compiler.cpp

+26-8
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,17 @@ void AddEngineToGraph(
3131
torch::jit::script::Module mod,
3232
std::shared_ptr<torch::jit::Graph>& g,
3333
const std::string& serialized_engine,
34-
runtime::CudaDevice& device_info,
34+
runtime::RTDevice& device_info,
35+
const std::vector<std::string>& input_binding_names,
36+
const std::vector<std::string>& output_binding_names,
3537
std::string engine_id = "",
3638
bool fallback = false) {
3739
auto engine_ptr = c10::make_intrusive<runtime::TRTEngine>(
38-
mod._ivalue()->name() + "_engine_" + engine_id, serialized_engine, device_info);
40+
mod._ivalue()->name() + "_engine_" + engine_id,
41+
serialized_engine,
42+
device_info,
43+
input_binding_names,
44+
output_binding_names);
3945
// Get required metadata about the engine out
4046
auto num_io = engine_ptr->num_io;
4147
auto name = engine_ptr->name;
@@ -162,8 +168,16 @@ partitioning::GraphAndMapping BuildHybridGraph(
162168
auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_info, static_params);
163169
auto temp_g = std::make_shared<torch::jit::Graph>();
164170
auto device_spec = convert_info.engine_settings.device;
165-
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
166-
AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
171+
auto cuda_device = runtime::RTDevice(device_spec.gpu_id, device_spec.device_type);
172+
AddEngineToGraph(
173+
new_mod,
174+
temp_g,
175+
engine,
176+
cuda_device,
177+
std::vector<std::string>(),
178+
std::vector<std::string>(),
179+
trt_engine_id.str(),
180+
true);
167181

168182
seg_block.update_graph(temp_g);
169183
}
@@ -279,7 +293,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
279293
torch::jit::Module new_mod(mod._ivalue()->name() + "_trt");
280294

281295
auto device_spec = cfg.convert_info.engine_settings.device;
282-
auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
296+
auto cuda_device = runtime::RTDevice(device_spec.gpu_id, device_spec.device_type);
283297

284298
for (const torch::jit::Method& method : mod.get_methods()) {
285299
if (method.name().compare("forward") == 0) {
@@ -327,7 +341,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
327341
"Not all operations in graph are supported by the compiler");
328342
// TODO find the right
329343
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
330-
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
344+
AddEngineToGraph(new_mod, new_g, engine, cuda_device, std::vector<std::string>(), std::vector<std::string>());
331345
}
332346
auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
333347
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
@@ -338,12 +352,16 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
338352
return new_mod;
339353
}
340354

341-
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CudaDevice cuda_device) {
355+
torch::jit::script::Module EmbedEngineInNewModule(
356+
const std::string& engine,
357+
runtime::RTDevice cuda_device,
358+
const std::vector<std::string>& input_binding_names,
359+
const std::vector<std::string>& output_binding_names) {
342360
std::ostringstream engine_id;
343361
engine_id << reinterpret_cast<const int*>(&engine);
344362
torch::jit::script::Module new_mod("tensorrt_engine_mod_" + engine_id.str());
345363
auto new_g = std::make_shared<torch::jit::Graph>();
346-
AddEngineToGraph(new_mod, new_g, engine, cuda_device);
364+
AddEngineToGraph(new_mod, new_g, engine, cuda_device, input_binding_names, output_binding_names);
347365
auto new_method = new_mod._ivalue()->compilation_unit()->create_function("forward", new_g);
348366
auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
349367
new_mod.type()->addMethod(new_method);

core/compiler.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
2828

2929
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, CompileSpec cfg);
3030

31-
torch::jit::script::Module EmbedEngineInNewModule(const std::string& engine, runtime::CudaDevice cuda_device);
31+
torch::jit::script::Module EmbedEngineInNewModule(
32+
const std::string& engine,
33+
runtime::RTDevice cuda_device,
34+
const std::vector<std::string>& input_binding_names,
35+
const std::vector<std::string>& output_binding_names);
3236

3337
void set_device(const int gpu_id);
3438

core/conversion/converters/impl/expand.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -374,12 +374,13 @@ auto expand_registrations TORCHTRT_UNUSED =
374374

375375
// Collapse repeated dimension back into desired dimension
376376
std::vector<int64_t> collapse_shape_vec;
377-
for (int k = 0; k < repeat_shape_dims.nbDims; k++) {
377+
for (int64_t k = 0; k < repeat_shape_dims.nbDims; k++) {
378378
if (k == dim) {
379-
int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[++k];
379+
int64_t collapse_dim = repeat_shape_dims.d[k] * repeat_shape_dims.d[k + 1];
380380
// Set dim size to -1 if repeat is being done on dynamic dim
381381
collapse_dim = std::max(collapse_dim, (int64_t)-1);
382382
collapse_shape_vec.push_back(collapse_dim);
383+
k++;
383384
} else {
384385
collapse_shape_vec.push_back(repeat_shape_dims.d[k]);
385386
}

core/conversion/converters/impl/select.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ auto select_registrations TORCHTRT_UNUSED =
287287

288288
std::vector<nvinfer1::ITensor*> tensors;
289289
std::vector<int32_t> adv_idx_indices;
290-
for (auto i = 0; i < ts.size(); i++) {
290+
for (size_t i = 0; i < ts.size(); i++) {
291291
auto t = ts[i];
292292
if (t.isTensor()) {
293293
auto torch_tensor = t.toTensor().to(torch::kInt32);

core/runtime/BUILD

+14-2
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,23 @@ config_setting(
1313
cc_library(
1414
name = "runtime",
1515
srcs = [
16-
"CudaDevice.cpp",
1716
"DeviceList.cpp",
17+
"RTDevice.cpp",
1818
"TRTEngine.cpp",
19+
"TRTEngineProfiler.cpp",
1920
"execute_engine.cpp",
2021
"register_jit_hooks.cpp",
2122
"runtime.cpp",
2223
],
2324
hdrs = [
25+
"RTDevice.h",
26+
"TRTEngine.h",
27+
"TRTEngineProfiler.h",
2428
"runtime.h",
2529
],
30+
linkopts = [
31+
"-lstdc++fs",
32+
],
2633
deps = [
2734
"@tensorrt//:nvinfer",
2835
"//core/util:prelude",
@@ -36,6 +43,11 @@ cc_library(
3643

3744
pkg_tar(
3845
name = "include",
39-
srcs = ["runtime.h"],
46+
srcs = [
47+
"RTDevice.h",
48+
"TRTEngine.h",
49+
"TRTEngineProfiler.h",
50+
"runtime.h",
51+
],
4052
package_dir = "core/runtime/",
4153
)

core/runtime/CMakeLists.txt

+7-2
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@ set(lib_name "core_runtime")
22
add_library(${lib_name} OBJECT)
33

44
set(CXX_SRCS
5-
"${CMAKE_CURRENT_SOURCE_DIR}/CudaDevice.cpp"
65
"${CMAKE_CURRENT_SOURCE_DIR}/DeviceList.cpp"
7-
"${CMAKE_CURRENT_SOURCE_DIR}/execute_engine.cpp"
6+
"${CMAKE_CURRENT_SOURCE_DIR}/RTDevice.cpp"
87
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngine.cpp"
8+
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngineProfiler.cpp"
9+
"${CMAKE_CURRENT_SOURCE_DIR}/execute_engine.cpp"
910
"${CMAKE_CURRENT_SOURCE_DIR}/register_jit_hooks.cpp"
1011
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.cpp"
1112
)
1213

1314
set(HEADER_FILES
15+
"${CMAKE_CURRENT_SOURCE_DIR}/RTDevice.h"
16+
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngine.h"
17+
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngineProfiler.h"
1418
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.h"
1519
)
1620

@@ -29,6 +33,7 @@ target_link_libraries(${lib_name}
2933
TensorRT::nvinfer
3034
torch
3135
core_util
36+
stdc++fs
3237
)
3338

3439
# Install

core/runtime/DeviceList.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,19 @@ DeviceList::DeviceList() {
1515
}
1616

1717
for (int i = 0; i < num_devices; i++) {
18-
device_list[i] = CudaDevice(i, nvinfer1::DeviceType::kGPU);
18+
device_list[i] = RTDevice(i, nvinfer1::DeviceType::kGPU);
1919
}
2020

2121
// REVIEW: DO WE CARE ABOUT DLA?
2222

2323
LOG_DEBUG("Runtime:\n Available CUDA Devices: \n" << this->dump_list());
2424
}
2525

26-
void DeviceList::insert(int device_id, CudaDevice cuda_device) {
26+
void DeviceList::insert(int device_id, RTDevice cuda_device) {
2727
device_list[device_id] = cuda_device;
2828
}
2929

30-
CudaDevice DeviceList::find(int device_id) {
30+
RTDevice DeviceList::find(int device_id) {
3131
return device_list[device_id];
3232
}
3333

core/runtime/CudaDevice.cpp core/runtime/RTDevice.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@ const std::string DEVICE_INFO_DELIM = "%";
1111

1212
typedef enum { ID_IDX = 0, SM_MAJOR_IDX, SM_MINOR_IDX, DEVICE_TYPE_IDX, DEVICE_NAME_IDX } SerializedDeviceInfoIndex;
1313

14-
CudaDevice::CudaDevice() : id{-1}, major{-1}, minor{-1}, device_type{nvinfer1::DeviceType::kGPU} {}
14+
RTDevice::RTDevice() : id{-1}, major{-1}, minor{-1}, device_type{nvinfer1::DeviceType::kGPU} {}
1515

16-
CudaDevice::CudaDevice(int64_t gpu_id, nvinfer1::DeviceType device_type) {
17-
CudaDevice cuda_device;
16+
RTDevice::RTDevice(int64_t gpu_id, nvinfer1::DeviceType device_type) {
17+
RTDevice cuda_device;
1818
cudaDeviceProp device_prop;
1919

2020
// Device ID
@@ -41,7 +41,7 @@ CudaDevice::CudaDevice(int64_t gpu_id, nvinfer1::DeviceType device_type) {
4141
// NOTE: Serialization Format for Device Info:
4242
// id%major%minor%(enum)device_type%device_name
4343

44-
CudaDevice::CudaDevice(std::string device_info) {
44+
RTDevice::RTDevice(std::string device_info) {
4545
LOG_DEBUG("Deserializing Device Info: " << device_info);
4646

4747
std::vector<std::string> tokens;
@@ -66,7 +66,7 @@ CudaDevice::CudaDevice(std::string device_info) {
6666
LOG_DEBUG("Deserialized Device Info: " << *this);
6767
}
6868

69-
CudaDevice& CudaDevice::operator=(const CudaDevice& other) {
69+
RTDevice& RTDevice::operator=(const RTDevice& other) {
7070
id = other.id;
7171
major = other.major;
7272
minor = other.minor;
@@ -75,7 +75,7 @@ CudaDevice& CudaDevice::operator=(const CudaDevice& other) {
7575
return (*this);
7676
}
7777

78-
std::string CudaDevice::serialize() {
78+
std::string RTDevice::serialize() {
7979
std::vector<std::string> content;
8080
content.resize(DEVICE_NAME_IDX + 1);
8181

@@ -98,13 +98,13 @@ std::string CudaDevice::serialize() {
9898
return serialized_device_info;
9999
}
100100

101-
std::string CudaDevice::getSMCapability() const {
101+
std::string RTDevice::getSMCapability() const {
102102
std::stringstream ss;
103103
ss << major << "." << minor;
104104
return ss.str();
105105
}
106106

107-
std::ostream& operator<<(std::ostream& os, const CudaDevice& device) {
107+
std::ostream& operator<<(std::ostream& os, const RTDevice& device) {
108108
os << "Device(ID: " << device.id << ", Name: " << device.device_name << ", SM Capability: " << device.major << '.'
109109
<< device.minor << ", Type: " << device.device_type << ')';
110110
return os;

core/runtime/RTDevice.h

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
#include <string>
3+
#include "NvInfer.h"
4+
5+
namespace torch_tensorrt {
6+
namespace core {
7+
namespace runtime {
8+
9+
struct RTDevice {
10+
int64_t id; // CUDA device id
11+
int64_t major; // CUDA compute major version
12+
int64_t minor; // CUDA compute minor version
13+
nvinfer1::DeviceType device_type;
14+
std::string device_name;
15+
16+
RTDevice();
17+
RTDevice(int64_t gpu_id, nvinfer1::DeviceType device_type);
18+
RTDevice(std::string serialized_device_info);
19+
~RTDevice() = default;
20+
RTDevice(const RTDevice& other) = default;
21+
RTDevice& operator=(const RTDevice& other);
22+
std::string serialize();
23+
std::string getSMCapability() const;
24+
friend std::ostream& operator<<(std::ostream& os, const RTDevice& device);
25+
};
26+
27+
} // namespace runtime
28+
} // namespace core
29+
} // namespace torch_tensorrt

0 commit comments

Comments
 (0)