Skip to content

Commit 6bb6d79

Browse files
authored
[TensorRT EP] Call cudaSetDevice at compute function for handling multithreading scenario (#24010)
The GPU device is set again at compute function/compute time to handle multithreading scenarios. Consider the following: Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0) Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function. Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU. Example code: ````python provider = [ [ ('TensorrtExecutionProvider', { 'device_id': 0, }), ], [ ('TensorrtExecutionProvider', { 'device_id': 1, }), ] ] class ThreadObj(): def __init__(self, model_path: str, iterations: int, idx: int): ... sess_opt = ort.SessionOptions() self.inference_session = ort.InferenceSession(model_path, sess_opt, provider[idx % 2]) def warmup(self): self.inference_session.run(None, self.input) def run(self, thread_times, threads_complete): for iter in range(self.iterations): self.inference_session.run(None, self.input) def thread_target(obj, thread_times, threads_complete): obj.run(thread_times, threads_complete) ... iterations = 500 num_threads = 13 t_obj_list = [] thread_list = [] for tidx in range(num_threads): obj = ThreadObj(model_path, iterations, tidx) t_obj_list.append(obj) obj.warmup() for t_obj in t_obj_list: thread = threading.Thread(target=thread_target, daemon=True, args=(t_obj,thread_times,threads_complete,)) thread.start() thread_list.append(thread) ... ```` Note: Based on our measurements (using cuda event) on the A100 GPU with CUDA 12, the execution time for `cudaSetDevice` is approximately 0.004 ms, which is negligible and does not impact runtime performance.
1 parent 5bd3163 commit 6bb6d79

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

+18
Original file line numberDiff line numberDiff line change
@@ -3538,6 +3538,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
35383538

35393539
// Create compute function
35403540
compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
3541+
// The GPU device is set again here to handle multithreading scenarios.
3542+
// Consider the following:
3543+
// Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0)
3544+
// Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function.
3545+
// Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU.
3546+
// Note: Based on our measurements on the A100 GPU with CUDA 12, the execution time for cudaSetDevice is approximately 0.004 ms, which is negligible
3547+
// and does not impact runtime performance.
3548+
CUDA_CALL_THROW(cudaSetDevice(device_id_));
3549+
35413550
Ort::KernelContext ctx(context);
35423551

35433552
TensorrtFuncState* trt_state = reinterpret_cast<TensorrtFuncState*>(state);
@@ -4212,6 +4221,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
42124221

42134222
// Create compute function
42144223
compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
4224+
// The GPU device is set again here to handle multithreading scenarios.
4225+
// Consider the following:
4226+
// Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0)
4227+
// Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function.
4228+
// Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU.
4229+
// Note: Based on our measurements on the A100 GPU with CUDA 12, the execution time for cudaSetDevice is approximately 0.004 ms, which is negligible
4230+
// and does not impact runtime performance.
4231+
CUDA_CALL_THROW(cudaSetDevice(device_id_));
4232+
42154233
Ort::KernelContext ctx(context);
42164234

42174235
TensorrtShortFuncState* trt_state = reinterpret_cast<TensorrtShortFuncState*>(state);

0 commit comments

Comments
 (0)