3
3
#pragma once
4
4
5
5
#include < functional>
6
+ #include < optional>
6
7
#include < unordered_map>
7
8
#include " core/framework/allocator.h"
8
9
#include " core/framework/ortdevice.h"
@@ -154,6 +155,12 @@ class Notification {
154
155
// TODO: use a better way to dispatch handles.
155
156
using CreateStreamFn = std::function<std::unique_ptr<Stream>(const OrtDevice&)>;
156
157
158
+ // This SetDevice function is used by TRT EP or CUDA EP to handle the case where ExecutionMode::ORT_PARALLEL is enabled.
159
+ // In that case, ORT retrieves a thread from the thread pool to run kernels for a given session.
160
+ // Since new threads default to using device 0, but the session may be tightly bound to a device > 0,
161
+ // This SetDevice function will be called in RunSince to ensure running kernels on a correct GPU device.
162
+ using SetDeviceFn = std::function<void (OrtDevice::DeviceId)>;
163
+
157
164
// an interface of a simple registry which hold the handles EP registered.
158
165
// make it interface so we can pass it through shared library based execution providers
159
166
class IStreamCommandHandleRegistry {
@@ -171,6 +178,20 @@ class IStreamCommandHandleRegistry {
171
178
WaitNotificationFn fn) = 0;
172
179
// register a handle about how to create stream on given device type.
173
180
virtual void RegisterCreateStreamFn (OrtDevice::DeviceType device_type, CreateStreamFn f) = 0;
181
+
182
+ // Register a SetDevice function.
183
+ // This interface is currently used by TRT EP or CUDA EP only.
184
+ virtual void RegisterSetDeviceFn (OrtDevice::DeviceType device_type, SetDeviceFn f) {
185
+ ORT_UNUSED_PARAMETER (device_type);
186
+ ORT_UNUSED_PARAMETER (f);
187
+ };
188
+
189
+ // Get a SetDevice function.
190
+ // This interface is currently used by TRT EP or CUDA EP only and is called in RunSince from stream execution.
191
+ virtual std::optional<SetDeviceFn> GetSetDeviceFn (OrtDevice::DeviceType device_type) const {
192
+ ORT_UNUSED_PARAMETER (device_type);
193
+ return std::nullopt;
194
+ };
174
195
};
175
196
176
197
} // namespace onnxruntime
0 commit comments