You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[TensorRT EP] Call cudaSetDevice at compute function for handling multithreading scenario (#24010)
The GPU device is set again at compute function/compute time to handle
multithreading scenarios.
Consider the following:
Users can create multiple threads to initialize separate inference
sessions on different devices (not just the default device 0)
Later, additional threads may be spawned to execute
inference_session.Run(), which calls this compute function.
Since new threads default to using device 0, it’s necessary to
explicitly set the correct device to ensure computations run on the
intended GPU.
Example code:
````python
provider = [
[
('TensorrtExecutionProvider', {
'device_id': 0,
}),
],
[
('TensorrtExecutionProvider', {
'device_id': 1,
}),
]
]
class ThreadObj():
def __init__(self, model_path: str, iterations: int, idx: int):
...
sess_opt = ort.SessionOptions()
self.inference_session = ort.InferenceSession(model_path, sess_opt, provider[idx % 2])
def warmup(self):
self.inference_session.run(None, self.input)
def run(self, thread_times, threads_complete):
for iter in range(self.iterations):
self.inference_session.run(None, self.input)
def thread_target(obj, thread_times, threads_complete):
obj.run(thread_times, threads_complete)
...
iterations = 500
num_threads = 13
t_obj_list = []
thread_list = []
for tidx in range(num_threads):
obj = ThreadObj(model_path, iterations, tidx)
t_obj_list.append(obj)
obj.warmup()
for t_obj in t_obj_list:
thread = threading.Thread(target=thread_target, daemon=True, args=(t_obj,thread_times,threads_complete,))
thread.start()
thread_list.append(thread)
...
````
Note: Based on our measurements (using cuda event) on the A100 GPU with
CUDA 12, the execution time for `cudaSetDevice` is approximately 0.004
ms, which is negligible and does not impact runtime performance.
0 commit comments