Skip to content

Commit

Permalink
implement with ctx manager
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Feb 26, 2025
1 parent af62bf1 commit ad04cf9
Show file tree
Hide file tree
Showing 10 changed files with 586 additions and 14 deletions.
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3584,7 +3584,7 @@ def aten_ops_full(
)


@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default)
@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default, supports_dynamic_shapes=True)
def aten_ops_nonzero(
ctx: ConversionContext,
target: Target,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from .remove_assert_nodes import remove_assert_nodes
from .remove_detach import remove_detach
from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones
from .remove_num_users_is_0_nodes import remove_num_users_is_0_nodes
from .remove_sym_size_and_constrain_nodes import remove_sym_size_and_constrain_nodes
from .repair_input_as_output import repair_input_as_output
from .replace_max_pool_with_indices import replace_max_pool_with_indices
from .view_to_reshape import view_to_reshape
Expand All @@ -29,6 +31,8 @@
view_to_reshape,
remove_assert_nodes,
accumulate_fp32_matmul,
remove_sym_size_and_constrain_nodes,
remove_num_users_is_0_nodes,
]
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import logging

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def remove_num_users_is_0_nodes(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Remove ops that [num_users=0] in the graph"""
output_node = list(gm.graph.nodes)[-1]

for node in gm.graph.nodes:
if node != output_node and len(node.users) == 0:
node_input = node.all_input_nodes[0]
node.replace_all_uses_with(node_input)
gm.graph.erase_node(node)
gm = clean_up_graph_after_modifications(gm)

logger.debug(f"Removed ops that [num_users=0] nodes:\n{gm.graph}")

return gm
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def remove_sym_size_and_constrain_nodes(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Remove aten.sym_size.int and aten.sym_constrain_range_for_size.default ops in the graph"""
count = 0
for node in gm.graph.nodes:
if node.op == "call_function" and (
node.target == torch.ops.aten.sym_size.int
or node.target == torch.ops.aten.sym_constrain_range_for_size.default
):
node_input = node.all_input_nodes[0]
node.replace_all_uses_with(node_input)
gm.graph.erase_node(node)
count += 1

if count > 0:
gm = clean_up_graph_after_modifications(gm)

logger.debug(
f"Removed {count} aten.sym_size.int or aten.sym_constrain_range_for_size.default nodes:\n{gm.graph}"
)

return gm
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
self._input_buffers: List[torch.Tensor] = []
self._output_buffers: List[torch.Tensor] = []
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
self.use_output_allocator_outputs = False
self.shape_key: Optional[str] = None
self._caller_stream: Optional[torch.cuda.Stream] = None
self._engine_stream: Optional[torch.cuda.Stream] = None
Expand Down Expand Up @@ -73,8 +74,16 @@ def __del__(self) -> None:
if self.cudagraph:
self.cudagraph.reset()

def set_output_allocator_outputs(self, enable: bool) -> None:
self.use_output_allocator_outputs = enable

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode()
if cudagraphs_enabled and self.use_output_allocator_outputs:
raise RuntimeError(
"There are non-TRT submodules in the module. OutputAllocator is not compatible with modules with non-TRT submodules."
)

if cudagraphs_enabled:
shape_changed = self.validate_input_shapes(inputs)
need_cudagraphs_record = shape_changed or self.is_weight_streaming_set
Expand Down
46 changes: 36 additions & 10 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,13 @@ def __init__(
torch_tensorrt.runtime.get_cudagraphs_mode()
)

self.engine_is_dds = engine_is_dds
self.cudagraphs_enabled = False
self.pre_allocated_outputs: List[torch.Tensor] = []
self.use_pre_allocated_outputs = False

self.engine_is_dds = engine_is_dds
self.output_allocator: Optional[DynamicOutputAllocator] = None
self.use_output_allocator_outputs = False

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()
Expand Down Expand Up @@ -401,6 +404,9 @@ def create_output_tensors(self) -> List[torch.Tensor]:
def set_pre_allocated_outputs(self, enable: bool) -> None:
self.use_pre_allocated_outputs = enable

def set_output_allocator_outputs(self, enable: bool) -> None:
self.use_output_allocator_outputs = enable

def create_output_allocator(self) -> None:
if self.output_allocator is None:
output_dtypes_dict = {}
Expand All @@ -410,15 +416,14 @@ def create_output_allocator(self) -> None:

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:

def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
shape_changed = self.validate_input_shapes(inputs)
(
need_cudagraphs_record,
can_use_pre_allocated_outputs,
need_cudagraphs_reset,
) = self.runtime_states.set_runtime_states(
cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
self.cudagraphs_enabled, self.use_pre_allocated_outputs, shape_changed
)

if need_cudagraphs_reset and self.cudagraph:
Expand All @@ -441,7 +446,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
), f"Wrong number of inputs, expect {len(self.input_names)} get {len(contiguous_inputs)}."

self.setup_input_tensors(
contiguous_inputs, cudagraphs_enabled, need_cudagraphs_record
contiguous_inputs, self.cudagraphs_enabled, need_cudagraphs_record
)

if shape_changed:
Expand Down Expand Up @@ -477,7 +482,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
if need_cudagraphs_record:
self._output_buffers[o] = outputs[o].clone()

if cudagraphs_enabled:
if self.cudagraphs_enabled:
self.context.set_tensor_address(
output_name, self._output_buffers[o].data_ptr()
)
Expand All @@ -503,7 +508,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
self._engine_stream.wait_stream(self._caller_stream)

with torch.cuda.stream(self._engine_stream):
if cudagraphs_enabled:
if self.cudagraphs_enabled:
if need_cudagraphs_record:
self.cudagraph = torch.cuda.CUDAGraph()

