Skip to content

Commit 71a2155

Browse files
committed
Merge remote-tracking branch 'origin/main' into snnn/ci
2 parents 9d0355a + bc7b07d commit 71a2155

File tree

13 files changed

+436
-146
lines changed

13 files changed

+436
-146
lines changed

cmake/onnxruntime_providers_vitisai.cmake

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
)
2121
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs})
2222
onnxruntime_add_shared_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs})
23-
onnxruntime_add_include_to_target(onnxruntime_providers_vitisai ${ONNXRUNTIME_PROVIDERS_SHARED} ${GSL_TARGET} safeint_interface flatbuffers::flatbuffers)
23+
onnxruntime_add_include_to_target(onnxruntime_providers_vitisai ${ONNXRUNTIME_PROVIDERS_SHARED} ${GSL_TARGET} safeint_interface flatbuffers::flatbuffers Boost::mp11)
2424
target_link_libraries(onnxruntime_providers_vitisai PRIVATE ${ONNXRUNTIME_PROVIDERS_SHARED})
2525
if(MSVC)
2626
onnxruntime_add_include_to_target(onnxruntime_providers_vitisai dbghelp)

include/onnxruntime/core/framework/stream_handles.h

+21
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#pragma once
44

55
#include <functional>
6+
#include <optional>
67
#include <unordered_map>
78
#include "core/framework/allocator.h"
89
#include "core/framework/ortdevice.h"
@@ -154,6 +155,12 @@ class Notification {
154155
// TODO: use a better way to dispatch handles.
155156
using CreateStreamFn = std::function<std::unique_ptr<Stream>(const OrtDevice&)>;
156157

158+
// This SetDevice function is used by TRT EP or CUDA EP to handle the case where ExecutionMode::ORT_PARALLEL is enabled.
159+
// In that case, ORT retrieves a thread from the thread pool to run kernels for a given session.
160+
// Since new threads default to using device 0, but the session may be tightly bound to a device > 0,
161+
// This SetDevice function will be called in RunSince to ensure running kernels on a correct GPU device.
162+
using SetDeviceFn = std::function<void(OrtDevice::DeviceId)>;
163+
157164
// an interface of a simple registry which hold the handles EP registered.
158165
// make it interface so we can pass it through shared library based execution providers
159166
class IStreamCommandHandleRegistry {
@@ -171,6 +178,20 @@ class IStreamCommandHandleRegistry {
171178
WaitNotificationFn fn) = 0;
172179
// register a handle about how to create stream on given device type.
173180
virtual void RegisterCreateStreamFn(OrtDevice::DeviceType device_type, CreateStreamFn f) = 0;
181+
182+
// Register a SetDevice function.
183+
// This interface is currently used by TRT EP or CUDA EP only.
184+
virtual void RegisterSetDeviceFn(OrtDevice::DeviceType device_type, SetDeviceFn f) {
185+
ORT_UNUSED_PARAMETER(device_type);
186+
ORT_UNUSED_PARAMETER(f);
187+
};
188+
189+
// Get a SetDevice function.
190+
// This interface is currently used by TRT EP or CUDA EP only and is called in RunSince from stream execution.
191+
virtual std::optional<SetDeviceFn> GetSetDeviceFn(OrtDevice::DeviceType device_type) const {
192+
ORT_UNUSED_PARAMETER(device_type);
193+
return std::nullopt;
194+
};
174195
};
175196

176197
} // namespace onnxruntime

0 commit comments

Comments
 (0)