-
Notifications
You must be signed in to change notification settings - Fork 361
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
586 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
27 changes: 27 additions & 0 deletions
27
py/torch_tensorrt/dynamo/lowering/passes/remove_num_users_is_0_nodes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import logging | ||
|
||
import torch | ||
from torch_tensorrt.dynamo._settings import CompilationSettings | ||
from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( | ||
clean_up_graph_after_modifications, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def remove_num_users_is_0_nodes( | ||
gm: torch.fx.GraphModule, settings: CompilationSettings | ||
) -> torch.fx.GraphModule: | ||
"""Remove ops that [num_users=0] in the graph""" | ||
output_node = list(gm.graph.nodes)[-1] | ||
|
||
for node in gm.graph.nodes: | ||
if node != output_node and len(node.users) == 0: | ||
node_input = node.all_input_nodes[0] | ||
node.replace_all_uses_with(node_input) | ||
gm.graph.erase_node(node) | ||
gm = clean_up_graph_after_modifications(gm) | ||
|
||
logger.debug(f"Removed ops that [num_users=0] nodes:\n{gm.graph}") | ||
|
||
return gm |
34 changes: 34 additions & 0 deletions
34
py/torch_tensorrt/dynamo/lowering/passes/remove_sym_size_and_constrain_nodes.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import logging | ||
|
||
import torch | ||
from torch_tensorrt.dynamo._settings import CompilationSettings | ||
from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( | ||
clean_up_graph_after_modifications, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def remove_sym_size_and_constrain_nodes( | ||
gm: torch.fx.GraphModule, settings: CompilationSettings | ||
) -> torch.fx.GraphModule: | ||
"""Remove aten.sym_size.int and aten.sym_constrain_range_for_size.default ops in the graph""" | ||
count = 0 | ||
for node in gm.graph.nodes: | ||
if node.op == "call_function" and ( | ||
node.target == torch.ops.aten.sym_size.int | ||
or node.target == torch.ops.aten.sym_constrain_range_for_size.default | ||
): | ||
node_input = node.all_input_nodes[0] | ||
node.replace_all_uses_with(node_input) | ||
gm.graph.erase_node(node) | ||
count += 1 | ||
|
||
if count > 0: | ||
gm = clean_up_graph_after_modifications(gm) | ||
|
||
logger.debug( | ||
f"Removed {count} aten.sym_size.int or aten.sym_constrain_range_for_size.default nodes:\n{gm.graph}" | ||
) | ||
|
||
return gm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import logging | ||
from typing import Any, Union | ||
|
||
import torch | ||
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule | ||
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import ( | ||
CudaGraphsTorchTensorRTModule, | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class _OutputAllocatorContextManager(object): | ||
""" | ||
Helper class to set up output_allocator | ||
""" | ||
|
||
def __init__( | ||
self, module: Union[torch.fx.GraphModule, CudaGraphsTorchTensorRTModule] | ||
) -> None: | ||
if isinstance(module, CudaGraphsTorchTensorRTModule): | ||
rt_mods = [module] | ||
else: | ||
rt_mods = [] | ||
|
||
for name, rt_mod in module.named_children(): | ||
if "_run_on_acc" in name and isinstance( | ||
rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule) | ||
): | ||
rt_mods.append(rt_mod) | ||
|
||
self.rt_mods = rt_mods | ||
|
||
def set_output_allocator_output(self, enable: bool) -> None: | ||
for mod in self.rt_mods: | ||
mod.set_output_allocator_outputs(enable) | ||
|
||
def __enter__(self) -> "_OutputAllocatorContextManager": | ||
# Enable output_allocator for TRT submodules | ||
self.set_output_allocator_output(True) | ||
return self | ||
|
||
def __exit__(self, *args: Any) -> None: | ||
# Disable output_allocator | ||
self.set_output_allocator_output(False) | ||
|
||
|
||
def enable_output_allocator( | ||
module: torch.fx.GraphModule, | ||
) -> _OutputAllocatorContextManager: | ||
return _OutputAllocatorContextManager(module) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.