-
Notifications
You must be signed in to change notification settings - Fork 363
feat: Hierarchical Partitioner to support multi-backends #3539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch_tensorrt | ||
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( | ||
DYNAMO_ATEN_CONVERTERS, | ||
) | ||
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( | ||
DYNAMO_CONVERTERS as CONVERTERS, | ||
) | ||
from torch_tensorrt.dynamo.lowering import ( | ||
get_decompositions, | ||
pre_export_lowering, | ||
) | ||
from torch_tensorrt.dynamo.partitioning._adjacency_partitioner import partition | ||
from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import ( | ||
hierarchical_adjacency_partition, | ||
) | ||
|
||
|
||
class SimpleModel(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) | ||
self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) | ||
self.bn1 = nn.BatchNorm2d(64) | ||
self.bn2 = nn.BatchNorm2d(128) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = torch.relu(x) | ||
x = self.conv2(x) | ||
x = self.bn2(x) | ||
x = torch.relu(x) | ||
return x | ||
|
||
|
||
def main(): | ||
# Create model | ||
model = SimpleModel().cuda() | ||
# model = models.efficientnet_b0(pretrained=True).cuda() | ||
model = model.eval() | ||
|
||
# Create example input | ||
example_input = torch.randn(1, 3, 224, 224).cuda() | ||
|
||
exported_program = torch.export.export(model, (example_input,)) | ||
exported_program = pre_export_lowering(exported_program) | ||
exported_program = exported_program.run_decompositions(get_decompositions()) | ||
|
||
gm = exported_program.module() | ||
|
||
print(gm.graph) | ||
|
||
original_output = model(example_input) | ||
|
||
# Partition the model using the adjacency partitioner | ||
# partitioned_model, op_support = partition( | ||
# gm, | ||
# verbose=True, | ||
# min_block_size=1, | ||
# torch_executed_ops=[ | ||
# torch.ops.aten.relu.default, | ||
# ], | ||
# ) | ||
|
||
partitioned_model, op_support = hierarchical_adjacency_partition( | ||
gm, | ||
verbose=True, | ||
min_block_size=1, | ||
backend_priority=["inductor", "tensorrt"], | ||
backend_support_map={ | ||
"inductor": { | ||
# operator.getitem, | ||
torch.ops.aten.conv2d.default, | ||
torch.ops.aten.convolution.default, | ||
}, | ||
"tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()), | ||
}, | ||
torch_executed_ops=[ | ||
torch.ops.aten._native_batch_norm_legit_no_training.default | ||
], | ||
require_full_compilation=False, | ||
skip_fusion=False, | ||
) | ||
|
||
print("\nPartitioned Model Structure:") | ||
print(partitioned_model) | ||
|
||
print("0. Original_output:", original_output) | ||
|
||
with torch.no_grad(): | ||
partitioned_output = partitioned_model(example_input) | ||
print("1. Partitioned output:", partitioned_output) | ||
print( | ||
"Partitioned output == Original output:", | ||
torch.allclose(original_output, partitioned_output, 1e-2, 1e-2), | ||
) | ||
|
||
compiled_model = torch_tensorrt.compile( | ||
model, inputs=[example_input], min_block_size=1 | ||
) | ||
with torch.no_grad(): | ||
compiled_output = compiled_model(example_input) | ||
print("2. Compiled_output:", compiled_output) | ||
|
||
print( | ||
"Compiled output == Original output:", | ||
torch.allclose(original_output, compiled_output, 1e-2, 1e-2), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,9 @@ | |
interpret_module_to_result, | ||
repair_double_inputs, | ||
) | ||
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( | ||
DYNAMO_ATEN_CONVERTERS, | ||
) | ||
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( | ||
DYNAMO_CONVERTERS as CONVERTERS, | ||
) | ||
|
@@ -788,20 +791,49 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: | |
"Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments." | ||
) | ||
|
||
############ TODO: testing only ############ | ||
use_hierarchical_partitioner = False | ||
backend_priority = ["inductor", "tensorrt"] | ||
backend_support_map = { | ||
"inductor": { | ||
# operator.getitem, | ||
torch.ops.aten.conv2d.default, | ||
torch.ops.aten.convolution.default, | ||
}, | ||
"tensorrt": set(DYNAMO_ATEN_CONVERTERS.keys()), | ||
} | ||
############################################# | ||
# Partition module into components that can be TRT-accelerated | ||
fast_partitioner_failed = False | ||
# If specified, try using the fast partitioner and fall back to the global one on failure | ||
if settings.use_fast_partitioner: | ||
try: | ||
logger.info("Partitioning the graph via the fast partitioner") | ||
partitioned_module, supported_ops = partitioning.fast_partition( | ||
gm, | ||
verbose=settings.debug, | ||
min_block_size=settings.min_block_size, | ||
torch_executed_ops=settings.torch_executed_ops, | ||
require_full_compilation=settings.require_full_compilation, | ||
skip_fusion=(num_supported_ops == total_ops), | ||
) | ||
if use_hierarchical_partitioner: | ||
logger.info( | ||
"Partitioning the graph via the fast hierarchical partitioner" | ||
) | ||
partitioned_module, supported_ops = ( | ||
partitioning.hierarchical_adjacency_partition( | ||
gm, | ||
verbose=settings.debug, | ||
min_block_size=settings.min_block_size, | ||
torch_executed_ops=settings.torch_executed_ops, | ||
require_full_compilation=settings.require_full_compilation, | ||
skip_fusion=(num_supported_ops == total_ops), | ||
backend_priority=backend_priority, | ||
backend_support_map=backend_support_map, | ||
) | ||
) | ||
else: | ||
logger.info("Partitioning the graph via the fast partitioner") | ||
partitioned_module, supported_ops = partitioning.fast_partition( | ||
gm, | ||
verbose=settings.debug, | ||
min_block_size=settings.min_block_size, | ||
torch_executed_ops=settings.torch_executed_ops, | ||
require_full_compilation=settings.require_full_compilation, | ||
skip_fusion=(num_supported_ops == total_ops), | ||
) | ||
|
||
except torch.fx.passes.splitter_base.FxNetSplitterInternalError: | ||
logger.error( | ||
|
@@ -836,7 +868,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: | |
submodule_node_dict[node.name] = node | ||
|
||
# Store TRT replicas of Torch subgraphs | ||
trt_modules = {} | ||
compiled_modules = {} | ||
# Iterate over all components that can be accelerated | ||
# Generate the corresponding TRT Module for those | ||
|
||
|
@@ -913,26 +945,61 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: | |
dryrun_tracker.tensorrt_graph_count += 1 | ||
dryrun_tracker.per_subgraph_data.append(subgraph_data) | ||
|
||
# Create TRT engines from submodule | ||
# Create TRT engines / compiled models from submodule | ||
# torch._logging.set_logs(inductor=logging.DEBUG) | ||
if not settings.dryrun: | ||
trt_module = convert_module( | ||
submodule, | ||
submodule_inputs, | ||
settings=settings, | ||
name=name, | ||
engine_cache=engine_cache, | ||
) | ||
if use_hierarchical_partitioner: | ||
# compile submodule with pytorch inductor | ||
if "_run_on_acc_inductor" in name: | ||
sub_inputs = [] | ||
for input in submodule_inputs: | ||
sub_input = ( | ||
torch.randn(input.shape) | ||
.to(dtype.to(input.dtype, t=torch.dtype)) | ||
.cuda() | ||
) | ||
sub_inputs.append(sub_input) | ||
|
||
compiled_func = torch._inductor.compile( | ||
submodule, | ||
sub_inputs, | ||
) | ||
Comment on lines
+956
to
+966
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
# Wrap the compiled function to be a torch.nn.Module | ||
compiled_submodule = FunctionWrapper(compiled_func) | ||
|
||
elif "_run_on_acc_tensorrt" in name: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there some sort of design where the capability and conversion parts can be grouped and registered? We can add this concept to the RFC for later There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That way we dont have a long conditional case set we just look up the appropriate conversion function on a standardized API There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah I agree that the conversion part of different backends should be grouped, but for now I don't have too much info about other backends (like how to convert an op to that backend). We can definitely do this when we are ready to support other backends. |
||
compiled_submodule = convert_module( | ||
submodule, | ||
submodule_inputs, | ||
settings=settings, | ||
name=name, | ||
engine_cache=engine_cache, | ||
) | ||
else: | ||
raise ValueError(f"Unknown backend for submodule: {name}") | ||
Comment on lines
+978
to
+979
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should there be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nope, as I mentioned above, |
||
else: | ||
compiled_submodule = convert_module( | ||
submodule, | ||
submodule_inputs, | ||
settings=settings, | ||
name=name, | ||
engine_cache=engine_cache, | ||
) | ||
|
||
trt_modules[name] = trt_module | ||
compiled_modules[name] = compiled_submodule | ||
|
||
# Parse the graph I/O and store it in dryrun tracker | ||
parse_graph_io(gm, dryrun_tracker) | ||
|
||
# Replace all FX Modules with TRT Modules | ||
for name, trt_module in trt_modules.items(): | ||
setattr(partitioned_module, name, trt_module) | ||
for name, compiled_module in compiled_modules.items(): | ||
setattr(partitioned_module, name, compiled_module) | ||
if settings.lazy_engine_init and not settings.enable_cross_compile_for_windows: | ||
getattr(partitioned_module, name).setup_engine() | ||
if use_hierarchical_partitioner: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar here, we should standardize post processing as well |
||
if "_run_on_acc_tensorrt" in name: | ||
getattr(partitioned_module, name).setup_engine() | ||
else: | ||
getattr(partitioned_module, name).setup_engine() | ||
|
||
# Reset settings object to user specification after fallback to global partitioning mode | ||
if fast_partitioner_failed: | ||
|
@@ -1276,3 +1343,12 @@ def load_cross_compiled_exported_program(file_path: str = "") -> Any: | |
) | ||
|
||
return replace_execute_engine_no_op_node(exp_program) | ||
|
||
|
||
class FunctionWrapper(torch.nn.Module): | ||
def __init__(self, func): | ||
super().__init__() | ||
self.func = func | ||
|
||
def forward(self, *args, **kwargs): | ||
return self.func(*args, **kwargs) | ||
Comment on lines
+1348
to
+1354
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider naming this to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
skip_fusion=False slows down the partitioning a lot. Can you check if it's really needed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that, if the num of ops of a GM is less than
min_block_size
, no matter it's TRT or other backend, the GM would not be compiled. Do you mean it shouldn't apply to other backends?My understanding is that
torch execute
is not considered as a backend because it doesn't need any compilation, it just runs ops in eager mode. So, if an op was intorch_executed_ops
, it would ignorebackend_support_map
and run in torch eager anyway.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since adjacency partitioner uses this flag, I just keep it here. yeah I can definitely switch it to True in the example.