diff --git a/include/onnxruntime/core/framework/stream_handles.h b/include/onnxruntime/core/framework/stream_handles.h index 01631e1fb2aa6..402ea2da2148c 100644 --- a/include/onnxruntime/core/framework/stream_handles.h +++ b/include/onnxruntime/core/framework/stream_handles.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include #include "core/framework/allocator.h" #include "core/framework/ortdevice.h" @@ -154,6 +155,12 @@ class Notification { // TODO: use a better way to dispatch handles. using CreateStreamFn = std::function(const OrtDevice&)>; +// This SetDevice function is used by TRT EP or CUDA EP to handle the case where ExecutionMode::ORT_PARALLEL is enabled. +// In that case, ORT retrieves a thread from the thread pool to run kernels for a given session. +// Since new threads default to using device 0, but the session may be tightly bound to a device > 0, +// This SetDevice function will be called in RunSince to ensure running kernels on a correct GPU device. +using SetDeviceFn = std::function; + // an interface of a simple registry which hold the handles EP registered. // make it interface so we can pass it through shared library based execution providers class IStreamCommandHandleRegistry { @@ -171,6 +178,20 @@ class IStreamCommandHandleRegistry { WaitNotificationFn fn) = 0; // register a handle about how to create stream on given device type. virtual void RegisterCreateStreamFn(OrtDevice::DeviceType device_type, CreateStreamFn f) = 0; + + // Register a SetDevice function. + // This interface is currently used by TRT EP or CUDA EP only. + virtual void RegisterSetDeviceFn(OrtDevice::DeviceType device_type, SetDeviceFn f) { + ORT_UNUSED_PARAMETER(device_type); + ORT_UNUSED_PARAMETER(f); + }; + + // Get a SetDevice function. + // This interface is currently used by TRT EP or CUDA EP only and is called in RunSince from stream execution. + virtual std::optional GetSetDeviceFn(OrtDevice::DeviceType device_type) const { + ORT_UNUSED_PARAMETER(device_type); + return std::nullopt; + }; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index a8c1a7515ae82..d174d6cc72ead 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -54,11 +54,24 @@ class StreamCommandHandleRegistryImpl : public IStreamCommandHandleRegistry { create_stream_map_.insert({device_type, f}); } + void RegisterSetDeviceFn(const OrtDevice::DeviceType device_type, SetDeviceFn f) override { + set_device_map_.insert({device_type, f}); + } + + std::optional GetSetDeviceFn(const OrtDevice::DeviceType device_type) const override { + auto it = set_device_map_.find(device_type); + if (it != set_device_map_.end()) { + return it->second; + } + return std::nullopt; + } + StreamCommandHandleRegistryImpl() = default; private: InlinedHashMap notification_wait_map_; InlinedHashMap create_stream_map_; + InlinedHashMap set_device_map_; }; #endif diff --git a/onnxruntime/core/framework/stream_execution_context.cc b/onnxruntime/core/framework/stream_execution_context.cc index dd7f4d35b34bd..e8beb98749028 100644 --- a/onnxruntime/core/framework/stream_execution_context.cc +++ b/onnxruntime/core/framework/stream_execution_context.cc @@ -205,6 +205,21 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess end = std::min(end, range->stream_pc_range[stream_idx].second); #endif +#ifdef ORT_ENABLE_STREAM + // If the device stream has corresponding SetDevice function registered, it means GPU device should be properly set to the correct device. + // The reason SetDevice should be called here is: + // - RunSince function can be invoked from a new thread + // - new threads default to using device 0, but the session may be tightly bound to a device > 0. + auto device_stream = ctx.GetDeviceStream(stream_idx); + if (device_stream) { + auto set_device_fn = ctx.GetSessionState().GetStreamHandleRegistryInstance().GetSetDeviceFn(device_stream->GetDevice().Type()); + if (set_device_fn.has_value()) { + auto device_id = device_stream->GetDevice().Id(); + set_device_fn.value()(device_id); + } + } +#endif + while (since < end) { if (!ctx.TaskStatus().IsOK()) { ctx.CompleteTask(); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 886fddd8f8a27..635eb67bbedd0 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -420,8 +420,6 @@ Status CUDAExecutionProvider::Sync() const { } Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { - // always set CUDA device when session::Run() in case it runs in a worker thread - CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId())); CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options); if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id) && GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) { diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index e9b159516dad9..51fd2c67b7478 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -266,6 +266,7 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis ep_info](const OrtDevice& device) { return std::make_unique(external_stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, false, external_cudnn_handle, external_cublas_handle, ep_info); }); + stream_handle_registry.RegisterSetDeviceFn(device_type, [](OrtDevice::DeviceId id) { CUDA_CALL_THROW(cudaSetDevice(id)); }); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 00f53b96f931a..523ebbfae807a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -3538,15 +3538,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Create compute function compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { - // The GPU device is set again here to handle multithreading scenarios. - // Consider the following: - // Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0) - // Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function. - // Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU. - // Note: Based on our measurements on the A100 GPU with CUDA 12, the execution time for cudaSetDevice is approximately 0.004 ms, which is negligible - // and does not impact runtime performance. - CUDA_CALL_THROW(cudaSetDevice(device_id_)); - Ort::KernelContext ctx(context); TensorrtFuncState* trt_state = reinterpret_cast(state); @@ -4221,15 +4212,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con // Create compute function compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { - // The GPU device is set again here to handle multithreading scenarios. - // Consider the following: - // Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0) - // Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function. - // Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU. - // Note: Based on our measurements on the A100 GPU with CUDA 12, the execution time for cudaSetDevice is approximately 0.004 ms, which is negligible - // and does not impact runtime performance. - CUDA_CALL_THROW(cudaSetDevice(device_id_)); - Ort::KernelContext ctx(context); TensorrtShortFuncState* trt_state = reinterpret_cast(state); diff --git a/onnxruntime/test/python/onnxruntime_test_python_ort_parallel.py b/onnxruntime/test/python/onnxruntime_test_python_ort_parallel.py new file mode 100644 index 0000000000000..c28bfb930e417 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_ort_parallel.py @@ -0,0 +1,154 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import ctypes +import sys +import threading +import time +import unittest + +import numpy as np +from helper import get_name + +import onnxruntime as onnxrt + + +class ThreadObj: + def __init__(self, model_path: str, iterations: int, idx: int, num_device: int, provider_options_list: list): + self.iterations = iterations + sess_opt = onnxrt.SessionOptions() + sess_opt.graph_optimization_level = onnxrt.GraphOptimizationLevel.ORT_DISABLE_ALL + sess_opt.execution_mode = onnxrt.ExecutionMode.ORT_PARALLEL # ORT will use thread from inter-op thread pool + self.inference_session = onnxrt.InferenceSession(model_path, sess_opt, provider_options_list[idx % num_device]) + self.input = { + "Input3": np.ones([1, 1, 28, 28], np.float32), + } + self.idx = idx + + def warmup(self): + print(f"[THREAD {self.idx}] running warmup") + self.inference_session.run(None, self.input) + print(f"[THREAD {self.idx}] warmup done") + + def run(self, thread_times, threads_complete): + for iter in range(self.iterations): + print(f"[THREAD {self.idx}] running iteration {iter}") + thread_times[self.idx] = time.time() + self.inference_session.run(None, self.input) + thread_times[self.idx] = time.time() + print(f"[THREAD {self.idx}] completed iteration {iter}") + threads_complete[0] += 1 + + +def thread_target(obj, thread_times, threads_complete): + obj.run(thread_times, threads_complete) + + +# This unittest class creates 10 threads, each thread creates its own inference session and runs one warmup sequentially. +# Once all threads finish their warmup run, all threads run multiple inference runs concurrently. +class TestParallelRun(unittest.TestCase): + def test_select_ep_to_run_ort_parallel_execution_mode(self): + if "TensorrtExecutionProvider" in onnxrt.get_available_providers(): + cuda_lib = self.load_cuda_lib() + device_cnt = self.cuda_device_count(cuda_lib) + assert device_cnt > 0 + print(f"Number of GPUs available: {device_cnt}") + self.run_inference_with_parallel_execution_mode("TensorrtExecutionProvider", device_cnt) + elif "CUDAExecutionProvider" in onnxrt.get_available_providers(): + cuda_lib = self.load_cuda_lib() + device_cnt = self.cuda_device_count(cuda_lib) + assert device_cnt > 0 + print(f"Number of GPUs available: {device_cnt}") + self.run_inference_with_parallel_execution_mode("CUDAExecutionProvider", device_cnt) + + def load_cuda_lib(self): + cuda_lib = None + if sys.platform == "win32": + cuda_lib = "nvcuda.dll" + elif sys.platform == "linux": + cuda_lib = "libcuda.so" + elif sys.platform == "darwin": + cuda_lib = "libcuda.dylib" + + if cuda_lib is not None: + try: + return ctypes.CDLL(cuda_lib) + except OSError: + pass + return None + + def cuda_device_count(self, cuda_lib): + if cuda_lib is None: + return -1 + num_device = ctypes.c_int() + cuda_lib.cuInit(0) + result = cuda_lib.cuDeviceGetCount(ctypes.byref(num_device)) + if result != 0: + error_str = ctypes.c_char_p() + cuda_lib.cuGetErrorString(result, ctypes.byref(error_str)) + print(f"cuDeviceGetCount failed with error code {result}: {error_str.value.decode()}") + return -1 + return num_device.value + + def run_inference_with_parallel_execution_mode(self, ep, num_device): + provider_options = [] + for i in range(num_device): + option = [ + ( + ep, + { + "device_id": i, + }, + ), + ] + provider_options.append(option) + + model_path = get_name("mnist.onnx") + iterations = 20 + hang_time = 60 + + num_threads = 10 + t_obj_list = [] + thread_list = [] + + threads_complete = [0] + thread_times = [0] * num_threads + + for tidx in range(num_threads): + obj = ThreadObj(model_path, iterations, tidx, num_device, provider_options) + t_obj_list.append(obj) + obj.warmup() + + for t_obj in t_obj_list: + thread = threading.Thread( + target=thread_target, + daemon=True, + args=( + t_obj, + thread_times, + threads_complete, + ), + ) + thread.start() + thread_list.append(thread) + + time.sleep(5) + while True: + for t_time in thread_times: + if time.time() - t_time < hang_time: + continue + else: + print("Hang occured, ending test") + exit(1) + if threads_complete[0] == num_threads: + break + time.sleep(5) + + for thread in thread_list: + thread.join() + + print("All threads completed") + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index c667df0369c91..698de85c5984b 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2438,6 +2438,8 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if args.use_cuda: log.info("Testing CUDA Graph feature") run_subprocess([sys.executable, "onnxruntime_test_python_cudagraph.py"], cwd=cwd, dll_path=dll_path) + log.info("Testing running inference concurrently") + run_subprocess([sys.executable, "onnxruntime_test_python_ort_parallel.py"], cwd=cwd, dll_path=dll_path) if args.use_dml: log.info("Testing DML Graph feature")