Skip to content

❓ [Question] Excessive GPU memory usage during constant folding pass causes OOM (Torch-TensorRT + Qwen-Image) #4128

@Tongkaio

Description

@Tongkaio

❓ Question

When compiling and running the Qwen-Image model with Torch-TensorRT, the peak GPU memory usage exceeds 95GB, resulting in CUDA out-of-memory (OOM) errors.

For comparison, using the Inductor backend only requires around 64GB peak VRAM. After debugging, the issue appears to originate primarily from the constant folding pass:

@torch.utils._python_dispatch._disable_current_modes() # type: ignore
def constant_fold(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Adapted from:
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
Folds constants in the graph module, not skipping constructors
Modifies the graph in-place and replaces node with constants
"""
cf = _TorchTensorRTConstantFolder(gm, skip_constructors=False)
cf.run()
# The constants are created on CPU to save GPU memory for TensorRT compilation.
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
for node, constant in cf.node_replacements.items():
if settings.offload_module_to_cpu:
replace_node_with_constant(
gm,
node,
torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False),
)
else:
replace_node_with_constant(
gm, node, torch.nn.Parameter(constant, requires_grad=False)
)
erased_params = []
for node in gm.graph.nodes:
# If get_attr node has no users, mark it for deletion
if node.op == "get_attr" and len(node.users) == 0:
erased_params.append(node)
# Remove unused nodes from the graph
for node in erased_params:
gm.graph.erase_node(node)
gm = clean_up_graph_after_modifications(gm)
# Delete the constant folder instance which holds GPU memory
del cf
logger.debug(f"Graph after constant folding:\n{gm.graph}")
return gm

What you have already tried

Setting offload_module_to_cpu=True successfully bypasses the problematic constant folding behavior and avoids the GPU OOM. However, this workaround significantly increases host (system RAM) memory consumption to an unacceptable level.

Environment

  • PyTorch Version:2.9
  • CPU Architecture: x86
  • OS (e.g., Linux): Linux
  • Python version: 3.12
  • CUDA version: 12.9

Are there any better / recommended ways to reduce peak GPU memory (or host memory) usage during the lowering / compilation phase?

Is there any plan in future releases to optimize the constant folding pass for lower memory footprint?

Any suggestions or insights would be greatly appreciated!

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions