|
| 1 | +from typing import Any, Callable |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +import torch_tensorrt |
| 6 | +from torch_tensorrt._enums import dtype |
| 7 | +from torch_tensorrt.dynamo import partitioning |
| 8 | +from torch_tensorrt.dynamo._compiler import convert_module |
| 9 | +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( |
| 10 | + DYNAMO_CONVERTERS as CONVERTERS, |
| 11 | +) |
| 12 | +from torch_tensorrt.dynamo.lowering import ( |
| 13 | + get_decompositions, |
| 14 | + pre_export_lowering, |
| 15 | +) |
| 16 | +from torch_tensorrt.dynamo.partitioning._hierarchical_partitioner import ( |
| 17 | + hierarchical_adjacency_partition, |
| 18 | +) |
| 19 | +from torch_tensorrt.dynamo.utils import ( |
| 20 | + get_output_metadata, |
| 21 | +) |
| 22 | +from torchvision import models |
| 23 | + |
| 24 | + |
| 25 | +class InductorModule(torch.nn.Module): # type: ignore[misc] |
| 26 | + """Wrapper module for inductor compiled function.""" |
| 27 | + |
| 28 | + def __init__(self, func: Callable[..., Any]) -> None: |
| 29 | + super().__init__() |
| 30 | + self.func = func |
| 31 | + |
| 32 | + def forward(self, *args: Any, **kwargs: Any) -> Any: |
| 33 | + return self.func(*args, **kwargs) |
| 34 | + |
| 35 | + |
| 36 | +class SimpleModel(nn.Module): |
| 37 | + def __init__(self): |
| 38 | + super().__init__() |
| 39 | + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) |
| 40 | + self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) |
| 41 | + self.bn1 = nn.BatchNorm2d(64) |
| 42 | + self.bn2 = nn.BatchNorm2d(128) |
| 43 | + |
| 44 | + def forward(self, x): |
| 45 | + x = self.conv1(x) |
| 46 | + x = self.bn1(x) |
| 47 | + x = torch.relu(x) |
| 48 | + x = self.conv2(x) |
| 49 | + x = self.bn2(x) |
| 50 | + x = torch.relu(x) |
| 51 | + return x |
| 52 | + |
| 53 | + |
| 54 | +def main(): |
| 55 | + # Create model |
| 56 | + model = SimpleModel().cuda() |
| 57 | + # model = models.efficientnet_b0(pretrained=True).cuda() |
| 58 | + model = model.eval() |
| 59 | + |
| 60 | + # Create example input |
| 61 | + example_input = torch.randn(1, 3, 224, 224).cuda() |
| 62 | + |
| 63 | + exported_program = torch.export.export(model, (example_input,)) |
| 64 | + exported_program = pre_export_lowering(exported_program) |
| 65 | + exported_program = exported_program.run_decompositions(get_decompositions()) |
| 66 | + |
| 67 | + gm = exported_program.module() |
| 68 | + |
| 69 | + print("Original Model Structure:\n", gm) |
| 70 | + |
| 71 | + original_output = model(example_input) |
| 72 | + |
| 73 | + # 1. Partition the model into blocks that can be executed by different backends |
| 74 | + partitioned_model, op_support = hierarchical_adjacency_partition( |
| 75 | + gm, |
| 76 | + min_block_size=1, |
| 77 | + backend_priority=["inductor", "tensorrt"], |
| 78 | + backend_support_map={ |
| 79 | + "inductor": { |
| 80 | + "torch.ops.aten.convolution.default", |
| 81 | + }, |
| 82 | + "tensorrt": CONVERTERS.keys(), |
| 83 | + }, |
| 84 | + torch_executed_ops={ |
| 85 | + "torch.ops.aten._native_batch_norm_legit_no_training.default" |
| 86 | + }, |
| 87 | + require_full_compilation=False, |
| 88 | + skip_fusion=True, |
| 89 | + ) |
| 90 | + |
| 91 | + print("1. Partitioned Model Structure:\n", partitioned_model) |
| 92 | + |
| 93 | + # 2. Compile each submodule with the corresponding backend |
| 94 | + submodule_node_dict = {} |
| 95 | + for node in partitioned_model.graph.nodes: |
| 96 | + if "_run_on_acc" not in node.name: |
| 97 | + continue |
| 98 | + submodule_node_dict[node.name] = node |
| 99 | + |
| 100 | + # Store compiled replicas of Torch subgraphs |
| 101 | + compiled_modules = {} |
| 102 | + |
| 103 | + for name, _ in partitioned_model.named_children(): |
| 104 | + submodule = getattr(partitioned_model, name) |
| 105 | + if not isinstance(submodule, torch.fx.graph_module.GraphModule): |
| 106 | + continue |
| 107 | + |
| 108 | + if "_run_on_acc" not in name: |
| 109 | + submodule.to("cuda") |
| 110 | + continue |
| 111 | + |
| 112 | + if name not in submodule_node_dict: |
| 113 | + raise ValueError( |
| 114 | + f"node_name: {name} does not exist in the submodule node dictionary" |
| 115 | + ) |
| 116 | + |
| 117 | + # set the submodule metadata back to the parent module_node |
| 118 | + metadata_list = get_output_metadata(submodule) |
| 119 | + assert len(metadata_list) > 0 |
| 120 | + metadata_keys = ["val", "tensor_meta"] |
| 121 | + for key in metadata_keys: |
| 122 | + if key not in submodule_node_dict[name].meta: |
| 123 | + meta_val_list = [ |
| 124 | + metadata[key] for metadata in metadata_list if key in metadata |
| 125 | + ] |
| 126 | + submodule_node_dict[name].meta[key] = meta_val_list |
| 127 | + break |
| 128 | + |
| 129 | + # Get the submodule inputs for min, opt, max shapes of the graph inputs |
| 130 | + submodule_inputs = partitioning.construct_submodule_inputs(submodule) |
| 131 | + assert submodule_inputs is not None |
| 132 | + |
| 133 | + # compile submodule with pytorch inductor backend |
| 134 | + if "_run_on_acc_inductor" in name: |
| 135 | + sub_inputs = [] |
| 136 | + for input in submodule_inputs: |
| 137 | + sub_input = input.torch_tensor.to( |
| 138 | + dtype.to(input.dtype, t=torch.dtype) |
| 139 | + ).cuda() |
| 140 | + sub_inputs.append(sub_input) |
| 141 | + |
| 142 | + compiled_func = torch._inductor.compile( |
| 143 | + submodule, |
| 144 | + sub_inputs, |
| 145 | + ) |
| 146 | + # Wrap the compiled function to be a torch.nn.Module |
| 147 | + compiled_submodule = InductorModule(compiled_func) |
| 148 | + |
| 149 | + # compile submodule with tensorrt backend |
| 150 | + elif "_run_on_acc_tensorrt" in name: |
| 151 | + compiled_submodule = convert_module( |
| 152 | + submodule, |
| 153 | + submodule_inputs, |
| 154 | + name=name, |
| 155 | + ) |
| 156 | + else: |
| 157 | + raise ValueError(f"Unknown backend for submodule: {name}") |
| 158 | + |
| 159 | + compiled_modules[name] = compiled_submodule |
| 160 | + |
| 161 | + # Replace all FX Modules with compiled Modules |
| 162 | + for name, compiled_module in compiled_modules.items(): |
| 163 | + setattr(partitioned_model, name, compiled_module) |
| 164 | + |
| 165 | + print("2. Compiled Model Structure:\n", partitioned_model) |
| 166 | + |
| 167 | + with torch.no_grad(): |
| 168 | + partitioned_output = partitioned_model(example_input) |
| 169 | + print( |
| 170 | + "3. Verify that Partitioned output == Original output:", |
| 171 | + torch.allclose(partitioned_output, original_output, 1e-2, 1e-2), |
| 172 | + ) |
| 173 | + |
| 174 | + |
| 175 | +if __name__ == "__main__": |
| 176 | + main() |
0 commit comments