Skip to content

Commit a14d586

Browse files
authored
Support load TensorRT V3 plugin (#24211)
### Description TensorRT V3 plugin is not able to load in TensorRT EP. The change deprecates `getPluginCreatorList` with `getAllCreators` to load V1 and V3 plugin creators. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> Support load TensorRT plugin. Reference: https://github.com/NVIDIA/TensorRT/blob/8c6d69ddec0b2feff12f55472dc5d55cb6861d53/python/src/infer/pyPlugin.cpp#L2971C1-L2995C6
1 parent 55aa03c commit a14d586

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,35 +60,43 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
6060
TensorrtLogger trt_logger = GetTensorrtLogger(false);
6161
initLibNvInferPlugins(&trt_logger, "");
6262

63-
#if defined(_MSC_VER)
64-
#pragma warning(push)
65-
#pragma warning(disable : 4996) // Ignore warning C4996: 'nvinfer1::*' was declared deprecated
66-
#endif
67-
6863
int num_plugin_creator = 0;
69-
auto plugin_creators = getPluginRegistry()->getPluginCreatorList(&num_plugin_creator);
64+
auto plugin_creators = getPluginRegistry()->getAllCreators(&num_plugin_creator);
7065
std::unordered_set<std::string> registered_plugin_names;
7166

7267
for (int i = 0; i < num_plugin_creator; i++) {
7368
auto plugin_creator = plugin_creators[i];
74-
std::string plugin_name(plugin_creator->getPluginName());
75-
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << plugin_name << ", version : " << plugin_creator->getPluginVersion();
69+
nvinfer1::AsciiChar const* plugin_name = nullptr;
70+
if (std::strcmp(plugin_creators[i]->getInterfaceInfo().kind, "PLUGIN CREATOR_V1") == 0) {
71+
#if defined(_MSC_VER)
72+
#pragma warning(push)
73+
#pragma warning(disable : 4996) // Ignore warning C4996: 'nvinfer1::*' was declared deprecated
74+
#endif
75+
auto plugin_creator_v1 = static_cast<nvinfer1::IPluginCreator const*>(plugin_creator);
76+
plugin_name = plugin_creator_v1->getPluginName();
77+
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << plugin_name << ", version : " << plugin_creator_v1->getPluginVersion();
78+
#if defined(_MSC_VER)
79+
#pragma warning(pop)
80+
#endif
81+
} else if (std::strcmp(plugin_creators[i]->getInterfaceInfo().kind, "PLUGIN CREATOR_V3ONE") == 0) {
82+
auto plugin_creator_v3 = static_cast<nvinfer1::IPluginCreatorV3One const*>(plugin_creator);
83+
plugin_name = plugin_creator_v3->getPluginName();
84+
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP][V3ONE] " << plugin_name << ", version : " << plugin_creator_v3->getPluginVersion();
85+
} else {
86+
ORT_THROW("Unknown plugin creator type");
87+
}
7688

7789
// plugin has different versions and we only register once
7890
if (registered_plugin_names.find(plugin_name) != registered_plugin_names.end()) {
7991
continue;
8092
}
8193

8294
created_custom_op_list.push_back(std::make_unique<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr)); // Make sure TensorRTCustomOp object won't be cleaned up
83-
created_custom_op_list.back().get()->SetName(plugin_creator->getPluginName());
95+
created_custom_op_list.back().get()->SetName(plugin_name);
8496
custom_op_domain->custom_ops_.push_back(created_custom_op_list.back().get());
8597
registered_plugin_names.insert(plugin_name);
8698
}
8799

88-
#if defined(_MSC_VER)
89-
#pragma warning(pop)
90-
#endif
91-
92100
custom_op_domain->domain_ = "trt.plugins";
93101
domain_list.push_back(custom_op_domain.get());
94102
} catch (const std::exception&) {

0 commit comments

Comments
 (0)