diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 523ebbfae807a..00f53b96f931a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -3538,6 +3538,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Create compute function compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + // The GPU device is set again here to handle multithreading scenarios. + // Consider the following: + // Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0) + // Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function. + // Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU. + // Note: Based on our measurements on the A100 GPU with CUDA 12, the execution time for cudaSetDevice is approximately 0.004 ms, which is negligible + // and does not impact runtime performance. + CUDA_CALL_THROW(cudaSetDevice(device_id_)); + Ort::KernelContext ctx(context); TensorrtFuncState* trt_state = reinterpret_cast(state); @@ -4212,6 +4221,15 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con // Create compute function compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + // The GPU device is set again here to handle multithreading scenarios. + // Consider the following: + // Users can create multiple threads to initialize separate inference sessions on different devices (not just the default device 0) + // Later, additional threads may be spawned to execute inference_session.Run(), which calls this compute function. + // Since new threads default to using device 0, it’s necessary to explicitly set the correct device to ensure computations run on the intended GPU. + // Note: Based on our measurements on the A100 GPU with CUDA 12, the execution time for cudaSetDevice is approximately 0.004 ms, which is negligible + // and does not impact runtime performance. + CUDA_CALL_THROW(cudaSetDevice(device_id_)); + Ort::KernelContext ctx(context); TensorrtShortFuncState* trt_state = reinterpret_cast(state);