-
Notifications
You must be signed in to change notification settings - Fork 384
Description
❓ Question
When compiling and running the Qwen-Image model with Torch-TensorRT, the peak GPU memory usage exceeds 95GB, resulting in CUDA out-of-memory (OOM) errors.
For comparison, using the Inductor backend only requires around 64GB peak VRAM. After debugging, the issue appears to originate primarily from the constant folding pass:
TensorRT/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Lines 22 to 65 in 08bfca2
| @torch.utils._python_dispatch._disable_current_modes() # type: ignore | |
| def constant_fold( | |
| gm: torch.fx.GraphModule, settings: CompilationSettings | |
| ) -> torch.fx.GraphModule: | |
| """Adapted from: | |
| https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197 | |
| Folds constants in the graph module, not skipping constructors | |
| Modifies the graph in-place and replaces node with constants | |
| """ | |
| cf = _TorchTensorRTConstantFolder(gm, skip_constructors=False) | |
| cf.run() | |
| # The constants are created on CPU to save GPU memory for TensorRT compilation. | |
| # For TRT INetwork construction the constants are moved to CPU in get_attr call. | |
| for node, constant in cf.node_replacements.items(): | |
| if settings.offload_module_to_cpu: | |
| replace_node_with_constant( | |
| gm, | |
| node, | |
| torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False), | |
| ) | |
| else: | |
| replace_node_with_constant( | |
| gm, node, torch.nn.Parameter(constant, requires_grad=False) | |
| ) | |
| erased_params = [] | |
| for node in gm.graph.nodes: | |
| # If get_attr node has no users, mark it for deletion | |
| if node.op == "get_attr" and len(node.users) == 0: | |
| erased_params.append(node) | |
| # Remove unused nodes from the graph | |
| for node in erased_params: | |
| gm.graph.erase_node(node) | |
| gm = clean_up_graph_after_modifications(gm) | |
| # Delete the constant folder instance which holds GPU memory | |
| del cf | |
| logger.debug(f"Graph after constant folding:\n{gm.graph}") | |
| return gm |
What you have already tried
Setting offload_module_to_cpu=True successfully bypasses the problematic constant folding behavior and avoids the GPU OOM. However, this workaround significantly increases host (system RAM) memory consumption to an unacceptable level.
Environment
- PyTorch Version:2.9
- CPU Architecture: x86
- OS (e.g., Linux): Linux
- Python version: 3.12
- CUDA version: 12.9
Are there any better / recommended ways to reduce peak GPU memory (or host memory) usage during the lowering / compilation phase?
Is there any plan in future releases to optimize the constant folding pass for lower memory footprint?
Any suggestions or insights would be greatly appreciated!