We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attempting torch.compile (backend = torch_tensorrt) the google/paligemma2-3b-pt-224 model, there is an unsupported ops like below:
google/paligemma2-3b-pt-224
DEBUG:torch_tensorrt.dynamo.backend.backends:Pre-AOT Autograd graph: graph(): %l_image_features_ : torch.Tensor [num_users=1] = placeholder[target=L_image_features_] %l_inputs_embeds_ : torch.Tensor [num_users=1] = placeholder[target=L_inputs_embeds_] %l_special_image_mask_ : torch.Tensor [num_users=1] = placeholder[target=L_special_image_mask_] %image_features : [num_users=2] = call_method[target=to](args = (%l_image_features_, cuda:0, torch.float16), kwargs = {}) %inputs_embeds : [num_users=1] = call_method[target=masked_scatter](args = (%l_inputs_embeds_, %l_special_image_mask_, %image_features), kwargs = {}) return (inputs_embeds, image_features) DEBUG:torch_tensorrt.dynamo.lowering.passes.repair_input_aliasing:Inserted auxiliary clone nodes for placeholders: graph(): %l_image_features_ : torch.Tensor [num_users=1] = placeholder[target=L_image_features_] %l_inputs_embeds_ : torch.Tensor [num_users=1] = placeholder[target=L_inputs_embeds_] %l_special_image_mask_ : torch.Tensor [num_users=1] = placeholder[target=L_special_image_mask_] %clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_special_image_mask_,), kwargs = {}) %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_inputs_embeds_,), kwargs = {}) %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_image_features_,), kwargs = {}) %image_features : [num_users=2] = call_method[target=to](args = (%clone_default, cuda:0, torch.float16), kwargs = {}) %inputs_embeds : [num_users=1] = call_method[target=masked_scatter](args = (%clone_default_1, %clone_default_2, %image_features), kwargs = {}) return (inputs_embeds, image_features) DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_sym_nodes:Removed SymInt placeholders: graph(): %l_image_features_ : torch.Tensor [num_users=1] = placeholder[target=L_image_features_] %l_inputs_embeds_ : torch.Tensor [num_users=1] = placeholder[target=L_inputs_embeds_] %l_special_image_mask_ : torch.Tensor [num_users=1] = placeholder[target=L_special_image_mask_] %clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_special_image_mask_,), kwargs = {}) %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_inputs_embeds_,), kwargs = {}) %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_image_features_,), kwargs = {}) %image_features : [num_users=2] = call_method[target=to](args = (%clone_default, cuda:0, torch.float16), kwargs = {}) %inputs_embeds : [num_users=1] = call_method[target=masked_scatter](args = (%clone_default_1, %clone_default_2, %image_features), kwargs = {}) return (inputs_embeds, image_features) DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes: graph(): %l_image_features_ : torch.Tensor [num_users=1] = placeholder[target=L_image_features_] %l_inputs_embeds_ : torch.Tensor [num_users=1] = placeholder[target=L_inputs_embeds_] %l_special_image_mask_ : torch.Tensor [num_users=1] = placeholder[target=L_special_image_mask_] %clone_default_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_special_image_mask_,), kwargs = {}) %clone_default_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_inputs_embeds_,), kwargs = {}) %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_image_features_,), kwargs = {}) %image_features : [num_users=2] = call_method[target=to](args = (%clone_default, cuda:0, torch.float16), kwargs = {}) %inputs_embeds : [num_users=1] = call_method[target=masked_scatter](args = (%clone_default_1, %clone_default_2, %image_features), kwargs = {}) return (inputs_embeds, image_features) DEBUG:torch_tensorrt.dynamo.backend.backends:Post-AOT Autograd graph: graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg2_1,), kwargs = {}) %clone_1 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg1_1,), kwargs = {}) %clone_2 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {}) %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%clone_2,), kwargs = {dtype: torch.float16, device: cuda:0}) %masked_scatter : [num_users=1] = call_function[target=torch.ops.aten.masked_scatter.default](args = (%clone_1, %clone, %_to_copy), kwargs = {}) return (masked_scatter, _to_copy) DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removing node clone_2 from graph, since it is a clone node which is the only user of placeholder arg0_1 and was inserted by the compiler. DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removing node clone_1 from graph, since it is a clone node which is the only user of placeholder arg1_1 and was inserted by the compiler. DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removing node clone from graph, since it is a clone node which is the only user of placeholder arg2_1 and was inserted by the compiler. DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removed auxiliary clone nodes for placeholders: graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {dtype: torch.float16, device: cuda:0}) %masked_scatter : [num_users=1] = call_function[target=torch.ops.aten.masked_scatter.default](args = (%arg1_1, %arg2_1, %_to_copy), kwargs = {}) return (masked_scatter, _to_copy) DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding: graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {dtype: torch.float16, device: cuda:0}) %masked_scatter : [num_users=1] = call_function[target=torch.ops.aten.masked_scatter.default](args = (%arg1_1, %arg2_1, %_to_copy), kwargs = {}) return (masked_scatter, _to_copy) DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes: graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {dtype: torch.float16, device: cuda:0}) %masked_scatter : [num_users=1] = call_function[target=torch.ops.aten.masked_scatter.default](args = (%arg1_1, %arg2_1, %_to_copy), kwargs = {}) return (masked_scatter, _to_copy) DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings DEBUG:torch_tensorrt.dynamo.backend.backends:Lowered Input graph: graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %arg2_1 : [num_users=1] = placeholder[target=arg2_1] %_to_copy : [num_users=2] = call_function[target=torch.ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {dtype: torch.float16, device: cuda:0}) %masked_scatter : [num_users=1] = call_function[target=torch.ops.aten.masked_scatter.default](args = (%arg1_1, %arg2_1, %_to_copy), kwargs = {}) return (masked_scatter, _to_copy) DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten._to_copy.default: 2 DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Skipping option 0 for aten._to_copy.default: (validator: False, supports dynamic shapes: True) DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 1 for converting aten._to_copy.default DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: Supported Nodes: - torch.ops.aten._to_copy.default + Operator Count: 1 DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner: Unsupported or Excluded Nodes: - torch.ops.aten.masked_scatter.default + Operator Count: 1 DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 2 in subgraph. INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten._to_copy.default: 2 DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Skipping option 0 for aten._to_copy.default: (validator: False, supports dynamic shapes: True) DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 1 for converting aten._to_copy.default DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: Number of TensorRT-Accelerated Engines Generated: 1 DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: Supported Nodes: - torch.ops.aten._to_copy.default + Operator Count: 1 DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner: Unsupported or Excluded Nodes: - torch.ops.aten.masked_scatter.default + Operator Count: 1
Steps to reproduce the behavior:
import torch import torch_tensorrt from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration from transformers.image_utils import load_image DEVICE = "cuda:0" model_id = "google/paligemma2-3b-pt-224" url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg" image = load_image(url) model = PaliGemmaForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.float16).eval() model.to(DEVICE).to(torch.float16) # model.forward = model.forward.to(torch.float16).eval() processor = PaliGemmaProcessor.from_pretrained(model_id) prompt = "" model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.float16).to(DEVICE) # to(DEVICE) # .to(torch.float16).to(DEVICE) input_len = model_inputs["input_ids"].shape[-1] # model.config.token_healing = False with torch.inference_mode(): pyt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) pyt_generation_out = pyt_generation[0][input_len:] pyt_decoded = processor.decode(pyt_generation_out, skip_special_tokens=True) print("=============================") print("pyt_generation whole text:") print(pyt_generation) print("=============================") print("=============================") print("PyTorch generated text:") print(pyt_decoded) print("=============================") with torch_tensorrt.logging.debug(): torch._dynamo.mark_dynamic(model_inputs["input_ids"], 1, min=2, max=1023) model.forward = torch.compile( model.forward, backend="tensorrt", dynamic=None, options={ "enabled_precisions": {torch.float16}, "disable_tf32": True, "min_block_size": 1, # "use_explicit_typing": True, # "use_fp32_acc": True, "debug": True, # "use_aot_joint_export":False, }, ) with torch.inference_mode(): trt_generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False) trt_generation_out = trt_generation[0][input_len:] trt_decoded = processor.decode(trt_generation_out, skip_special_tokens=True) print(trt_generation) print("TensorRT generated text:") print(trt_decoded)
The text was updated successfully, but these errors were encountered:
chohk88
No branches or pull requests
Bug Description
Attempting torch.compile (backend = torch_tensorrt) the
google/paligemma2-3b-pt-224
model, there is an unsupported ops like below:To Reproduce
Steps to reproduce the behavior:
The text was updated successfully, but these errors were encountered: