@@ -60,35 +60,43 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
60
60
TensorrtLogger trt_logger = GetTensorrtLogger (false );
61
61
initLibNvInferPlugins (&trt_logger, " " );
62
62
63
- #if defined(_MSC_VER)
64
- #pragma warning(push)
65
- #pragma warning(disable : 4996) // Ignore warning C4996: 'nvinfer1::*' was declared deprecated
66
- #endif
67
-
68
63
int num_plugin_creator = 0 ;
69
- auto plugin_creators = getPluginRegistry ()->getPluginCreatorList (&num_plugin_creator);
64
+ auto plugin_creators = getPluginRegistry ()->getAllCreators (&num_plugin_creator);
70
65
std::unordered_set<std::string> registered_plugin_names;
71
66
72
67
for (int i = 0 ; i < num_plugin_creator; i++) {
73
68
auto plugin_creator = plugin_creators[i];
74
- std::string plugin_name (plugin_creator->getPluginName ());
75
- LOGS_DEFAULT (VERBOSE) << " [TensorRT EP] " << plugin_name << " , version : " << plugin_creator->getPluginVersion ();
69
+ nvinfer1::AsciiChar const * plugin_name = nullptr ;
70
+ if (std::strcmp (plugin_creators[i]->getInterfaceInfo ().kind , " PLUGIN CREATOR_V1" ) == 0 ) {
71
+ #if defined(_MSC_VER)
72
+ #pragma warning(push)
73
+ #pragma warning(disable : 4996) // Ignore warning C4996: 'nvinfer1::*' was declared deprecated
74
+ #endif
75
+ auto plugin_creator_v1 = static_cast <nvinfer1::IPluginCreator const *>(plugin_creator);
76
+ plugin_name = plugin_creator_v1->getPluginName ();
77
+ LOGS_DEFAULT (VERBOSE) << " [TensorRT EP] " << plugin_name << " , version : " << plugin_creator_v1->getPluginVersion ();
78
+ #if defined(_MSC_VER)
79
+ #pragma warning(pop)
80
+ #endif
81
+ } else if (std::strcmp (plugin_creators[i]->getInterfaceInfo ().kind , " PLUGIN CREATOR_V3ONE" ) == 0 ) {
82
+ auto plugin_creator_v3 = static_cast <nvinfer1::IPluginCreatorV3One const *>(plugin_creator);
83
+ plugin_name = plugin_creator_v3->getPluginName ();
84
+ LOGS_DEFAULT (VERBOSE) << " [TensorRT EP][V3ONE] " << plugin_name << " , version : " << plugin_creator_v3->getPluginVersion ();
85
+ } else {
86
+ ORT_THROW (" Unknown plugin creator type" );
87
+ }
76
88
77
89
// plugin has different versions and we only register once
78
90
if (registered_plugin_names.find (plugin_name) != registered_plugin_names.end ()) {
79
91
continue ;
80
92
}
81
93
82
94
created_custom_op_list.push_back (std::make_unique<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider , nullptr )); // Make sure TensorRTCustomOp object won't be cleaned up
83
- created_custom_op_list.back ().get ()->SetName (plugin_creator-> getPluginName () );
95
+ created_custom_op_list.back ().get ()->SetName (plugin_name );
84
96
custom_op_domain->custom_ops_ .push_back (created_custom_op_list.back ().get ());
85
97
registered_plugin_names.insert (plugin_name);
86
98
}
87
99
88
- #if defined(_MSC_VER)
89
- #pragma warning(pop)
90
- #endif
91
-
92
100
custom_op_domain->domain_ = " trt.plugins" ;
93
101
domain_list.push_back (custom_op_domain.get ());
94
102
} catch (const std::exception &) {
0 commit comments