|
| 1 | +import time |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn as nn |
| 5 | +import torch.nn.functional as F |
| 6 | +import torch_tensorrt as torchtrt |
| 7 | +import torchvision |
| 8 | +from pyinstrument import Profiler |
| 9 | +from torch_tensorrt.dynamo.utils import get_model_device |
| 10 | + |
| 11 | +torch.manual_seed(0) |
| 12 | +torch.cuda.manual_seed_all(0) |
| 13 | +import argparse |
| 14 | + |
| 15 | + |
| 16 | +def benchmark_model(model, input, label, profile=False): |
| 17 | + if profile: |
| 18 | + profiler = Profiler(interval=0.01) |
| 19 | + profiler.start() |
| 20 | + start_time = time.time() |
| 21 | + for _ in range(1000): |
| 22 | + model_outputs = model(*input) |
| 23 | + end_time = time.time() |
| 24 | + print(f"{label} 1000 runs: {end_time - start_time:.4f} seconds") |
| 25 | + if profile: |
| 26 | + profiler.stop() |
| 27 | + profiler.write_html( |
| 28 | + f"/home/other/{label.replace(' ', '_')}.html", timeline=False, show_all=True |
| 29 | + ) |
| 30 | + |
| 31 | + |
| 32 | +def main(args): |
| 33 | + profile = args.profile |
| 34 | + use_python_runtime = args.use_python_runtime |
| 35 | + model_name = args.model |
| 36 | + |
| 37 | + with torchtrt.dynamo.Debugger(log_level="debug", engine_builder_monitor=False): |
| 38 | + |
| 39 | + model = ( |
| 40 | + torchvision.models.__dict__[model_name](pretrained=True).eval().to("cuda") |
| 41 | + ) |
| 42 | + input = [torch.randn((1, 3, 224, 224)).to("cuda")] |
| 43 | + |
| 44 | + BATCH = torch.export.Dim("BATCH", min=1, max=16) |
| 45 | + exp_program = torch.export.export(model, tuple(input), strict=True) |
| 46 | + trt_mod2 = trt_gm = torchtrt.dynamo.compile( |
| 47 | + exp_program, |
| 48 | + tuple(input), |
| 49 | + use_python_runtime=use_python_runtime, |
| 50 | + enabled_precisions={torch.float}, |
| 51 | + min_block_size=1, |
| 52 | + immutable_weights=False, |
| 53 | + reuse_cached_engines=False, |
| 54 | + ) |
| 55 | + |
| 56 | + trt_mod1 = trt_gm = torchtrt.dynamo.compile( |
| 57 | + exp_program, |
| 58 | + tuple(input), |
| 59 | + use_python_runtime=use_python_runtime, |
| 60 | + enabled_precisions={torch.float}, |
| 61 | + min_block_size=1, |
| 62 | + immutable_weights=False, |
| 63 | + torch_executed_ops={torch.ops.aten.relu.default}, |
| 64 | + reuse_cached_engines=False, |
| 65 | + ) |
| 66 | + |
| 67 | + # AOTI |
| 68 | + if not use_python_runtime: |
| 69 | + torchtrt.save( |
| 70 | + trt_mod1, |
| 71 | + "/home/other/aoti.pt2", |
| 72 | + output_format="aot_inductor", |
| 73 | + inputs=input, |
| 74 | + retrace=True, |
| 75 | + ) |
| 76 | + aoti_model_gb = torch._inductor.aoti_load_package("/home/other/aoti.pt2") |
| 77 | + torchtrt.save( |
| 78 | + trt_mod2, |
| 79 | + "/home/other/aoti_no_gb.pt2", |
| 80 | + output_format="aot_inductor", |
| 81 | + inputs=input, |
| 82 | + retrace=True, |
| 83 | + ) |
| 84 | + aoti_model_no_gb = torch._inductor.aoti_load_package( |
| 85 | + "/home/other/aoti_no_gb.pt2" |
| 86 | + ) |
| 87 | + |
| 88 | + # Warmup runs to avoid measuring first-run overheads |
| 89 | + for _ in range(100): |
| 90 | + trt_mod2(*input) |
| 91 | + model(*input) |
| 92 | + if not use_python_runtime: |
| 93 | + aoti_model_gb(*input) |
| 94 | + aoti_model_no_gb(*input) |
| 95 | + |
| 96 | + time.sleep(1) |
| 97 | + benchmark_model(trt_mod1, input, "trt_mod1 (with graph break)", profile=profile) |
| 98 | + benchmark_model(trt_mod2, input, "trt_mod2 (without graph break)", profile=profile) |
| 99 | + if not use_python_runtime: |
| 100 | + benchmark_model(aoti_model_gb, input, "aoti_model_gb", profile=profile) |
| 101 | + benchmark_model(aoti_model_no_gb, input, "aoti_model_no_gb", profile=profile) |
| 102 | + |
| 103 | + out1 = trt_mod1(*input) |
| 104 | + out2 = trt_mod2(*input) |
| 105 | + if not use_python_runtime: |
| 106 | + out3 = aoti_model_gb(*input) |
| 107 | + out4 = aoti_model_no_gb(*input) |
| 108 | + |
| 109 | + def _to_tuple(x): |
| 110 | + if isinstance(x, (tuple, list)): |
| 111 | + return tuple(x) |
| 112 | + return (x,) |
| 113 | + |
| 114 | + outs1 = _to_tuple(out1) |
| 115 | + outs2 = _to_tuple(out2) |
| 116 | + if not use_python_runtime: |
| 117 | + outs3 = _to_tuple(out3) |
| 118 | + outs4 = _to_tuple(out4) |
| 119 | + |
| 120 | + def compare_outputs(a, b, name1="A", name2="B"): |
| 121 | + if len(a) != len(b): |
| 122 | + print(f"Number of outputs differ: {len(a)} vs {len(b)}") |
| 123 | + return False |
| 124 | + all_equal = True |
| 125 | + for i, (x, y) in enumerate(zip(a, b)): |
| 126 | + if not torch.allclose(x, y, atol=1e-3, rtol=1e-3): |
| 127 | + print(f"Output {i} differs between {name1} and {name2}") |
| 128 | + print(f"max diff: {torch.max(torch.abs(x - y))}") |
| 129 | + print(f"Mean diff: {torch.mean(torch.abs(x - y))}") |
| 130 | + all_equal = False |
| 131 | + if all_equal: |
| 132 | + print(f"All outputs match between {name1} and {name2}") |
| 133 | + return all_equal |
| 134 | + |
| 135 | + compare_outputs(outs1, outs2, "trt_mod1", "trt_mod2") |
| 136 | + if not use_python_runtime: |
| 137 | + compare_outputs(outs1, outs3, "trt_mod1", "aoti_model_gb") |
| 138 | + compare_outputs(outs1, outs4, "trt_mod1", "aoti_model_no_gb") |
| 139 | + compare_outputs(outs2, outs3, "trt_mod2", "aoti_model") |
| 140 | + |
| 141 | + |
| 142 | +if __name__ == "__main__": |
| 143 | + arg_parser = argparse.ArgumentParser() |
| 144 | + arg_parser.add_argument("--profile", action="store_true") |
| 145 | + arg_parser.add_argument("--use_python_runtime", action="store_true") |
| 146 | + arg_parser.add_argument( |
| 147 | + "--model", type=str, default="resnet18", choices=["resnet18", "resnet152"] |
| 148 | + ) |
| 149 | + args = arg_parser.parse_args() |
| 150 | + main(args) |
0 commit comments