Skip to content

🐛 [Bug] torch._subclasses.fake_tensor.MetadataMismatchError: Devices cpu and cuda:0 are not equal! (_scaled_dot_product_flash_attention) #3408

Closed
@chohk88

Description

@chohk88

Bug Description

After resolving issues from pytorch/pytorch#147096, a MetadataMismatchError occurs at _scaled_dot_product_flash_attention.

Traceback (most recent call last):
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1854, in _maybe_infer_fake
    _check_fake_real_tensors(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_utils.py", line 196, in _check_fake_real_tensors
    torch._prims.utils.compare_tensor_meta(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 193, in compare_tensor_meta
    raise MetadataMismatchError(msg)
torch._subclasses.fake_tensor.MetadataMismatchError: Devices cpu and cuda:0 are not equal!

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/develop/TensorRT/examples/dynamo/torch_export_pg.py", line 151, in <module>
    trt_model = torch_tensorrt.dynamo.compile(
  File "/develop/TensorRT/py/torch_tensorrt/dynamo/_compiler.py", line 670, in compile
    exported_program = exported_program.run_decompositions(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 121, in wrapper
    return fn(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 1405, in run_decompositions
    return _decompose_exported_program(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 872, in _decompose_exported_program
    ) = _decompose_and_get_gm_with_new_signature_constants(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/exported_program.py", line 491, in _decompose_and_get_gm_with_new_signature_constants
    aten_export_artifact = _export_to_aten_ir(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/export/_trace.py", line 771, in _export_to_aten_ir
    gm, graph_signature = transform(aot_export_module)(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1345, in aot_export_module
    fx_g, metadata, in_spec, out_spec = _aot_export_function(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 1584, in _aot_export_function
    fx_g, meta = create_aot_dispatcher_function(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 570, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 671, in _create_aot_dispatcher_function
    fw_metadata = run_functionalized_fw_and_collect_metadata(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/collect_metadata_analysis.py", line 197, in inner
    flat_f_outs = f(*flat_f_args)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn
    tree_out = fn(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 875, in functional_call
    out = PropagateUnbackedSymInts(mod).run(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/interpreter.py", line 171, in run
    self.env[node] = self.run_node(node)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/experimental/symbolic_shapes.py", line 6955, in run_node
    result = super().run_node(n)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/interpreter.py", line 236, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/fx/interpreter.py", line 316, in call_function
    return target(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_ops.py", line 756, in __call__
    return self._op(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 527, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/utils/_stats.py", line 27, in wrapper
    return fn(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1269, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1812, in dispatch
    return self._dispatch_impl(func, types, args, kwargs)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2388, in _dispatch_impl
    return maybe_propagate_real_tensors(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2220, in maybe_propagate_real_tensors
    self._maybe_infer_fake_kernel_from_pytree_out(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1959, in _maybe_infer_fake_kernel_from_pytree_out
    fake_leaves = [
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1960, in <listcomp>
    self._maybe_infer_fake(func, _fake_path, _fake_out, _real_out)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1873, in _maybe_infer_fake
    raise MetadataMismatchError(
torch._subclasses.fake_tensor.MetadataMismatchError: Real tensor propagation found a metadata mismatch between fake tensor FakeTensor(..., device='cuda:0', size=(), dtype=torch.int64) and real tensor 140486520641728,  at output[6], for func: aten._scaled_dot_product_flash_attention.default

While executing %_scaled_dot_product_flash_attention : [num_users=1] = call_function[target=torch.ops.aten._scaled_dot_product_flash_attention.default](args = (%transpose_1, %transpose_2, %transpose_3), kwargs = {scale: 0.11785113019775793})

....

Original traceback:
File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/paligemma/modeling_paligemma.py", line 504, in forward
    image_features = self.get_image_features(pixel_values)
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1190, in forward
    return self.vision_model(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 1091, in forward
    encoder_outputs = self.encoder(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 902, in forward
    layer_outputs = encoder_layer(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 643, in forward
    hidden_states, attn_weights = self.self_attn(
  File "/root/.pyenv/versions/3.10.16/lib/python3.10/site-packages/transformers/models/siglip/modeling_siglip.py", line 574, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(


To Reproduce

Steps to reproduce the behavior:

import torch
import torch_tensorrt
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
from transformers.image_utils import load_image


# 1. Model
DEVICE = torch.device("cuda:0")
model_id = "google/paligemma2-3b-pt-224"
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg"
image = load_image(url)

model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id, torch_dtype=torch.float16
).eval().to(DEVICE)
processor = PaliGemmaProcessor.from_pretrained(model_id)

prompt = ""
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(DEVICE)
input_len = model_inputs["input_ids"].shape[-1]

# 2. PyTorch
with torch.inference_mode():
    pyt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) #, use_cache=False)
    pyt_generation = pyt_generation[0][input_len:]
    pyt_decoded = processor.decode(pyt_generation, skip_special_tokens=True)
    print("=============================")
    print("PyTorch generated text:")
    print(pyt_decoded)
    print("=============================")

# (a) Dummy inputs  
batch_size = 1
dummy_input_ids = model_inputs["input_ids"] 
dummy_attention_mask = model_inputs["attention_mask"] 
dummy_pixel_values = model_inputs["pixel_values"]

dummy_inputs = {
    "input_ids": dummy_input_ids,
    "attention_mask": dummy_attention_mask,
    "pixel_values": dummy_pixel_values,
}

# (b) Dynamic shape 
BATCH = torch.export.Dim("batch", min=1, max=2)
SEQ_LEN = torch.export.Dim("seq_len", min=1, max=1024)
dynamic_shapes = {
    "input_ids": {0: BATCH, 1: SEQ_LEN},
    "attention_mask": {0: BATCH, 1: SEQ_LEN},
    "pixel_values": {0: BATCH},
}
# (c) ExportedProgram  
# torch.export.export(
#     model,
#     args=(),
#     kwargs=dummy_inputs,
#     dynamic_shapes=dynamic_shapes,
#     strict=False,
# )


import torch
import torch.utils._pytree as pytree
import transformers

def flatten_hybridcache(hc: transformers.cache_utils.HybridCache):
    flat_tensors = []
    flat_tensors.append(hc.is_sliding)               # shape: [num_hidden_layers], bool
    flat_tensors.extend(hc.key_cache)                # List[Tensor]
    flat_tensors.extend(hc.value_cache)              # List[Tensor]

    context = {
        "max_cache_len": hc.max_cache_len,
        "max_batch_size": hc.max_batch_size,
        "head_dim": hc.head_dim,
        "dtype": hc.dtype,
        "num_key_value_heads": hc.num_key_value_heads,
        "num_layers": len(hc.key_cache),  # = len(hc.value_cache) = config.num_hidden_layers
    }

    return flat_tensors, context


def unflatten_hybridcache(flat_tensors, context):
    num_layers = context["num_layers"]

    is_sliding = flat_tensors[0]
    key_cache = flat_tensors[1 : 1 + num_layers]
    value_cache = flat_tensors[1 + num_layers : 1 + 2*num_layers]

    hc = transformers.cache_utils.HybridCache.__new__(transformers.cache_utils.HybridCache)

    hc.max_cache_len = context["max_cache_len"]
    hc.max_batch_size = context["max_batch_size"]
    hc.head_dim = context["head_dim"]
    hc.dtype = context["dtype"]
    hc.num_key_value_heads = context["num_key_value_heads"]
    hc.is_sliding = is_sliding
    hc.key_cache = list(key_cache)
    hc.value_cache = list(value_cache)

    return hc

# pytree register
pytree.register_pytree_node(
    transformers.cache_utils.HybridCache,
    flatten_hybridcache,
    unflatten_hybridcache
)

# from torch.export._trace import _export  
# exported_program = _export(
#     model,
#     args=(),
#     kwargs=dummy_inputs,
#     dynamic_shapes=dynamic_shapes,
#     strict=False,
#     allow_complex_guards_as_runtime_asserts=True,
# )

# torch.export._draft_export.draft_export
import torch.export._draft_export
exported_program = torch.export._draft_export.draft_export(
    model,
    args=(),
    kwargs=dummy_inputs,
    dynamic_shapes=dynamic_shapes,
    strict=False,
    # allow_complex_guards_as_runtime_asserts=True,
)


trt_model = torch_tensorrt.dynamo.compile(
    exported_program[0],
    inputs=dummy_inputs,
    enabled_precisions={torch.float32},
    truncate_double=True,
    device=DEVICE,
    disable_tf32=True,
    use_explicit_typing=True,
    use_fp32_acc=True,  
)

# TensorRT
model_inputs = {k: v.to(DEVICE) for k, v in model_inputs.items()}
with torch.inference_mode():
    trt_generation = trt_model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    trt_generation = trt_generation[0][input_len:]
    trt_decoded = processor.decode(trt_generation, skip_special_tokens=True)
    print("TensorRT generated text:")
    print(trt_decoded)

Environment

pytorch-triton 3.2.0+git4b3bb1f8
torch 2.7.0.dev20250207+cu124
torch-tensorrt 2.7.0.dev0+5a4dd33ef /develop/TensorRT/py
torchvision 0.22.0.dev20250207+cu124

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions