You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Running the vgg16_ptq example doesn't work and fails with error: torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)
To Reproduce
Copied the code from the site into a test.py. Instead of training the model directly, used the default vgg16_bn checkpoint from Pytorch (here).
Changes from original code are:
The module sizes were slightly different so I updated the values in the VGG module definition.
The checkpoint loading code is slightly different since it isn't nested anymore.
I don't think these changes should have affected anything.
import argparse
import modelopt.torch.quantization as mtq
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt as torchtrt
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from modelopt.torch.quantization.utils import export_torch_mode
class VGG(nn.Module):
def __init__(self, layer_spec, num_classes=1000, init_weights=False):
super(VGG, self).__init__()
layers = []
in_channels = 3
for l in layer_spec:
if l == "pool":
layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
else:
layers += [
nn.Conv2d(in_channels, l, kernel_size=3, padding=1),
nn.BatchNorm2d(l),
nn.ReLU(),
]
in_channels = l
self.features = nn.Sequential(*layers)
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
if init_weights:
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
def vgg16(num_classes=1000, init_weights=False):
vgg16_cfg = [
64,
64,
"pool",
128,
128,
"pool",
256,
256,
256,
"pool",
512,
512,
512,
"pool",
512,
512,
512,
"pool",
]
return VGG(vgg16_cfg, num_classes, init_weights)
PARSER = argparse.ArgumentParser(
description="Load pre-trained VGG model and then tune with FP8 and PTQ. For having a pre-trained VGG model, please refer to https://github.com/pytorch/TensorRT/tree/main/examples/int8/training/vgg16"
)
PARSER.add_argument(
"--ckpt", type=str, required=True, help="Path to the pre-trained checkpoint"
)
PARSER.add_argument(
"--batch-size",
default=128,
type=int,
help="Batch size for tuning the model with PTQ and FP8",
)
PARSER.add_argument(
"--quantize-type",
default="int8",
type=str,
help="quantization type, currently supported int8 or fp8 for PTQ",
)
args = PARSER.parse_args()
model = vgg16(num_classes=1000, init_weights=False)
model = model.cuda()
ckpt = torch.load(args.ckpt)
weights = ckpt
model.load_state_dict(weights)
# Don't forget to set the model to evaluation mode!
model.eval()
training_dataset = datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
),
)
training_dataloader = torch.utils.data.DataLoader(
training_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=2,
drop_last=True,
)
data = iter(training_dataloader)
images, _ = next(data)
crit = nn.CrossEntropyLoss()
def calibrate_loop(model):
# calibrate over the training dataset
total = 0
correct = 0
loss = 0.0
for data, labels in training_dataloader:
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = model(data)
loss += crit(out, labels)
preds = torch.max(out, 1)[1]
total += labels.size(0)
correct += (preds == labels).sum().item()
print("PTQ Loss: {:.5f} Acc: {:.2f}%".format(loss / total, 100 * correct / total))
if args.quantize_type == "int8":
quant_cfg = mtq.INT8_DEFAULT_CFG
elif args.quantize_type == "fp8":
quant_cfg = mtq.FP8_DEFAULT_CFG
# PTQ with in-place replacement to quantized modules
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has FP8 qdq nodes at this point
# Load the testing dataset
testing_dataset = datasets.CIFAR10(
root="./data",
train=False,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
),
)
testing_dataloader = torch.utils.data.DataLoader(
testing_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=2,
drop_last=True,
) # set drop_last=True to drop the last incomplete batch for static shape `torchtrt.dynamo.compile()`
with torch.no_grad():
with export_torch_mode():
# Compile the model with Torch-TensorRT Dynamo backend
input_tensor = images.cuda()
# torch.export.export() failed due to RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()
from torch.export._trace import _export
exp_program = _export(model, (input_tensor,))
if args.quantize_type == "int8":
enabled_precisions = {torch.int8}
elif args.quantize_type == "fp8":
enabled_precisions = {torch.float8_e4m3fn}
trt_model = torchtrt.dynamo.compile(
exp_program,
inputs=[input_tensor],
enabled_precisions=enabled_precisions,
min_block_size=1,
debug=True,
)
# You can also use torch compile path to compile the model with Torch-TensorRT:
# trt_model = torch.compile(model, backend="tensorrt")
# Inference compiled Torch-TensorRT model over the testing dataset
total = 0
correct = 0
loss = 0.0
class_probs = []
class_preds = []
for data, labels in testing_dataloader:
data, labels = data.cuda(), labels.cuda(non_blocking=True)
out = trt_model(data)
loss += crit(out, labels)
preds = torch.max(out, 1)[1]
class_probs.append([F.softmax(i, dim=0) for i in out])
class_preds.append(preds)
total += labels.size(0)
correct += (preds == labels).sum().item()
test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
test_preds = torch.cat(class_preds)
test_loss = loss / total
test_acc = correct / total
print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
I then run (vgg16_bn-6c64b313.pth is the locally downloaded checkpoint):
python test.py --ckpt vgg16_bn-6c64b313.pth
And get error:
torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)
from user code:
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/modelopt/torch/opt/dynamic.py", line 376, in _get_dm_attribute_manager
return self._dm_attribute_manager
Full trace
[WARNING | root ]: Supported flash-attn versions are >= 2.1.1, <= 2.6.3. Found flash-attn 2.7.4.post1.
[WARNING | torch_tensorrt.dynamo.conversion.converter_utils]: TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops
[02/27/2025-05:44:05] [TRT] [W] Functionality provided through tensorrt.plugin module is experimental.
Inserted 86 quantizers
PTQ Loss: 0.18583 Acc: 0.00%
Loading extension modelopt_cuda_ext...
Traceback (most recent call last):
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 263, in __call__
self.call_reconstruct(value)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 90, in call_reconstruct
res = value.reconstruct(self)
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/variables/base.py", line 358, in reconstruct
raise NotImplementedError
NotImplementedError
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/ubuntu/test.py", line 199, in <module>
exp_program = _export(model, (input_tensor,))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1035, in wrapper
raise e
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1008, in wrapper
ep = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/exported_program.py", line 128, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1990, in _export
export_artifact = export_func( # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1255, in _strict_export
return _strict_export_lower_to_aten_ir(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
gm_torch_level = _export_to_torch_ir(
^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
gm_torch_level, _ = torch._dynamo.export(
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
result_traced = opt_f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/test.py", line 57, in forward
def forward(self, x):
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/modelopt/torch/opt/dynamic.py", line 786, in __getattr__
manager = self._get_dm_attribute_manager(use_default=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
return self._torchdynamo_orig_callable(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
return _compile(
^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
return _compile_inner(code, one_graph, hooks, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
return function(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
out_code = transform_code_object(code, transform)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
transformations(instructions, code_options)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
tracer.run()
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
super().run()
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
while self.step():
^^^^^^^^^^^
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
self.dispatch_table[inst.opcode](self, inst)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3048, in RETURN_VALUE
self._return(inst)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3033, in _return
self.output.compile_subgraph(
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1120, in compile_subgraph
self.codegen_suffix(tx, stack_values, pass1)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1193, in codegen_suffix
cg.restore_stack(stack_values, value_from_source=not tx.export)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 82, in restore_stack
self.foreach(stack_values)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 293, in foreach
self(i)
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 265, in __call__
unimplemented(f"reconstruct: {value}")
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)
from user code:
File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/modelopt/torch/opt/dynamic.py", line 376, in _get_dm_attribute_manager
return self._dm_attribute_manager
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Expected behavior
Should run correctly.
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
Torch-TensorRT Version (e.g. 1.0.0): 2.6.0
PyTorch Version (e.g. 1.0): 2.6.0
CPU Architecture: x86
OS (e.g., Linux): Ubuntu 22.04.5 LTS
How you installed PyTorch (conda, pip, libtorch, source): pip
Build command you used (if compiling from source): pip install torch_tensorrt
Are you using local sources or building from archives:
Python version: 3.11.10
CUDA version: 550.127.05
GPU models and configuration: Nvidia L4
Any other relevant information:
The text was updated successfully, but these errors were encountered:
Bug Description
Running the vgg16_ptq example doesn't work and fails with error:
torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)
To Reproduce
Copied the code from the site into a
test.py
. Instead of training the model directly, used the default vgg16_bn checkpoint from Pytorch (here).Changes from original code are:
I don't think these changes should have affected anything.
test.py
(pastebin link if easier)I then run (
vgg16_bn-6c64b313.pth
is the locally downloaded checkpoint):And get error:
Full trace
Expected behavior
Should run correctly.
Environment
conda
,pip
,libtorch
, source): pippip install torch_tensorrt
The text was updated successfully, but these errors were encountered: