Skip to content

Commit bdec720

Browse files
committed
add tests and rebase
1 parent d8509fa commit bdec720

File tree

5 files changed

+304
-8
lines changed

5 files changed

+304
-8
lines changed

examples/dynamo/hierarchical_partitioner_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@ def main():
7373
# 1. Partition the model into blocks that can be executed by different backends
7474
partitioned_model, op_support = hierarchical_adjacency_partition(
7575
gm,
76-
verbose=True,
7776
min_block_size=1,
7877
backend_priority=["inductor", "tensorrt"],
7978
backend_support_map={

py/torch_tensorrt/dynamo/partitioning/_adjacency_partitioner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,6 @@ def partition(
261261
262262
Args:
263263
gm: FX GraphModule to partition
264-
verbose: Bool representing whether to print operator support
265264
min_block_size: Minimum number of operators per TRT-Engine Block
266265
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
267266
require_full_compilation: Require that all computational operators be run in TRT

py/torch_tensorrt/dynamo/partitioning/_global_partitioner.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ def partition(
210210
211211
Args:
212212
gm: FX GraphModule to partition
213-
verbose: Bool representing whether to print operator support
214213
min_block_size: Minimum number of operators per TRT-Engine Block
215214
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
216215
require_full_compilation: Whether to require that all operators be run in TRT

py/torch_tensorrt/dynamo/partitioning/_hierarchical_partitioner.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
is_node_output_tensor,
1919
)
2020
from torch_tensorrt.dynamo._defaults import (
21-
DEBUG,
2221
MIN_BLOCK_SIZE,
2322
REQUIRE_FULL_COMPILATION,
2423
)
@@ -527,7 +526,6 @@ class FxNetSplitterInternalError(Exception):
527526

528527
def hierarchical_adjacency_partition(
529528
gm: torch.fx.GraphModule,
530-
verbose: bool = DEBUG,
531529
min_block_size: int = MIN_BLOCK_SIZE,
532530
torch_executed_ops: Collection[Target] = set(),
533531
backend_support_map: Optional[Dict[str, Collection[Target]]] = None,
@@ -540,7 +538,6 @@ def hierarchical_adjacency_partition(
540538
541539
Args:
542540
gm: FX GraphModule to partition
543-
verbose: Bool representing whether to print operator support
544541
min_block_size: Minimum number of operators per TRT-Engine Block
545542
backend_support_map: Dictionary mapping backend names to sets of supported operators
546543
backend_priority: Ordered list of backend names, from highest to lowest priority
@@ -583,7 +580,6 @@ def hierarchical_adjacency_partition(
583580

584581
partitioned_graph = partitioner.partition_graph()
585582

586-
if verbose:
587-
supported_ops.print_support_overview(partitioner.num_accelerated_subgraphs)
583+
supported_ops.print_support_overview(partitioner.num_accelerated_subgraphs)
588584

589585
return partitioned_graph, supported_ops
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
from copy import deepcopy
2+
3+
import numpy as np
4+
import torch
5+
from torch.testing._internal.common_utils import TestCase, run_tests
6+
from torch_tensorrt.dynamo import partitioning
7+
8+
9+
class TestHierarchicalAdjacencyPartitioning(TestCase):
10+
def test_hierarchical_adjacency_partition_fully_supported_one_op(self):
11+
class FullySupportedOneOp(torch.nn.Module):
12+
def __init__(self, *args, **kwargs) -> None:
13+
super().__init__(*args, **kwargs)
14+
15+
def forward(self, x, y):
16+
return torch.ops.aten.add.Tensor(x, y)
17+
18+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
19+
partitioned_graph, _ = partitioning.hierarchical_adjacency_partition(
20+
deepcopy(fx_graph),
21+
)
22+
self.assertEqual(
23+
len(
24+
[
25+
1
26+
for submod in list(partitioned_graph.named_children())
27+
if "_run_on_acc" in submod[0]
28+
]
29+
),
30+
0,
31+
"Single operators should not be segmented",
32+
)
33+
34+
def test_hierarchical_adjacency_partition_fully_supported_one_op_require_full_compilation(
35+
self,
36+
):
37+
class FullySupportedOneOp(torch.nn.Module):
38+
def __init__(self, *args, **kwargs) -> None:
39+
super().__init__(*args, **kwargs)
40+
41+
def forward(self, x, y):
42+
return torch.ops.aten.add.Tensor(x, y)
43+
44+
fx_graph = torch.fx.symbolic_trace(FullySupportedOneOp())
45+
partitioned_graph, _ = partitioning.hierarchical_adjacency_partition(
46+
deepcopy(fx_graph), require_full_compilation=True
47+
)
48+
self.assertEqual(
49+
len(
50+
[
51+
1
52+
for submod in list(partitioned_graph.named_children())
53+
if "_run_on_acc" in submod[0]
54+
]
55+
),
56+
1,
57+
"Single operators can be segmented if full compilation is required",
58+
)
59+
60+
def test_hierarchical_adjacency_partition_fully_supported_multi_op(self):
61+
class FullySupportedMultiOp(torch.nn.Module):
62+
def __init__(self, *args, **kwargs) -> None:
63+
super().__init__(*args, **kwargs)
64+
65+
def forward(self, x, y):
66+
sum_ = torch.ops.aten.sub.Tensor(x, y)
67+
concat_ = torch.ops.aten.cat.default(x, sum_)
68+
relu_ = torch.ops.aten.relu.default(concat_)
69+
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
70+
return pow_
71+
72+
fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
73+
partitioned_graph, _ = partitioning.hierarchical_adjacency_partition(
74+
deepcopy(fx_graph), min_block_size=2
75+
)
76+
self.assertEqual(
77+
len(
78+
[
79+
1
80+
for submod in list(partitioned_graph.named_children())
81+
if "_run_on_acc" in submod[0]
82+
]
83+
),
84+
1,
85+
"All operators are supported, there should be one segment",
86+
)
87+
88+
def test_hierarchical_adjacency_partition_partially_supported_multi_op(self):
89+
class PartiallySupportedMultiOp(torch.nn.Module):
90+
def __init__(self, *args, **kwargs) -> None:
91+
super().__init__(*args, **kwargs)
92+
93+
def forward(self, x, y):
94+
sum_1 = torch.ops.aten.add.Tensor(x, y)
95+
sum_2 = torch.ops.aten.add.Tensor(x, sum_1)
96+
sum_ = np.sum(sum_1) + np.sum(sum_2)
97+
relu_ = torch.ops.aten.relu.default(sum_)
98+
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
99+
return pow_
100+
101+
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
102+
partitioned_graph, _ = partitioning.hierarchical_adjacency_partition(
103+
deepcopy(fx_graph), min_block_size=2
104+
)
105+
self.assertEqual(
106+
len(
107+
[
108+
1
109+
for submod in list(partitioned_graph.named_children())
110+
if "_run_on_acc" in submod[0]
111+
]
112+
),
113+
2,
114+
"Unsupported operators interleave supported ones, expected 2 segments",
115+
)
116+
117+
def test_hierarchical_adjacency_partition_partially_supported_with_torch_executed_ops(
118+
self,
119+
):
120+
class PartiallySupportedMultiOp(torch.nn.Module):
121+
def __init__(self, *args, **kwargs) -> None:
122+
super().__init__(*args, **kwargs)
123+
124+
def forward(self, x, y):
125+
sum_1 = torch.ops.aten.add.Tensor(x, y)
126+
sum_2 = torch.ops.aten.add.Tensor(x, sum_1)
127+
sum_ = torch.ops.aten.add.Tensor(sum_1, sum_2)
128+
relu_ = torch.ops.aten.relu.default(sum_)
129+
pow_ = torch.ops.aten.pow.Tensor_Scalar(relu_, 2)
130+
return pow_
131+
132+
torch_executed_ops = {torch.ops.aten.add.Tensor}
133+
134+
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
135+
partitioned_graph, _ = partitioning.hierarchical_adjacency_partition(
136+
deepcopy(fx_graph),
137+
min_block_size=1,
138+
torch_executed_ops=torch_executed_ops,
139+
)
140+
141+
unexpected_ops = torch_executed_ops
142+
expected_ops = {torch.ops.aten.relu.default, torch.ops.aten.pow.Tensor_Scalar}
143+
144+
unexpected_ops_seen = set()
145+
expected_ops_seen = set()
146+
147+
for name, gm in partitioned_graph.named_children():
148+
if "_run_on_acc" in name:
149+
for node in gm.graph.nodes:
150+
if node.op == "call_function":
151+
if node.target in unexpected_ops:
152+
unexpected_ops_seen.add(node.target)
153+
elif node.target in expected_ops:
154+
expected_ops_seen.add(node.target)
155+
156+
expected_ops_unseen = expected_ops.difference(expected_ops_seen)
157+
158+
self.assertEqual(
159+
len(unexpected_ops_seen),
160+
0,
161+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
162+
)
163+
self.assertEqual(
164+
len(expected_ops_unseen),
165+
0,
166+
f"The following expected ops were not encountered: {expected_ops_unseen}",
167+
)
168+
169+
class SimpleModel(torch.nn.Module):
170+
def __init__(self):
171+
super().__init__()
172+
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
173+
self.conv2 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
174+
self.bn1 = torch.nn.BatchNorm2d(64)
175+
self.bn2 = torch.nn.BatchNorm2d(128)
176+
177+
def forward(self, x):
178+
x = self.conv1(x)
179+
x = self.bn1(x)
180+
x = torch.relu(x)
181+
x = self.conv2(x)
182+
x = self.bn2(x)
183+
x = torch.relu(x)
184+
return x
185+
186+
def test_hierarchical_adjacency_partition_with_two_backends(self):
187+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
188+
DYNAMO_CONVERTERS as CONVERTERS,
189+
)
190+
from torch_tensorrt.dynamo.lowering import (
191+
get_decompositions,
192+
pre_export_lowering,
193+
)
194+
195+
model = self.SimpleModel().cuda().eval()
196+
example_input = torch.randn(1, 3, 224, 224).cuda()
197+
198+
exported_program = torch.export.export(model, (example_input,))
199+
exported_program = pre_export_lowering(exported_program)
200+
exported_program = exported_program.run_decompositions(get_decompositions())
201+
gm = exported_program.module()
202+
203+
partitioned_graph, _ = partitioning.hierarchical_adjacency_partition(
204+
gm,
205+
min_block_size=1,
206+
backend_priority=["inductor", "tensorrt"],
207+
backend_support_map={
208+
"inductor": {
209+
"torch.ops.aten.convolution.default",
210+
},
211+
"tensorrt": CONVERTERS.keys(),
212+
},
213+
)
214+
215+
inductor_subgraphs_num = 0
216+
tensorrt_subgraphs_num = 0
217+
218+
for name, gm in partitioned_graph.named_children():
219+
if "_run_on_acc_inductor" in name:
220+
inductor_subgraphs_num += 1
221+
elif "_run_on_acc_tensorrt" in name:
222+
tensorrt_subgraphs_num += 1
223+
else:
224+
raise ValueError(f"Unknown backend: {name}")
225+
226+
self.assertEqual(
227+
inductor_subgraphs_num,
228+
2,
229+
"There should be 2 subgraphs running on inductor backend",
230+
)
231+
self.assertEqual(
232+
tensorrt_subgraphs_num,
233+
2,
234+
"There should be 2 subgraph running on tensorrt backend",
235+
)
236+
237+
def test_hierarchical_adjacency_partition_with_two_backends_with_torch_executed_ops(
238+
self,
239+
):
240+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
241+
DYNAMO_CONVERTERS as CONVERTERS,
242+
)
243+
from torch_tensorrt.dynamo.lowering import (
244+
get_decompositions,
245+
pre_export_lowering,
246+
)
247+
248+
model = self.SimpleModel().cuda().eval()
249+
example_input = torch.randn(1, 3, 224, 224).cuda()
250+
251+
exported_program = torch.export.export(model, (example_input,))
252+
exported_program = pre_export_lowering(exported_program)
253+
exported_program = exported_program.run_decompositions(get_decompositions())
254+
gm = exported_program.module()
255+
256+
partitioned_graph, _ = partitioning.hierarchical_adjacency_partition(
257+
gm,
258+
min_block_size=1,
259+
backend_priority=["inductor", "tensorrt"],
260+
backend_support_map={
261+
"inductor": {
262+
"torch.ops.aten.convolution.default",
263+
},
264+
"tensorrt": CONVERTERS.keys(),
265+
},
266+
torch_executed_ops={
267+
"torch.ops.aten._native_batch_norm_legit_no_training.default"
268+
},
269+
)
270+
271+
inductor_subgraphs_num = 0
272+
tensorrt_subgraphs_num = 0
273+
torch_gpu_subgraphs_num = 0
274+
275+
for name, gm in partitioned_graph.named_children():
276+
if "_run_on_acc_inductor" in name:
277+
inductor_subgraphs_num += 1
278+
elif "_run_on_acc_tensorrt" in name:
279+
tensorrt_subgraphs_num += 1
280+
elif "_run_on_gpu" in name:
281+
torch_gpu_subgraphs_num += 1
282+
else:
283+
raise ValueError(f"Unknown backend: {name}")
284+
285+
self.assertEqual(
286+
torch_gpu_subgraphs_num,
287+
2,
288+
"There should be 2 subgraphs running on torch gpu backend",
289+
)
290+
self.assertEqual(
291+
inductor_subgraphs_num,
292+
2,
293+
"There should be 2 subgraphs running on inductor backend",
294+
)
295+
self.assertEqual(
296+
tensorrt_subgraphs_num,
297+
2,
298+
"There should be 2 subgraph running on tensorrt backend",
299+
)
300+
301+
302+
if __name__ == "__main__":
303+
run_tests()

0 commit comments

Comments
 (0)