-
Notifications
You must be signed in to change notification settings - Fork 27
Open
Description
Thank for ur work, it is amazing, but I encounter a critical tensor shape mismatch error during multi-GPU inference when adjusting the two key parameters, could u help me?:
✅ Working config: ulysses_degree = 1 + ring_degree = 2 → Inference runs successfully with torchrun --nproc_per_node=2 infer.py
❌ Error config: ulysses_degree = 2 + ring_degree = 2 → All GPU ranks crash with the same RuntimeError about invalid tensor view shape, and the inference process terminates completely. torchrun --nproc_per_node=4 infer.py
GPU: NVIDIA RTX 5090 (2/4 cards used for inference)
🚨 Full Error Log
plaintext
[rank1]: Traceback (most recent call last):
[rank1]: File "/home/wsx/FlashPortrait/infer.py", line 372, in <module>
[rank1]: sample = pipeline(
[rank1]: File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: File "/home/wsx/FlashPortrait/wan/pipeline/pipeline_wan_long.py", line 760, in __call__
[rank1]: noise_pred_posi = self.transformer(
[rank1]: File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 858, in forward
[rank1]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]: File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/home/wsx/FlashPortrait/wan/utils/cfg_optimization.py", line 32, in wrapper
[rank1]: result = func(self, new_x, *new_args, **new_kwargs)
[rank1]: File "/home/wsx/FlashPortrait/wan/models/wan_transformer3d.py", line 1047, in forward
[rank1]: x = block(x, **kwargs)
[rank1]: File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 858, in forward
[rank1]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]: File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/home/wsx/FlashPortrait/wan/models/wan_transformer3d.py", line 514, in forward
[rank1]: x = cross_attn_ffn(x, context, context_lens, e, emo_proj=emo_proj, emo_context_lens=emo_context_lens, latents_num_frames=latents_num_frames, ip_scale=ip_scale, emo_attn_mask=emo_attn_mask)
[rank1]: File "/home/wsx/FlashPortrait/wan/models/wan_transformer3d.py", line 494, in cross_attn_ffn
[rank1]: x = x + self.cross_attn(self.norm3(x),
[rank1]: File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: File "/home/wsx/FlashPortrait/wan/models/wan_transformer3d.py", line 349, in forward
[rank1]: emo_q = q.view(b * latents_num_frames, -1, n, d)
[rank1]: RuntimeError: shape '[51, -1, 40, 128]' is invalid for input of size 76380160
Note: The exact same error occurs on rank0, rank2, rank3 (all GPU processes).
The mathematical mismatch is clear:
Target shape: 51 × ? × 40 × 128 = 76380160
Calculation: 51 * 40 * 128 = 261120 → 76380160 ÷ 261120 = 292.509... (non-integer)
This proves the tensor dimension calculation logic in wan_transformer3d.py does not account for ulysses_degree=2 and leads to an invalid reshape operation.
at the same: when I use 3 GPU is also not work
i set
GPU_memory_mode = "model_full_load_and_qfloat8"
# Multi GPUs config
# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used.
# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4.
# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1.
ulysses_degree = 2 # when GPU=3 ,it is 1
ring_degree = 2 # when GPU=3 ,it is 3
# Use FSDP to save more GPU memory in multi gpus.
fsdp_dit = True
fsdp_text_encoder = True
# Compile will give a speedup in fixed resolution and need a little GPU memory.
# The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload.
compile_dit = False
# Support TeaCache.
enable_teacache = True
# Recommended to be set between 0.05 and 0.30. A larger threshold can cache more steps, speeding up the inference process,
# but it may cause slight differences between the generated content and the original content.
# # --------------------------------------------------------------------------------------------------- #
# | Model Name | threshold | Model Name | threshold | Model Name | threshold |
# | Wan2.1-T2V-1.3B | 0.05~0.10 | Wan2.1-T2V-14B | 0.10~0.15 | Wan2.1-I2V-14B-720P | 0.20~0.30 |
# | Wan2.1-I2V-14B-480P | 0.20~0.25 | Wan2.1-Fun-*-1.3B-* | 0.05~0.10 | Wan2.1-Fun-*-14B-* | 0.20~0.30 |
# # --------------------------------------------------------------------------------------------------- #
teacache_threshold = 0.10
# The number of steps to skip TeaCache at the beginning of the inference process, which can
# reduce the impact of TeaCache on generated video quality.
num_skip_start_steps = 5
# Whether to offload TeaCache tensors to cpu to save a little bit of GPU memory.
teacache_offload = True
Metadata
Metadata
Assignees
Labels
No labels