Closed
Description
Bug Description
After resolving issues from pytorch/pytorch#147096, a MetadataMismatchError
occurs at _scaled_dot_product_flash_attention
.
Traceback (most recent call last):
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1854, in _maybe_infer_fake
_check_fake_real_tensors(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_utils.py", line 196, in _check_fake_real_tensors
torch._prims.utils.compare_tensor_meta(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 193, in compare_tensor_meta
raise MetadataMismatchError(msg)
torch._subclasses.fake_tensor.MetadataMismatchError: Devices cpu and cuda:0 are not equal!
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/develop/TensorRT/examples/dynamo/torch_export_pg.py", line 151, in <module>
trt_model = torch_tensorrt.dynamo.compile(
File "/develop/TensorRT/py/torch_tensorrt/dynamo/_compiler.py", line 670, in compile
exported_program = exported_program.run_decompositions(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
return fn(*args, **kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 1405, in run_decompositions
return _decompose_exported_program(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 872, in _decompose_exported_program
) = _decompose_and_get_gm_with_new_signature_constants(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 491, in _decompose_and_get_gm_with_new_signature_constants
aten_export_artifact = _export_to_aten_ir(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/_trace.py", line 771, in _export_to_aten_ir
gm, graph_signature = transform(aot_export_module)(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1345, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1584, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
return _create_aot_dispatcher_function(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 671, in _create_aot_dispatcher_function
fw_metadata = run_functionalized_fw_and_collect_metadata(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 197, in inner
flat_f_outs = f(*flat_f_args)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
tree_out = fn(*args, **kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
out = PropagateUnbackedSymInts(mod).run(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/interpreter.py", line 171, in run
self.env[node] = self.run_node(node)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6955, in run_node
result = super().run_node(n)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/interpreter.py", line 236, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/interpreter.py", line 316, in call_function
return target(*args, **kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_ops.py", line 756, in __call__
return self._op(*args, **kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 527, in __torch_dispatch__
outs_unwrapped = func._op_dk(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/utils/_stats.py", line 27, in wrapper
return fn(*args, **kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1269, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1812, in dispatch
return self._dispatch_impl(func, types, args, kwargs)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2388, in _dispatch_impl
return maybe_propagate_real_tensors(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2220, in maybe_propagate_real_tensors
self._maybe_infer_fake_kernel_from_pytree_out(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1959, in _maybe_infer_fake_kernel_from_pytree_out
fake_leaves = [
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1960, in <listcomp>
self._maybe_infer_fake(func, _fake_path, _fake_out, _real_out)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1873, in _maybe_infer_fake
raise MetadataMismatchError(
torch._subclasses.fake_tensor.MetadataMismatchError: Real tensor propagation found a metadata mismatch between fake tensor FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64) and real tensor 140486520641728, at output[6], for func: aten._scaled_dot_product_flash_attention.default
While executing %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%transpose_1, %transpose_2, %transpose_3), kwargs = {scale: 0.11785113019775793})
....
Original traceback:
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/paligemma/modeling_paligemma.py", line 504, in forward
image_features = self.get_image_features(pixel_values)
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1190, in forward
return self.vision_model(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1091, in forward
encoder_outputs = self.encoder(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 902, in forward
layer_outputs = encoder_layer(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 643, in forward
hidden_states, attn_weights = self.self_attn(
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 574, in forward
attn_output = torch.nn.functional.scaled_dot_product_attention(
To Reproduce
Steps to reproduce the behavior:
import torch
import torch_tensorrt
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from transformers.image_utils import load_image
# 1. Model
DEVICE = torch.device("cuda:0")
model_id = "google/paligemma2-3b-pt-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)
model = PaliGemmaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16
).eval().to(DEVICE)
processor = PaliGemmaProcessor.from_pretrained(model_id)
prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(DEVICE)
input_len = model_inputs["input_ids"].shape[-1]
# 2. PyTorch
with torch.inference_mode():
pyt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) #, use_cache=False)
pyt_generation = pyt_generation[0][input_len:]
pyt_decoded = processor.decode(pyt_generation, skip_special_tokens=True)
print("=============================")
print("PyTorch generated text:")
print(pyt_decoded)
print("=============================")
# (a) Dummy inputs
batch_size = 1
dummy_input_ids = model_inputs["input_ids"]
dummy_attention_mask = model_inputs["attention_mask"]
dummy_pixel_values = model_inputs["pixel_values"]
dummy_inputs = {
"input_ids": dummy_input_ids,
"attention_mask": dummy_attention_mask,
"pixel_values": dummy_pixel_values,
}
# (b) Dynamic shape
BATCH = torch.export.Dim("batch", min=1, max=2)
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=1024)
dynamic_shapes = {
"input_ids": {0: BATCH, 1: SEQ_LEN},
"attention_mask": {0: BATCH, 1: SEQ_LEN},
"pixel_values": {0: BATCH},
}
# (c) ExportedProgram
# torch.export.export(
# model,
# args=(),
# kwargs=dummy_inputs,
# dynamic_shapes=dynamic_shapes,
# strict=False,
# )
import torch
import torch.utils._pytree as pytree
import transformers
def flatten_hybridcache(hc: transformers.cache_utils.HybridCache):
flat_tensors = []
flat_tensors.append(hc.is_sliding) # shape: [num_hidden_layers], bool
flat_tensors.extend(hc.key_cache) # List[Tensor]
flat_tensors.extend(hc.value_cache) # List[Tensor]
context = {
"max_cache_len": hc.max_cache_len,
"max_batch_size": hc.max_batch_size,
"head_dim": hc.head_dim,
"dtype": hc.dtype,
"num_key_value_heads": hc.num_key_value_heads,
"num_layers": len(hc.key_cache), # = len(hc.value_cache) = config.num_hidden_layers
}
return flat_tensors, context
def unflatten_hybridcache(flat_tensors, context):
num_layers = context["num_layers"]
is_sliding = flat_tensors[0]
key_cache = flat_tensors[1 : 1 + num_layers]
value_cache = flat_tensors[1 + num_layers : 1 + 2*num_layers]
hc = transformers.cache_utils.HybridCache.__new__(transformers.cache_utils.HybridCache)
hc.max_cache_len = context["max_cache_len"]
hc.max_batch_size = context["max_batch_size"]
hc.head_dim = context["head_dim"]
hc.dtype = context["dtype"]
hc.num_key_value_heads = context["num_key_value_heads"]
hc.is_sliding = is_sliding
hc.key_cache = list(key_cache)
hc.value_cache = list(value_cache)
return hc
# pytree register
pytree.register_pytree_node(
transformers.cache_utils.HybridCache,
flatten_hybridcache,
unflatten_hybridcache
)
# from torch.export._trace import _export
# exported_program = _export(
# model,
# args=(),
# kwargs=dummy_inputs,
# dynamic_shapes=dynamic_shapes,
# strict=False,
# allow_complex_guards_as_runtime_asserts=True,
# )
# torch.export._draft_export.draft_export
import torch.export._draft_export
exported_program = torch.export._draft_export.draft_export(
model,
args=(),
kwargs=dummy_inputs,
dynamic_shapes=dynamic_shapes,
strict=False,
# allow_complex_guards_as_runtime_asserts=True,
)
trt_model = torch_tensorrt.dynamo.compile(
exported_program[0],
inputs=dummy_inputs,
enabled_precisions={torch.float32},
truncate_double=True,
device=DEVICE,
disable_tf32=True,
use_explicit_typing=True,
use_fp32_acc=True,
)
# TensorRT
model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
with torch.inference_mode():
trt_generation = trt_model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
trt_generation = trt_generation[0][input_len:]
trt_decoded = processor.decode(trt_generation, skip_special_tokens=True)
print("TensorRT generated text:")
print(trt_decoded)
Environment
pytorch-triton 3.2.0+git4b3bb1f8
torch 2.7.0.dev20250207+cu124
torch-tensorrt 2.7.0.dev0+5a4dd33ef /develop/TensorRT/py
torchvision 0.22.0.dev20250207+cu124