Skip to content

Commit 619b9a0

Browse files
committed
fix: Allow full model compilation with collection Inputs
- Allow users to specify full model compilation when using `input_signature`, which allows for complex collection-based inputs - Enable "psuedo-partitioning" phase for input collections as well as output collections - Update `OutputIsCollection` to include dictionary outputs, and add function `InputIsCollection` to detect collection-based inputs during graph compilation - Remove automatic fallback for collection pack/unpack operations when using `input_signature` argument - Add collections tests to ensure full compilation is respected for input and output collections
1 parent c2126b1 commit 619b9a0

File tree

7 files changed

+165
-59
lines changed

7 files changed

+165
-59
lines changed

core/compiler.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,9 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
352352
// Determine if the block is convertible/has collection output, and based on the result,
353353
// whether full compilation can be expected
354354
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
355+
auto inputIsCollection = conversion::InputIsCollection(g->block());
355356
auto outputIsCollection = conversion::OutputIsCollection(g->block());
356-
auto requires_collection_handling = (isBlockConvertible && outputIsCollection);
357+
auto requires_collection_handling = (isBlockConvertible && (inputIsCollection || outputIsCollection));
357358

358359
// Determine whether user specifications necessitate partitioning
359360
auto isFallbackRequested = userRequestedFallback(cfg);

core/conversion/conversion.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -556,10 +556,20 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
556556
return convertable_ops;
557557
}
558558

559+
bool InputIsCollection(const torch::jit::Block* b) {
560+
for (auto in : b->inputs()) {
561+
if (in->type()->kind() == torch::jit::TypeKind::TupleType || in->type()->kind() == torch::jit::TypeKind::ListType) {
562+
return true;
563+
}
564+
}
565+
return false;
566+
}
567+
559568
bool OutputIsCollection(const torch::jit::Block* b) {
560569
for (auto out : b->outputs()) {
561570
if (out->type()->kind() == torch::jit::TypeKind::TupleType ||
562-
out->type()->kind() == torch::jit::TypeKind::ListType) {
571+
out->type()->kind() == torch::jit::TypeKind::ListType ||
572+
out->type()->kind() == torch::jit::TypeKind::DictType) {
563573
return true;
564574
}
565575
}

core/conversion/conversion.h

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ std::string ConvertBlockToEngine(
2626

2727
bool OpSupported(const torch::jit::Node* n);
2828

29+
bool InputIsCollection(const torch::jit::Block* b);
30+
2931
bool OutputIsCollection(const torch::jit::Block* b);
3032

3133
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors = false);

cpp/src/compile_spec.cpp

-21
Original file line numberDiff line numberDiff line change
@@ -74,27 +74,6 @@ torchtrt::core::CompileSpec init_compile_spec(CompileSpec& external) {
7474
LOG_WARNING("Input signature parsing is an experimental feature, behavior and APIs may change");
7575
to_internal_input_signature(external.graph_inputs.input_signature, converted_input_signature);
7676
torchtrt::core::CompileSpec internal(converted_input_signature);
77-
78-
TORCHTRT_CHECK(
79-
!external.require_full_compilation,
80-
"Grouped inputs currently requires partial compilation to be enabled, \
81-
this restriction will be relaxed in a future release");
82-
83-
LOG_DEBUG("Grouped inputs currently requires additional settings to enable the feature");
84-
LOG_DEBUG(
85-
"Adding the following ops to torch_executed_ops:" << std::endl
86-
<< " - aten::__getitem__" << std::endl
87-
<< " - prim::ListConstruct" << std::endl
88-
<< " - prim::ListUnpack" << std::endl
89-
<< " - prim::TupleIndex" << std::endl
90-
<< " - prim::TupleConstruct" << std::endl
91-
<< " - prim::TupleUnpack");
92-
external.torch_executed_ops.push_back("aten::__getitem__");
93-
external.torch_executed_ops.push_back("prim::ListConstruct");
94-
external.torch_executed_ops.push_back("prim::ListUnpack");
95-
external.torch_executed_ops.push_back("prim::TupleIndex");
96-
external.torch_executed_ops.push_back("prim::TupleConstruct");
97-
external.torch_executed_ops.push_back("prim::TupleUnpack");
9877
return internal;
9978
}
10079
}

py/torch_tensorrt/ts/_compile_spec.py

+1-36
Original file line numberDiff line numberDiff line change
@@ -268,42 +268,7 @@ def _parse_compile_spec(compile_spec_: Dict[str, Any]) -> _ts_C.CompileSpec:
268268
"Input signature parsing is an experimental feature, behavior and APIs may change",
269269
)
270270
signature = _parse_input_signature(compile_spec["input_signature"])
271-
info.input_signature = _C.InputSignature(signature) # py_object
272-
273-
if not compile_spec["torch_fallback"]["enabled"]:
274-
raise ValueError(
275-
"Grouped inputs currently requires partial compilation to be enabled, this restriction will be relaxed in a future release"
276-
)
277-
278-
log(
279-
Level.Debug,
280-
"Grouped inputs currently requires additional settings to enable the feature",
281-
)
282-
log(
283-
Level.Debug,
284-
"""Adding the following ops to torch_executed_ops:
285-
- aten::__getitem__
286-
- prim::ListConstruct
287-
- prim::ListUnpack
288-
- prim::TupleIndex
289-
- prim::TupleConstruct
290-
- prim::TupleUnpack
291-
""",
292-
)
293-
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
294-
"aten::__getitem__"
295-
)
296-
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
297-
"prim::ListConstruct"
298-
)
299-
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::ListUnpack")
300-
compile_spec["torch_fallback"]["forced_fallback_ops"].append("prim::TupleIndex")
301-
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
302-
"prim::TupleConstruct"
303-
)
304-
compile_spec["torch_fallback"]["forced_fallback_ops"].append(
305-
"prim::TupleUnpack"
306-
)
271+
info.input_signature = _C.InputSignature(signature)
307272

