Skip to content

Commit f8f0f55

Browse files
committed
adding test case for multiple outputs and moving op to register_meta_ops
1 parent 108b0ec commit f8f0f55

File tree

3 files changed

+49
-21
lines changed

3 files changed

+49
-21
lines changed

py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py

+21
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,24 @@ def __setstate__(self, serialized_state: List[str]) -> Any:
150150

151151
def __getstate__(self) -> Any:
152152
pass
153+
154+
155+
@torch.library.custom_op(
156+
"tensorrt::no_op_placeholder_for_execute_engine", mutates_args=()
157+
)
158+
def no_op_placeholder_for_execute_engine(
159+
inputs: List[torch.Tensor],
160+
abi_version: str,
161+
name: str,
162+
serialized_device_info: str,
163+
serialized_engine: str,
164+
serialized_in_binding_names: str,
165+
serialized_out_binding_names: str,
166+
serialized_hardware_compatible: str,
167+
serialized_metadata: str,
168+
serialized_target_platform: str,
169+
serialized_require_output_allocator: str,
170+
) -> List[torch.Tensor]:
171+
raise RuntimeError(
172+
"The saved model is cross compiled for windows in Linux, should only be loadded in Windows via torch_tensorrt.load_cross_compiled_exported_program() api."
173+
)

py/torch_tensorrt/runtime/_utils.py

-21
Original file line numberDiff line numberDiff line change
@@ -128,24 +128,3 @@ def _get_most_compatible_device(
128128
best_match = candidate
129129

130130
return best_match
131-
132-
133-
@torch.library.custom_op(
134-
"tensorrt::no_op_placeholder_for_execute_engine", mutates_args=()
135-
)
136-
def no_op_placeholder_for_execute_engine(
137-
inputs: List[torch.Tensor],
138-
abi_version: str,
139-
name: str,
140-
serialized_device_info: str,
141-
serialized_engine: str,
142-
serialized_in_binding_names: str,
143-
serialized_out_binding_names: str,
144-
serialized_hardware_compatible: str,
145-
serialized_metadata: str,
146-
serialized_target_platform: str,
147-
serialized_require_output_allocator: str,
148-
) -> List[torch.Tensor]:
149-
raise RuntimeError(
150-
"The saved model is cross compiled for windows in Linux, should only be loadded in Windows via torch_tensorrt.load_cross_compiled_exported_program() api."
151-
)

tests/py/dynamo/runtime/test_003_cross_compile_for_windows.py

+28
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,31 @@ def forward(self, a, b):
6363
)
6464
except Exception as e:
6565
pytest.fail(f"unexpected exception raised: {e}")
66+
67+
@unittest.skipIf(
68+
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
69+
"Cross compile for windows can only be enabled on linux x86-64 platform",
70+
)
71+
@pytest.mark.unit
72+
def test_dynamo_cross_compile_for_windows_multiple_output(self):
73+
class Add(torch.nn.Module):
74+
def forward(self, a, b):
75+
return torch.add(a, b), torch.add(a, b)
76+
77+
model = Add().eval().cuda()
78+
inputs = (torch.randn(2, 3).cuda(), torch.randn(2, 3).cuda())
79+
trt_ep_path = os.path.join(tempfile.gettempdir(), "trt.ep")
80+
exp_program = torch.export.export(model, inputs)
81+
compile_spec = {
82+
"inputs": inputs,
83+
"min_block_size": 1,
84+
}
85+
try:
86+
trt_gm = torch_tensorrt.dynamo.cross_compile_for_windows(
87+
exp_program, **compile_spec
88+
)
89+
torch_tensorrt.dynamo.save_cross_compiled_exported_program(
90+
trt_gm, file_path=trt_ep_path
91+
)
92+
except Exception as e:
93+
pytest.fail(f"unexpected exception raised: {e}")

0 commit comments

Comments
 (0)