Skip to content

Commit 0b15b99

Browse files
chilo-msSean Ye
authored and
Sean Ye
committed
[TensorRT EP] Call cudaSetDevice at compute function for handling multithreading scenario (microsoft#24010)
The GPU device is set again at compute function/compute time to handle multithreading scenarios. Consider the following: Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0) Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function. Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU. Example code: ````python provider = [ [ ('TensorrtExecutionProvider', { 'device_id': 0, }), ], [ ('TensorrtExecutionProvider', { 'device_id': 1, }), ] ] class ThreadObj(): def __init__(self, model_path: str, iterations: int, idx: int): ... sess_opt = ort.SessionOptions() self.inference_session = ort.InferenceSession(model_path, sess_opt, provider[idx % 2]) def warmup(self): self.inference_session.run(None, self.input) def run(self, thread_times, threads_complete): for iter in range(self.iterations): self.inference_session.run(None, self.input) def thread_target(obj, thread_times, threads_complete): obj.run(thread_times, threads_complete) ... iterations = 500 num_threads = 13 t_obj_list = [] thread_list = [] for tidx in range(num_threads): obj = ThreadObj(model_path, iterations, tidx) t_obj_list.append(obj) obj.warmup() for t_obj in t_obj_list: thread = threading.Thread(target=thread_target, daemon=True, args=(t_obj,thread_times,threads_complete,)) thread.start() thread_list.append(thread) ... ```` Note: Based on our measurements (using cuda event) on the A100 GPU with CUDA 12, the execution time for `cudaSetDevice` is approximately 0.004 ms, which is negligible and does not impact runtime performance.
1 parent c8a0537 commit 0b15b99

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

+18
Original file line numberDiff line numberDiff line change
@@ -3385,6 +3385,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
33853385

33863386
// Create compute function
33873387
compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
3388+
// The GPU device is set again here to handle multithreading scenarios.
3389+
// Consider the following:
3390+
// Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0)
3391+
// Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function.
3392+
// Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU.
3393+
// Note: Based on our measurements on the A100 GPU with CUDA 12, the execution time for cudaSetDevice is approximately 0.004 ms, which is negligible
3394+
// and does not impact runtime performance.
3395+
CUDA_CALL_THROW(cudaSetDevice(device_id_));
3396+
33883397
Ort::KernelContext ctx(context);
33893398

33903399
TensorrtFuncState* trt_state = reinterpret_cast<TensorrtFuncState*>(state);
@@ -4055,6 +4064,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
40554064

40564065
// Create compute function
40574066
compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
4067+
// The GPU device is set again here to handle multithreading scenarios.
4068+
// Consider the following:
4069+
// Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0)
4070+
// Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function.
4071+
// Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU.
4072+
// Note: Based on our measurements on the A100 GPU with CUDA 12, the execution time for cudaSetDevice is approximately 0.004 ms, which is negligible
4073+
// and does not impact runtime performance.
4074+
CUDA_CALL_THROW(cudaSetDevice(device_id_));
4075+
40584076
Ort::KernelContext ctx(context);
40594077

40604078
TensorrtShortFuncState* trt_state = reinterpret_cast<TensorrtShortFuncState*>(state);

0 commit comments

Comments
 (0)