Skip to content

Commit 8be37c8

Browse files
chilo-msamarin16
authored andcommitted
[TensorRT EP] Call cudaSetDevice at compute function for handling multithreading scenario (#24010)
The GPU device is set again at compute function/compute time to handle multithreading scenarios. Consider the following: Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0) Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function. Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU. Example code: ````python provider = [ [ ('TensorrtExecutionProvider', { 'device_id': 0, }), ], [ ('TensorrtExecutionProvider', { 'device_id': 1, }), ] ] class ThreadObj(): def __init__(self, model_path: str, iterations: int, idx: int): ... sess_opt = ort.SessionOptions() self.inference_session = ort.InferenceSession(model_path, sess_opt, provider[idx % 2]) def warmup(self): self.inference_session.run(None, self.input) def run(self, thread_times, threads_complete): for iter in range(self.iterations): self.inference_session.run(None, self.input) def thread_target(obj, thread_times, threads_complete): obj.run(thread_times, threads_complete) ... iterations = 500 num_threads = 13 t_obj_list = [] thread_list = [] for tidx in range(num_threads): obj = ThreadObj(model_path, iterations, tidx) t_obj_list.append(obj) obj.warmup() for t_obj in t_obj_list: thread = threading.Thread(target=thread_target, daemon=True, args=(t_obj,thread_times,threads_complete,)) thread.start() thread_list.append(thread) ... ```` Note: Based on our measurements (using cuda event) on the A100 GPU with CUDA 12, the execution time for `cudaSetDevice` is approximately 0.004 ms, which is negligible and does not impact runtime performance.
1 parent 26a2a96 commit 8be37c8

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

+18
Original file line numberDiff line numberDiff line change
@@ -3487,6 +3487,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
34873487

34883488
// Create compute function
34893489
compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
3490+
// The GPU device is set again here to handle multithreading scenarios.
3491+
// Consider the following:
3492+
// Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0)
3493+
// Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function.
3494+
// Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU.
3495+
// Note: Based on our measurements on the A100 GPU with CUDA 12, the execution time for cudaSetDevice is approximately 0.004 ms, which is negligible
3496+
// and does not impact runtime performance.
3497+
CUDA_CALL_THROW(cudaSetDevice(device_id_));
3498+
34903499
Ort::KernelContext ctx(context);
34913500

34923501
TensorrtFuncState* trt_state = reinterpret_cast<TensorrtFuncState*>(state);
@@ -4161,6 +4170,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
41614170

41624171
// Create compute function
41634172
compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
4173+
// The GPU device is set again here to handle multithreading scenarios.
4174+
// Consider the following:
4175+
// Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0)
4176+
// Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function.
4177+
// Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU.
4178+
// Note: Based on our measurements on the A100 GPU with CUDA 12, the execution time for cudaSetDevice is approximately 0.004 ms, which is negligible
4179+
// and does not impact runtime performance.
4180+
CUDA_CALL_THROW(cudaSetDevice(device_id_));
4181+
41644182
Ort::KernelContext ctx(context);
41654183

41664184
TensorrtShortFuncState* trt_state = reinterpret_cast<TensorrtShortFuncState*>(state);

0 commit comments

Comments
 (0)