Expand Down Expand Up @@ -535,7 +540,7 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
if self.use_pre_allocated_outputs:
self.pre_allocated_outputs = self.create_output_tensors()

if cudagraphs_enabled:
if self.cudagraphs_enabled:
for idx, o in enumerate(outputs):
o.copy_(self._output_buffers[idx])

Expand All @@ -545,7 +550,9 @@ def run_cuda_graph() -> torch.Tensor | Tuple[torch.Tensor, ...]:
return outputs

def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
torch_tensorrt.runtime.set_cudagraphs_mode(False)
assert (
not torch_tensorrt.runtime.get_cudagraphs_mode()
), "CUDA Graphs are not compatible with OutputAllocator."
with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessInputs"
Expand Down Expand Up @@ -625,6 +632,8 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:

return outputs

self.cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()

# Run forward function
contiguous_inputs: List[torch.Tensor] = [
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
Expand Down Expand Up @@ -670,9 +679,26 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
logger.warning(f"Moved all input Tensors to cuda:{device_id}")

if self.engine_is_dds:
if self.cudagraphs_enabled:
raise RuntimeError(
"The module is Data-Dependent Shape (DDS). It has to be handled by OutputAllocator which is not compatible with CUDA Graphs. Please disable CUDA Graphs."
)
logger.debug(
"The module is Data-Dependent Shape (DDS). Using output allocator."
)
return run_output_allocator()
else:
return run_cuda_graph()
if self.cudagraphs_enabled and self.use_output_allocator_outputs:
raise RuntimeError(
"Both CUDA Graphs and OutputAllocator are enabled. Please disable either one."
)
if self.use_output_allocator_outputs:
logger.debug("Using output allocator.")
return run_output_allocator()
logger.debug(
f"Using standard execution with cudagraphs={self.cudagraphs_enabled}."
)
return run_standard_execution()

def enable_profiling(self, profiler: "trt.IProfiler" = None) -> None:
"""
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
set_cudagraphs_mode,
)
from torch_tensorrt.runtime._multi_device_safe_mode import set_multi_device_safe_mode
from torch_tensorrt.runtime._output_allocator import enable_output_allocator
from torch_tensorrt.runtime._pre_allocated_outputs import enable_pre_allocated_outputs
from torch_tensorrt.runtime._weight_streaming import weight_streaming
51 changes: 51 additions & 0 deletions py/torch_tensorrt/runtime/_output_allocator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import logging
from typing import Any, Union

import torch
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
CudaGraphsTorchTensorRTModule,
)

logger = logging.getLogger(__name__)


class _OutputAllocatorContextManager(object):
"""
Helper class to set up output_allocator
"""

def __init__(
self, module: Union[torch.fx.GraphModule, CudaGraphsTorchTensorRTModule]
) -> None:
if isinstance(module, CudaGraphsTorchTensorRTModule):
rt_mods = [module]
else:
rt_mods = []

for name, rt_mod in module.named_children():
if "_run_on_acc" in name and isinstance(
rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule)
):
rt_mods.append(rt_mod)

self.rt_mods = rt_mods

def set_output_allocator_output(self, enable: bool) -> None:
for mod in self.rt_mods:
mod.set_output_allocator_outputs(enable)

def __enter__(self) -> "_OutputAllocatorContextManager":
# Enable output_allocator for TRT submodules
self.set_output_allocator_output(True)
return self

def __exit__(self, *args: Any) -> None:
# Disable output_allocator
self.set_output_allocator_output(False)


def enable_output_allocator(
module: torch.fx.GraphModule,
) -> _OutputAllocatorContextManager:
return _OutputAllocatorContextManager(module)
51 changes: 48 additions & 3 deletions tests/py/dynamo/conversion/test_nonzero_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class TestNonZeroConverter(DispatchTestCase):
((2, 3, 4, 5), torch.float),
]
)
def test_non_zero(self, input_shape, dtype):
def test_nonzero_dds(self, input_shape, dtype):
class NonZero(nn.Module):
# This is a DDS network
def forward(self, input):
Expand All @@ -39,7 +39,7 @@ def forward(self, input):
((2, 3, 4, 5), torch.float),
]
)
def test_non_zero(self, input_shape, dtype):
def test_nonzero_non_dds(self, input_shape, dtype):
class NonZero(nn.Module):
# This is a static network
def forward(self, input):
Expand Down Expand Up @@ -78,7 +78,7 @@ def forward(self, input):
),
]
)
def test_nonzero_dynamic_shape(self, _, min_shape, opt_shape, max_shape, dtype):
def test_nonzero_dynamic_shape_dds(self, _, min_shape, opt_shape, max_shape, dtype):
class NonZero(nn.Module):
def forward(self, input):
return torch.ops.aten.nonzero.default(input)
Expand All @@ -94,6 +94,51 @@ def forward(self, input):

self.run_test_with_dynamic_shape(NonZero(), input_specs)

@parameterized.expand(
[
(
"1d",
(1,),
(10,),
(100,),
torch.int32,
),
(
"2d",
(1, 2),
(5, 10),
(20, 40),
torch.float16,
),
(
"3d",
(1, 2, 3),
(5, 10, 20),
(30, 40, 50),
torch.float,
),
]
)
def test_nonzero_dynamic_shape_non_dds(
self, _, min_shape, opt_shape, max_shape, dtype
):
class NonZero(nn.Module):
def forward(self, input):
out = torch.ops.aten.nonzero.default(input)
out = torch.ops.aten.sum.dim_IntList(out, 0)
return out

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=dtype,
),
]

self.run_test_with_dynamic_shape(NonZero(), input_specs)


if __name__ == "__main__":
run_tests()
Loading

0 comments on commit ad04cf9

Please sign in to comment.