Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] vgg16_ptq doesn't run correctly. #3419

Open
dragoneye-alex opened this issue Feb 27, 2025 · 0 comments
Open

🐛 [Bug] vgg16_ptq doesn't run correctly. #3419

dragoneye-alex opened this issue Feb 27, 2025 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@dragoneye-alex
Copy link

Bug Description

Running the vgg16_ptq example doesn't work and fails with error: torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)

To Reproduce

Copied the code from the site into a test.py. Instead of training the model directly, used the default vgg16_bn checkpoint from Pytorch (here).

Changes from original code are:

  1. The module sizes were slightly different so I updated the values in the VGG module definition.
  2. The checkpoint loading code is slightly different since it isn't nested anymore.

I don't think these changes should have affected anything.

test.py (pastebin link if easier)

import argparse

import modelopt.torch.quantization as mtq
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt as torchtrt
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from modelopt.torch.quantization.utils import export_torch_mode


class VGG(nn.Module):
    def __init__(self, layer_spec, num_classes=1000, init_weights=False):
        super(VGG, self).__init__()

        layers = []
        in_channels = 3
        for l in layer_spec:
            if l == "pool":
                layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                layers += [
                    nn.Conv2d(in_channels, l, kernel_size=3, padding=1),
                    nn.BatchNorm2d(l),
                    nn.ReLU(),
                ]
                in_channels = l

        self.features = nn.Sequential(*layers)
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


def vgg16(num_classes=1000, init_weights=False):
    vgg16_cfg = [
        64,
        64,
        "pool",
        128,
        128,
        "pool",
        256,
        256,
        256,
        "pool",
        512,
        512,
        512,
        "pool",
        512,
        512,
        512,
        "pool",
    ]
    return VGG(vgg16_cfg, num_classes, init_weights)


PARSER = argparse.ArgumentParser(
    description="Load pre-trained VGG model and then tune with FP8 and PTQ. For having a pre-trained VGG model, please refer to https://github.com/pytorch/TensorRT/tree/main/examples/int8/training/vgg16"
)
PARSER.add_argument(
    "--ckpt", type=str, required=True, help="Path to the pre-trained checkpoint"
)
PARSER.add_argument(
    "--batch-size",
    default=128,
    type=int,
    help="Batch size for tuning the model with PTQ and FP8",
)
PARSER.add_argument(
    "--quantize-type",
    default="int8",
    type=str,
    help="quantization type, currently supported int8 or fp8 for PTQ",
)
args = PARSER.parse_args()

model = vgg16(num_classes=1000, init_weights=False)
model = model.cuda()


ckpt = torch.load(args.ckpt)
weights = ckpt


model.load_state_dict(weights)
# Don't forget to set the model to evaluation mode!
model.eval()

training_dataset = datasets.CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    ),
)
training_dataloader = torch.utils.data.DataLoader(
    training_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=2,
    drop_last=True,
)

data = iter(training_dataloader)
images, _ = next(data)

crit = nn.CrossEntropyLoss()

def calibrate_loop(model):
    # calibrate over the training dataset
    total = 0
    correct = 0
    loss = 0.0
    for data, labels in training_dataloader:
        data, labels = data.cuda(), labels.cuda(non_blocking=True)
        out = model(data)
        loss += crit(out, labels)
        preds = torch.max(out, 1)[1]
        total += labels.size(0)
        correct += (preds == labels).sum().item()

    print("PTQ Loss: {:.5f} Acc: {:.2f}%".format(loss / total, 100 * correct / total))

if args.quantize_type == "int8":
    quant_cfg = mtq.INT8_DEFAULT_CFG
elif args.quantize_type == "fp8":
    quant_cfg = mtq.FP8_DEFAULT_CFG
# PTQ with in-place replacement to quantized modules
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
# model has FP8 qdq nodes at this point


# Load the testing dataset
testing_dataset = datasets.CIFAR10(
    root="./data",
    train=False,
    download=True,
    transform=transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ]
    ),
)

testing_dataloader = torch.utils.data.DataLoader(
    testing_dataset,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=2,
    drop_last=True,
)  # set drop_last=True to drop the last incomplete batch for static shape `torchtrt.dynamo.compile()`