308273
else:
309274
raise KeyError(

tests/cpp/test_collections.cpp

+62
Original file line numberDiff line numberDiff line change
@@ -404,3 +404,65 @@ TEST(CppAPITests, TestCollectionComplexModel) {
404404
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
405405
out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor()));
406406
}
407+
408+
TEST(CppAPITests, TestCollectionFullCompilationComplexModel) {
409+
std::string path = "tests/modules/list_input_tuple_output_scripted.jit.pt";
410+
torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf);
411+
std::vector<at::Tensor> inputs;
412+
inputs.push_back(in0);
413+
414+
torch::jit::Module mod;
415+
try {
416+
// Deserialize the ScriptModule from a file using torch::jit::load().
417+
mod = torch::jit::load(path);
418+
} catch (const c10::Error& e) {
419+
std::cerr << "error loading the model\n";
420+
}
421+
mod.eval();
422+
mod.to(torch::kCUDA);
423+
424+
std::vector<torch::jit::IValue> inputs_;
425+
426+
for (auto in : inputs) {
427+
inputs_.push_back(torch::jit::IValue(in.clone()));
428+
}
429+
430+
std::vector<torch::jit::IValue> complex_inputs;
431+
auto input_list = c10::impl::GenericList(c10::TensorType::get());
432+
input_list.push_back(inputs_[0]);
433+
input_list.push_back(inputs_[0]);
434+
435+
torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
436+
437+
complex_inputs.push_back(input_list_ivalue);
438+
439+
auto out = mod.forward(complex_inputs);
440+
441+
auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf);
442+
443+
auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
444+
445+
c10::TypePtr elementType = input_shape_ivalue.type();
446+
auto list = c10::impl::GenericList(elementType);
447+
list.push_back(input_shape_ivalue);
448+
list.push_back(input_shape_ivalue);
449+
450+
torch::jit::IValue complex_input_shape(list);
451+
std::tuple<torch::jit::IValue> input_tuple2(complex_input_shape);
452+
torch::jit::IValue complex_input_shape2(input_tuple2);
453+
454+
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
455+
compile_settings.min_block_size = 1;
456+
compile_settings.require_full_compilation = true;
457+
458+
// // FP16 execution
459+
compile_settings.enabled_precisions = {torch::kHalf};
460+
// // Compile module
461+
auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings);
462+
auto trt_out = trt_mod.forward(complex_inputs);
463+
464+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
465+
out.toTuple()->elements()[0].toTensor(), trt_out.toTuple()->elements()[0].toTensor()));
466+
ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(
467+
out.toTuple()->elements()[1].toTensor(), trt_out.toTuple()->elements()[1].toTensor()));
468+
}

tests/py/api/test_collections.py

+87
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,34 @@ def test_compile(self):
194194
msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
195195
)
196196

197+
def test_compile_full_compilation(self):
198+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
199+
self.model = (
200+
torch.jit.load(MODULE_DIR + "/tuple_input_output_scripted.jit.pt")
201+
.eval()
202+
.to("cuda")
203+
)
204+
205+
compile_spec = {
206+
"input_signature": (
207+
(torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)),
208+
),
209+
"device": torchtrt.Device("gpu:0"),
210+
"enabled_precisions": {torch.float},
211+
"min_block_size": 1,
212+
"require_full_compilation": True,
213+
}
214+
215+
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
216+
trt_out = trt_mod((self.input, self.input))
217+
pyt_out = self.model((self.input, self.input))
218+
for (t, p) in zip(trt_out, pyt_out):
219+
cos_sim = cosine_similarity(t, p)
220+
self.assertTrue(
221+
cos_sim > COSINE_THRESHOLD,
222+
msg=f"tuple_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
223+
)
224+
197225

198226
class TestListInputOutput(unittest.TestCase):
199227
def test_compile(self):
@@ -225,6 +253,36 @@ def test_compile(self):
225253
msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
226254
)
227255

256+
def test_compile_full_compilation(self):
257+
258+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
259+
self.model = (
260+
torch.jit.load(MODULE_DIR + "/list_input_output_scripted.jit.pt")
261+
.eval()
262+
.to("cuda")
263+
)
264+
265+
compile_spec = {
266+
"input_signature": (
267+
[torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],
268+
),
269+
"device": torchtrt.Device("gpu:0"),
270+
"enabled_precisions": {torch.float},
271+
"min_block_size": 1,
272+
"require_full_compilation": True,
273+
}
274+
275+
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
276+
trt_out = trt_mod((self.input, self.input))
277+
pyt_out = self.model((self.input, self.input))
278+
279+
for (t, p) in zip(trt_out, pyt_out):
280+
cos_sim = cosine_similarity(t, p)
281+
self.assertTrue(
282+
cos_sim > COSINE_THRESHOLD,
283+
msg=f"list_input_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
284+
)
285+
228286

229287
class TestListInputTupleOutput(unittest.TestCase):
230288
def test_compile(self):
@@ -255,6 +313,35 @@ def test_compile(self):
255313
msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
256314
)
257315

316+
def test_compile_full_compilation(self):
317+
318+
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
319+
self.model = (
320+
torch.jit.load(MODULE_DIR + "/list_input_tuple_output_scripted.jit.pt")
321+
.eval()
322+
.to("cuda")
323+
)
324+
325+
compile_spec = {
326+
"input_signature": (
327+
[torchtrt.Input(self.input.shape), torchtrt.Input(self.input.shape)],
328+
),
329+
"device": torchtrt.Device("gpu:0"),
330+
"enabled_precisions": {torch.float},
331+
"min_block_size": 1,
332+
"require_full_compilation": True,
333+
}
334+
335+
trt_mod = torchtrt.ts.compile(self.model, **compile_spec)
336+
trt_out = trt_mod((self.input, self.input))
337+
pyt_out = self.model((self.input, self.input))
338+
for (t, p) in zip(trt_out, pyt_out):
339+
cos_sim = cosine_similarity(t, p)
340+
self.assertTrue(
341+
cos_sim > COSINE_THRESHOLD,
342+
msg=f"list_input_tuple_output_scripted TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
343+
)
344+
258345

259346
if __name__ == "__main__":
260347
unittest.main()

0 commit comments

Comments
 (0)