Skip to content

Commit dad58f1

Browse files
committed
Fixed bugs
1 parent 29df36d commit dad58f1

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,8 +1071,6 @@ def preserve_module_specs(
10711071
f.write(trt_module.get_layer_info())
10721072

10731073
# Only set the requires_unique_output flag for the last TRT Module when user has access to the output tensor
1074-
if trt_module:
1075-
trt_module.set_output_tensors_as_unowned(True)
10761074

10771075
# Parse the graph I/O and store it in dryrun tracker
10781076
parse_graph_io(gm, dryrun_tracker)
@@ -1081,7 +1079,11 @@ def preserve_module_specs(
10811079
for name, trt_module in trt_modules.items():
10821080
setattr(partitioned_module, name, trt_module)
10831081
if settings.lazy_engine_init and not settings.enable_cross_compile_for_windows:
1084-
getattr(partitioned_module, name).setup_engine()
1082+
trt_module = getattr(partitioned_module, name)
1083+
trt_module.setup_engine()
1084+
1085+
if trt_module:
1086+
trt_module.set_output_tensors_as_unowned(True)
10851087

10861088
# Reset settings object to user specification after fallback to global partitioning mode
10871089
if fast_partitioner_failed:

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,6 @@ def __init__(
225225
self.output_tensors_are_unowned = False
226226
if self.serialized_engine is not None and not self.settings.lazy_engine_init:
227227
self.setup_engine()
228-
self.is_shape_inference_io = {
229-
input_name: self.engine.is_shape_inference_io(input_name)
230-
for input_name in self.input_names
231-
}
232228

233229
def set_output_tensors_as_unowned(self, enabled: bool) -> None:
234230
"""
@@ -327,6 +323,11 @@ def setup_engine(self) -> None:
327323
if torch_tensorrt.runtime.get_cudagraphs_mode():
328324
self.cudagraph = torch.cuda.CUDAGraph()
329325

326+
self.is_shape_inference_io = {
327+
input_name: self.engine.is_shape_inference_io(input_name)
328+
for input_name in self.input_names
329+
}
330+
330331
def _check_initialized(self) -> None:
331332
if not self.initialized:
332333
raise RuntimeError("PythonTorchTensorRTModule is not initialized.")
@@ -531,6 +532,7 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
531532
or self.output_tensors_are_unowned
532533
or shape_changed
533534
):
535+
breakpoint()
534536
self.output_tensors = self.create_output_tensors()
535537
outputs = self.output_tensors
536538

0 commit comments

Comments
 (0)