with torch.no_grad():
    with export_torch_mode():
        # Compile the model with Torch-TensorRT Dynamo backend
        input_tensor = images.cuda()
        # torch.export.export() failed due to RuntimeError: Attempting to use FunctionalTensor on its own. Instead, please use it with a corresponding FunctionalTensorMode()
        from torch.export._trace import _export

        exp_program = _export(model, (input_tensor,))
        if args.quantize_type == "int8":
            enabled_precisions = {torch.int8}
        elif args.quantize_type == "fp8":
            enabled_precisions = {torch.float8_e4m3fn}
        trt_model = torchtrt.dynamo.compile(
            exp_program,
            inputs=[input_tensor],
            enabled_precisions=enabled_precisions,
            min_block_size=1,
            debug=True,
        )
        # You can also use torch compile path to compile the model with Torch-TensorRT:
        # trt_model = torch.compile(model, backend="tensorrt")

        # Inference compiled Torch-TensorRT model over the testing dataset
        total = 0
        correct = 0
        loss = 0.0
        class_probs = []
        class_preds = []
        for data, labels in testing_dataloader:
            data, labels = data.cuda(), labels.cuda(non_blocking=True)
            out = trt_model(data)
            loss += crit(out, labels)
            preds = torch.max(out, 1)[1]
            class_probs.append([F.softmax(i, dim=0) for i in out])
            class_preds.append(preds)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

        test_probs = torch.cat([torch.stack(batch) for batch in class_probs])
        test_preds = torch.cat(class_preds)
        test_loss = loss / total
        test_acc = correct / total
        print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))

I then run (vgg16_bn-6c64b313.pth is the locally downloaded checkpoint):

python test.py --ckpt vgg16_bn-6c64b313.pth

And get error:

torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)

from user code:
   File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/modelopt/torch/opt/dynamic.py", line 376, in _get_dm_attribute_manager
    return self._dm_attribute_manager

Full trace

[WARNING  | root               ]: Supported flash-attn versions are >= 2.1.1, <= 2.6.3. Found flash-attn 2.7.4.post1.
[WARNING  | torch_tensorrt.dynamo.conversion.converter_utils]: TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops
[02/27/2025-05:44:05] [TRT] [W] Functionality provided through tensorrt.plugin module is experimental.
Inserted 86 quantizers
PTQ Loss: 0.18583 Acc: 0.00%
Loading extension modelopt_cuda_ext...
Traceback (most recent call last):
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 263, in __call__
    self.call_reconstruct(value)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 90, in call_reconstruct
    res = value.reconstruct(self)
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/variables/base.py", line 358, in reconstruct
    raise NotImplementedError
NotImplementedError

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ubuntu/test.py", line 199, in <module>
    exp_program = _export(model, (input_tensor,))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1035, in wrapper
    raise e
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1008, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/exported_program.py", line 128, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1990, in _export
    export_artifact = export_func(  # type: ignore[operator]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1255, in _strict_export
    return _strict_export_lower_to_aten_ir(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 1283, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/export/_trace.py", line 662, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
                        ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1569, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 574, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/test.py", line 57, in forward
    def forward(self, x):
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/modelopt/torch/opt/dynamic.py", line 786, in __getattr__
    manager = self._get_dm_attribute_manager(use_default=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1380, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 547, in __call__
    return _compile(
           ^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 986, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 715, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_utils_internal.py", line 95, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 750, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1361, in transform_code_object
    transformations(instructions, code_options)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 231, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 662, in transform
    tracer.run()
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2868, in run
    super().run()
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 1052, in run
    while self.step():
          ^^^^^^^^^^^
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 962, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3048, in RETURN_VALUE
    self._return(inst)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 3033, in _return
    self.output.compile_subgraph(
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1120, in compile_subgraph
    self.codegen_suffix(tx, stack_values, pass1)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1193, in codegen_suffix
    cg.restore_stack(stack_values, value_from_source=not tx.export)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 82, in restore_stack
    self.foreach(stack_values)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 293, in foreach
    self(i)
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/codegen.py", line 265, in __call__
    unimplemented(f"reconstruct: {value}")
  File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 317, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: reconstruct: UserDefinedObjectVariable(_DMAttributeManager)

from user code:
   File "/opt/conda/envs/pytorch/lib/python3.11/site-packages/modelopt/torch/opt/dynamic.py", line 376, in _get_dm_attribute_manager
    return self._dm_attribute_manager

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Expected behavior

Should run correctly.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.6.0
  • PyTorch Version (e.g. 1.0): 2.6.0
  • CPU Architecture: x86
  • OS (e.g., Linux): Ubuntu 22.04.5 LTS
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source): pip install torch_tensorrt
  • Are you using local sources or building from archives:
  • Python version: 3.11.10
  • CUDA version: 550.127.05
  • GPU models and configuration: Nvidia L4
  • Any other relevant information:
@dragoneye-alex dragoneye-alex added the bug Something isn't working label Feb 27, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants