Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure to use correct GPU device in RunSince when it's invoked by new thread #24192

Merged
merged 15 commits into from
Apr 2, 2025
21 changes: 21 additions & 0 deletions include/onnxruntime/core/framework/stream_handles.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#pragma once

#include <functional>
#include <optional>
#include <unordered_map>
#include "core/framework/allocator.h"
#include "core/framework/ortdevice.h"
Expand Down Expand Up @@ -154,6 +155,12 @@
// TODO: use a better way to dispatch handles.
using CreateStreamFn = std::function<std::unique_ptr<Stream>(const OrtDevice&)>;

// This SetDevice function is used by TRT EP or CUDA EP to handle the case where ExecutionMode::ORT_PARALLEL is enabled.
// In that case, ORT retrieves a thread from the thread pool to run kernels for a given session.
// Since new threads default to using device 0, but the session may be tightly bound to a device > 0,
// This SetDevice function will be called in RunSince to ensure running kernels on a correct GPU device.
using SetDeviceFn = std::function<void(OrtDevice::DeviceId)>;

// an interface of a simple registry which hold the handles EP registered.
// make it interface so we can pass it through shared library based execution providers
class IStreamCommandHandleRegistry {
Expand All @@ -171,6 +178,20 @@
WaitNotificationFn fn) = 0;
// register a handle about how to create stream on given device type.
virtual void RegisterCreateStreamFn(OrtDevice::DeviceType device_type, CreateStreamFn f) = 0;

// Register a SetDevice function.
// This interface is currently used by TRT EP or CUDA EP only.
virtual void RegisterSetDeviceFn(OrtDevice::DeviceType device_type, SetDeviceFn f) {
ORT_UNUSED_PARAMETER(device_type);
ORT_UNUSED_PARAMETER(f);
};

Check warning on line 187 in include/onnxruntime/core/framework/stream_handles.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: include/onnxruntime/core/framework/stream_handles.h:187: You don't need a ; after a } [readability/braces] [4]

// Get a SetDevice function.
// This interface is currently used by TRT EP or CUDA EP only and is called in RunSince from stream execution.
virtual std::optional<SetDeviceFn> GetSetDeviceFn(OrtDevice::DeviceType device_type) const {
ORT_UNUSED_PARAMETER(device_type);
return std::nullopt;
};

Check warning on line 194 in include/onnxruntime/core/framework/stream_handles.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: include/onnxruntime/core/framework/stream_handles.h:194: You don't need a ; after a } [readability/braces] [4]
};

} // namespace onnxruntime
13 changes: 13 additions & 0 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,24 @@ class StreamCommandHandleRegistryImpl : public IStreamCommandHandleRegistry {
create_stream_map_.insert({device_type, f});
}

void RegisterSetDeviceFn(const OrtDevice::DeviceType device_type, SetDeviceFn f) override {
set_device_map_.insert({device_type, f});
}

std::optional<SetDeviceFn> GetSetDeviceFn(const OrtDevice::DeviceType device_type) const override {
auto it = set_device_map_.find(device_type);
if (it != set_device_map_.end()) {
return it->second;
}
return std::nullopt;
}

StreamCommandHandleRegistryImpl() = default;

private:
InlinedHashMap<std::string, WaitNotificationFn> notification_wait_map_;
InlinedHashMap<OrtDevice::DeviceType, CreateStreamFn> create_stream_map_;
InlinedHashMap<OrtDevice::DeviceType, SetDeviceFn> set_device_map_;
};
#endif

Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/framework/stream_execution_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,21 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess
end = std::min(end, range->stream_pc_range[stream_idx].second);
#endif

#ifdef ORT_ENABLE_STREAM
// If the device stream has corresponding SetDevice function registered, it means GPU device should be properly set to the correct device.
// The reason SetDevice should be called here is:
// - RunSince function can be invoked from a new thread
// - new threads default to using device 0, but the session may be tightly bound to a device > 0.
auto device_stream = ctx.GetDeviceStream(stream_idx);
if (device_stream) {
auto set_device_fn = ctx.GetSessionState().GetStreamHandleRegistryInstance().GetSetDeviceFn(device_stream->GetDevice().Type());
if (set_device_fn.has_value()) {
auto device_id = device_stream->GetDevice().Id();
set_device_fn.value()(device_id);
}
}
#endif

while (since < end) {
if (!ctx.TaskStatus().IsOK()) {
ctx.CompleteTask();
Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,6 @@ Status CUDAExecutionProvider::Sync() const {
}

Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) {
// always set CUDA device when session::Run() in case it runs in a worker thread
CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId()));
CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options);
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id) &&
GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) {
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cuda/cuda_stream_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
ep_info](const OrtDevice& device) {
return std::make_unique<CudaStream>(external_stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, false, external_cudnn_handle, external_cublas_handle, ep_info);
});
stream_handle_registry.RegisterSetDeviceFn(device_type, [](OrtDevice::DeviceId id) { CUDA_CALL_THROW(cudaSetDevice(id)); });
}

} // namespace onnxruntime
18 changes: 0 additions & 18 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3538,15 +3538,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView

// Create compute function
compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
// The GPU device is set again here to handle multithreading scenarios.
// Consider the following:
// Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0)
// Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function.
// Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU.
// Note: Based on our measurements on the A100 GPU with CUDA 12, the execution time for cudaSetDevice is approximately 0.004 ms, which is negligible
// and does not impact runtime performance.
CUDA_CALL_THROW(cudaSetDevice(device_id_));

Ort::KernelContext ctx(context);

TensorrtFuncState* trt_state = reinterpret_cast<TensorrtFuncState*>(state);
Expand Down Expand Up @@ -4221,15 +4212,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con

// Create compute function
compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
// The GPU device is set again here to handle multithreading scenarios.
// Consider the following:
// Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0)
// Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function.
// Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU.
// Note: Based on our measurements on the A100 GPU with CUDA 12, the execution time for cudaSetDevice is approximately 0.004 ms, which is negligible
// and does not impact runtime performance.
CUDA_CALL_THROW(cudaSetDevice(device_id_));

Ort::KernelContext ctx(context);

TensorrtShortFuncState* trt_state = reinterpret_cast<TensorrtShortFuncState*>(state);
Expand Down
154 changes: 154 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_python_ort_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import ctypes
import sys
import threading
import time
import unittest

import numpy as np
from helper import get_name

import onnxruntime as onnxrt


class ThreadObj:
def __init__(self, model_path: str, iterations: int, idx: int, num_device: int, provider_options_list: list):
self.iterations = iterations
sess_opt = onnxrt.SessionOptions()
sess_opt.graph_optimization_level = onnxrt.GraphOptimizationLevel.ORT_DISABLE_ALL
sess_opt.execution_mode = onnxrt.ExecutionMode.ORT_PARALLEL # ORT will use thread from inter-op thread pool
self.inference_session = onnxrt.InferenceSession(model_path, sess_opt, provider_options_list[idx % num_device])
self.input = {
"Input3": np.ones([1, 1, 28, 28], np.float32),
}
self.idx = idx

def warmup(self):
print(f"[THREAD {self.idx}] running warmup")
self.inference_session.run(None, self.input)
print(f"[THREAD {self.idx}] warmup done")

def run(self, thread_times, threads_complete):
for iter in range(self.iterations):
print(f"[THREAD {self.idx}] running iteration {iter}")
thread_times[self.idx] = time.time()
self.inference_session.run(None, self.input)
thread_times[self.idx] = time.time()
print(f"[THREAD {self.idx}] completed iteration {iter}")
threads_complete[0] += 1


def thread_target(obj, thread_times, threads_complete):
obj.run(thread_times, threads_complete)


# This unittest class creates 10 threads, each thread creates its own inference session and runs one warmup sequentially.
# Once all threads finish their warmup run, all threads run multiple inference runs concurrently.
class TestParallelRun(unittest.TestCase):
def test_select_ep_to_run_ort_parallel_execution_mode(self):
if "TensorrtExecutionProvider" in onnxrt.get_available_providers():
cuda_lib = self.load_cuda_lib()
device_cnt = self.cuda_device_count(cuda_lib)
assert device_cnt > 0
print(f"Number of GPUs available: {device_cnt}")
self.run_inference_with_parallel_execution_mode("TensorrtExecutionProvider", device_cnt)
elif "CUDAExecutionProvider" in onnxrt.get_available_providers():
cuda_lib = self.load_cuda_lib()
device_cnt = self.cuda_device_count(cuda_lib)
assert device_cnt > 0
print(f"Number of GPUs available: {device_cnt}")
self.run_inference_with_parallel_execution_mode("CUDAExecutionProvider", device_cnt)

def load_cuda_lib(self):
cuda_lib = None
if sys.platform == "win32":
cuda_lib = "nvcuda.dll"
elif sys.platform == "linux":
cuda_lib = "libcuda.so"
elif sys.platform == "darwin":
cuda_lib = "libcuda.dylib"

if cuda_lib is not None:
try:
return ctypes.CDLL(cuda_lib)
except OSError:
pass
return None

def cuda_device_count(self, cuda_lib):
if cuda_lib is None:
return -1
num_device = ctypes.c_int()
cuda_lib.cuInit(0)
result = cuda_lib.cuDeviceGetCount(ctypes.byref(num_device))
if result != 0:
error_str = ctypes.c_char_p()
cuda_lib.cuGetErrorString(result, ctypes.byref(error_str))
print(f"cuDeviceGetCount failed with error code {result}: {error_str.value.decode()}")
return -1
return num_device.value

def run_inference_with_parallel_execution_mode(self, ep, num_device):
provider_options = []
for i in range(num_device):
option = [
(
ep,
{
"device_id": i,
},
),
]
provider_options.append(option)

model_path = get_name("mnist.onnx")
iterations = 20
hang_time = 60

num_threads = 10
t_obj_list = []
thread_list = []

threads_complete = [0]
thread_times = [0] * num_threads

for tidx in range(num_threads):
obj = ThreadObj(model_path, iterations, tidx, num_device, provider_options)
t_obj_list.append(obj)
obj.warmup()

for t_obj in t_obj_list:
thread = threading.Thread(
target=thread_target,
daemon=True,
args=(
t_obj,
thread_times,
threads_complete,
),
)
thread.start()
thread_list.append(thread)

time.sleep(5)
while True:
for t_time in thread_times:
if time.time() - t_time < hang_time:
continue
else:
print("Hang occured, ending test")

Check warning on line 141 in onnxruntime/test/python/onnxruntime_test_python_ort_parallel.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "occured" is a misspelling of "occurred" Raw Output: ./onnxruntime/test/python/onnxruntime_test_python_ort_parallel.py:141:32: "occured" is a misspelling of "occurred"
exit(1)
if threads_complete[0] == num_threads:
break
time.sleep(5)

for thread in thread_list:
thread.join()

print("All threads completed")


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions tools/ci_build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -2438,6 +2438,8 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs):
if args.use_cuda:
log.info("Testing CUDA Graph feature")
run_subprocess([sys.executable, "onnxruntime_test_python_cudagraph.py"], cwd=cwd, dll_path=dll_path)
log.info("Testing running inference concurrently")
run_subprocess([sys.executable, "onnxruntime_test_python_ort_parallel.py"], cwd=cwd, dll_path=dll_path)

if args.use_dml:
log.info("Testing DML Graph feature")
Expand Down
Loading