Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ void setup_input_tensors(
std::vector<at::Tensor> inputs,
c10::intrusive_ptr<TRTEngine> compiled_engine,
bool cudagraphs_enabled,
bool need_cudagraphs_record) {
// this is a buffer to store shape tensor input addresses throughout the runtime scope
std::list<std::vector<int64_t>> inputShapeTensorValues;
bool need_cudagraphs_record,
std::list<std::vector<int64_t>>& inputShapeTensorValues) {
std::list<at::Tensor> formatted_inputs(compiled_engine->num_io.first);

for (size_t i = 0; i < inputs.size(); i++) {
Expand All @@ -115,9 +114,10 @@ void setup_input_tensors(

auto dims = core::util::toDims(inputs[i].sizes());
auto shape = core::util::toVec(dims);
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);
bool is_shape_tensor = compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str());
LOG_DEBUG("Input Name: " << name << " Shape: " << dims << " isShapeInferenceIO: " << is_shape_tensor);

if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) {
if (is_shape_tensor) {
// Shape tensor inputs are casted to int64 explicitly.
// Refer to
// https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
Expand Down Expand Up @@ -233,6 +233,9 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

std::vector<at::Tensor> outputs(compiled_engine->num_io.second);

// Shape tensor CPU buffers must outlive inferShapes() and enqueueV3()
std::list<std::vector<int64_t>> inputShapeTensorValues;

// Intialize inputs and outputs to be available throughout the succeeding scopes
{ // Input Setup
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
Expand All @@ -241,7 +244,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
}

setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record);
setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, inputShapeTensorValues);
// Check if input shapes can be inferred.
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
std::vector<char const*> names(io_size);
Expand Down Expand Up @@ -364,14 +367,17 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
};

auto run_output_allocator = [&]() {
// Shape tensor CPU buffers must outlive inferShapes() and enqueueV3()
std::list<std::vector<int64_t>> inputShapeTensorValues;

{ // Input Setup
std::unique_ptr<torch::autograd::profiler::RecordProfile> input_profiler_guard;
if (compiled_engine->profile_execution) {
input_profiler_guard =
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
}

setup_input_tensors(inputs, compiled_engine, false, false);
setup_input_tensors(inputs, compiled_engine, false, false, inputShapeTensorValues);
// Check if input shapes can be inferred.
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
std::vector<char const*> names(io_size);
Expand Down
104 changes: 65 additions & 39 deletions py/torch_tensorrt/dynamo/conversion/_symbolic_shape_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,30 +50,45 @@ def extract_symbolic_shape_expressions(
return None

input_val = input_node.meta["val"]
if not isinstance(input_val, torch.Tensor):
logger.debug(
f"Input node '{input_node.name}': type={type(input_val)}, val={input_val}"
)
if isinstance(input_val, torch.Tensor):
shape_exprs = []
for dim_size in input_val.shape:
if isinstance(dim_size, torch.SymInt):
shape_exprs.append(dim_size.node.expr)
else:
shape_exprs.append(int(dim_size))

input_info.append(
{
"shape_exprs": shape_exprs,
"dtype": input_val.dtype,
"name": input_node.name,
}
)
elif isinstance(input_val, (torch.SymInt, torch.SymFloat, int, float, bool)):
if isinstance(input_val, (torch.SymInt, int)):
scalar_dtype = torch.int64
elif isinstance(input_val, (torch.SymFloat, float)):
scalar_dtype = torch.float64
else:
scalar_dtype = torch.bool
input_info.append(
{
"shape_exprs": [],
"dtype": scalar_dtype,
"name": input_node.name,
"is_scalar": True,
}
)
else:
logger.warning(
"When processing symbolic shapes for TensorRT engine, input is not a tensor"
f"When processing symbolic shapes for TensorRT engine, unsupported input type: {type(input_val)}"
)
return None

# Extract shape as sympy expressions (can be pickled)
shape_exprs = []
for dim_size in input_val.shape:
if isinstance(dim_size, torch.SymInt):
# Store the sympy expression, which can be pickled
shape_exprs.append(dim_size.node.expr)
else:
# Store concrete integer
shape_exprs.append(int(dim_size))

input_info.append(
{
"shape_exprs": shape_exprs,
"dtype": input_val.dtype,
"name": input_node.name,
}
)

# Extract output values from output node
output_args = output_node.args[0]
if not isinstance(output_args, (tuple, list)):
Expand All @@ -89,29 +104,40 @@ def extract_symbolic_shape_expressions(
return None

out_val = out_arg.meta["val"]
if not isinstance(out_val, torch.Tensor):
if isinstance(out_val, torch.Tensor):
shape_exprs = []
for dim_size in out_val.shape:
if isinstance(dim_size, torch.SymInt):
shape_exprs.append(dim_size.node.expr)
else:
shape_exprs.append(int(dim_size))

output_info.append(
{
"shape_exprs": shape_exprs,
"dtype": out_val.dtype,
}
)
elif isinstance(out_val, (torch.SymInt, torch.SymFloat, int, float, bool)):
if isinstance(out_val, (torch.SymInt, int)):
scalar_dtype = torch.int64
elif isinstance(out_val, (torch.SymFloat, float)):
scalar_dtype = torch.float64
else:
scalar_dtype = torch.bool
output_info.append(
{
"shape_exprs": [],
"dtype": scalar_dtype,
"is_scalar": True,
}
)
else:
logger.warning(
"When processing symbolic shapes for TensorRT engine, output is not a tensor"
f"When processing symbolic shapes for TensorRT engine, unsupported output type: {type(out_val)}"
)
return None

# Extract shape as sympy expressions (can be pickled)
shape_exprs = []
for dim_size in out_val.shape:
if isinstance(dim_size, torch.SymInt):
# Store the sympy expression, which can be pickled
shape_exprs.append(dim_size.node.expr)
else:
# Store concrete integer
shape_exprs.append(int(dim_size))

output_info.append(
{
"shape_exprs": shape_exprs,
"dtype": out_val.dtype,
}
)

if not output_info:
return None

Expand Down
204 changes: 204 additions & 0 deletions tests/py/dynamo/models/test_symint_scalar_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""
Tests for SymInt scalar input handling in symbolic shape capture and TRT compilation.

These tests verify that when Dynamo partitions an FX graph such that a SymInt
(e.g., from targets.size(0)) becomes a bare scalar placeholder input to the TRT
subgraph, the symbolic shape extraction and compilation succeed.

This covers the fix in _symbolic_shape_capture.py where non-tensor inputs
(SymInt, int, float, bool) are handled gracefully instead of aborting extraction.
"""

import unittest

import pytest
import torch
import torch_tensorrt as torchtrt
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity

assertions = unittest.TestCase()


@pytest.mark.unit
@pytest.mark.parametrize("use_python_runtime", [True, False])
def test_symint_from_size_used_in_reshape(use_python_runtime):
"""
Test that a SymInt derived from tensor.size(0) can be used in reshape
when it becomes a scalar placeholder input to the TRT subgraph.

This is the core pattern from issue #4107: targets.size(0) produces a
SymInt that Dynamo passes as a bare scalar input to the TRT partition,
which then uses it in a reshape operation.
"""

class Model(torch.nn.Module):
def forward(self, x, targets):
B = targets.size(0)
y = x.reshape(B, -1)
return y

model = Model().eval().cuda()

x = torch.randn(16, 64).cuda()
targets = torch.randint(0, 10, (16, 1), dtype=torch.int64).cuda()

torch._dynamo.mark_dynamic(x, 0, min=1, max=2048)
torch._dynamo.mark_dynamic(targets, 0, min=1, max=2048)

compile_spec = {
"enabled_precisions": {torch.float},
"min_block_size": 1,
"pass_through_build_failures": True,
"use_python_runtime": use_python_runtime,
}

trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)

output_ref = model(x, targets)
output_trt = trt_model(x, targets)

cos_sim = cosine_similarity(output_ref, output_trt)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"SymInt reshape test (python_runtime={use_python_runtime}) failed. Cosine sim: {cos_sim}",
)

torch._dynamo.reset()


@pytest.mark.unit
@pytest.mark.parametrize("use_python_runtime", [True, False])
def test_scalar_tensor_input(use_python_runtime):
"""
Test that a 0-dim scalar tensor input (e.g., cache_length) is handled
correctly during symbolic shape extraction and TRT compilation.
"""

class Model(torch.nn.Module):
def forward(self, x, offset):
return x + offset

model = Model().eval().cuda()

x = torch.randn(16, 64).cuda()
offset = torch.tensor(5.0).cuda()

compile_spec = {
"enabled_precisions": {torch.float},
"min_block_size": 1,
"pass_through_build_failures": True,
"use_python_runtime": use_python_runtime,
}

trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)

output_ref = model(x, offset)
output_trt = trt_model(x, offset)

cos_sim = cosine_similarity(output_ref, output_trt)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"Scalar tensor input test (python_runtime={use_python_runtime}) failed. Cosine sim: {cos_sim}",
)

torch._dynamo.reset()


@pytest.mark.unit
@pytest.mark.parametrize("use_python_runtime", [True, False])
def test_symint_with_index_and_reshape(use_python_runtime):
"""
Full reproduction of issue #4107 pattern: symbolic size from int64 tensor,
used with index operation and reshape.

Model does:
1. B = targets.size(0) → SymInt
2. idx = cache_length + arange(1) → int64 index tensor
3. y = x[:, idx, :] → gather with int64 index
4. z = y.reshape(B, 1, -1, 2) → reshape using SymInt
"""

class TestModule(torch.nn.Module):
def forward(self, x, targets, cache_length):
B = targets.size(0)
idx = cache_length + torch.arange(1, device=x.device)
y = x[:, idx, :]
z = y.reshape(B, 1, -1, 2)
return z

model = TestModule().eval().cuda()

B, S, D = 16, 128, 1024
x = torch.randn(B, S, D).cuda()
targets = torch.randint(0, 10, (B, 1), dtype=torch.int64).cuda()
cache_length = torch.tensor(0, dtype=torch.int64).cuda()

torch._dynamo.mark_dynamic(targets, 0, min=1, max=2048)
torch._dynamo.mark_dynamic(x, 0, min=1, max=2048)

compile_spec = {
"enabled_precisions": {torch.float, torch.half},
"min_block_size": 1,
"truncate_double": True,
"pass_through_build_failures": True,
"use_python_runtime": use_python_runtime,
}

trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)

output_ref = model(x, targets, cache_length)
output_trt = trt_model(x, targets, cache_length)

cos_sim = cosine_similarity(output_ref, output_trt)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"Issue 4107 repro test (python_runtime={use_python_runtime}) failed. Cosine sim: {cos_sim}",
)

torch._dynamo.reset()


@pytest.mark.unit
@pytest.mark.parametrize("use_python_runtime", [True, False])
def test_symint_with_different_batch_sizes(use_python_runtime):
"""
Test that after compilation with a SymInt scalar input, the model
produces correct results with different batch sizes.
"""

class Model(torch.nn.Module):
def forward(self, x, targets):
B = targets.size(0)
return x.reshape(B, 2, -1)

model = Model().eval().cuda()

x = torch.randn(8, 64).cuda()
targets = torch.randint(0, 10, (8, 1), dtype=torch.int64).cuda()

torch._dynamo.mark_dynamic(x, 0, min=1, max=2048)
torch._dynamo.mark_dynamic(targets, 0, min=1, max=2048)

compile_spec = {
"enabled_precisions": {torch.float},
"min_block_size": 1,
"pass_through_build_failures": True,
"use_python_runtime": use_python_runtime,
}

trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)

for batch_size in [4, 8, 16]:
x_test = torch.randn(batch_size, 64).cuda()
targets_test = torch.randint(0, 10, (batch_size, 1), dtype=torch.int64).cuda()

output_ref = model(x_test, targets_test)
output_trt = trt_model(x_test, targets_test)

cos_sim = cosine_similarity(output_ref, output_trt)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"Varying batch size test (python_runtime={use_python_runtime}) failed at B={batch_size}. Cosine sim: {cos_sim}",
)

torch._dynamo.reset()
Loading