Skip to content

RuntimeError: Invalid tensor shape when setting ulysses_degree=2 and ring_degree=2 with multi-GPU inference (RTX 5090) #6

@ChnWuyue

Description

@ChnWuyue

Thank for ur work, it is amazing, but I encounter a critical tensor shape mismatch error during multi-GPU inference when adjusting the two key parameters, could u help me?:
✅ Working config: ulysses_degree = 1 + ring_degree = 2 → Inference runs successfully with torchrun --nproc_per_node=2 infer.py
❌ Error config: ulysses_degree = 2 + ring_degree = 2 → All GPU ranks crash with the same RuntimeError about invalid tensor view shape, and the inference process terminates completely. torchrun --nproc_per_node=4 infer.py

GPU: NVIDIA RTX 5090 (2/4 cards used for inference)

🚨 Full Error Log
plaintext
[rank1]: Traceback (most recent call last):
[rank1]:   File "/home/wsx/FlashPortrait/infer.py", line 372, in <module>
[rank1]:     sample = pipeline(
[rank1]:   File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 124, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/home/wsx/FlashPortrait/wan/pipeline/pipeline_wan_long.py", line 760, in __call__
[rank1]:     noise_pred_posi = self.transformer(
[rank1]:   File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 858, in forward
[rank1]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]:   File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/wsx/FlashPortrait/wan/utils/cfg_optimization.py", line 32, in wrapper
[rank1]:     result = func(self, new_x, *new_args, **new_kwargs)
[rank1]:   File "/home/wsx/FlashPortrait/wan/models/wan_transformer3d.py", line 1047, in forward
[rank1]:     x = block(x, **kwargs)
[rank1]:   File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 858, in forward
[rank1]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]:   File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/wsx/FlashPortrait/wan/models/wan_transformer3d.py", line 514, in forward
[rank1]:     x = cross_attn_ffn(x, context, context_lens, e, emo_proj=emo_proj, emo_context_lens=emo_context_lens, latents_num_frames=latents_num_frames, ip_scale=ip_scale, emo_attn_mask=emo_attn_mask)
[rank1]:   File "/home/wsx/FlashPortrait/wan/models/wan_transformer3d.py", line 494, in cross_attn_ffn
[rank1]:     x = x + self.cross_attn(self.norm3(x),
[rank1]:   File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1783, in _wrapped_call_impl
[rank1]:     return self._call_impl(*args, **kwargs)
[rank1]:   File "/home/shared/miniconda3/envs/wsx-lyra2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1794, in _call_impl
[rank1]:     return forward_call(*args, **kwargs)
[rank1]:   File "/home/wsx/FlashPortrait/wan/models/wan_transformer3d.py", line 349, in forward
[rank1]:     emo_q = q.view(b * latents_num_frames, -1, n, d)
[rank1]: RuntimeError: shape '[51, -1, 40, 128]' is invalid for input of size 76380160
Note: The exact same error occurs on rank0, rank2, rank3 (all GPU processes).

The mathematical mismatch is clear:
Target shape: 51 × ? × 40 × 128 = 76380160
Calculation: 51 * 40 * 128 = 261120 → 76380160 ÷ 261120 = 292.509... (non-integer)
This proves the tensor dimension calculation logic in wan_transformer3d.py does not account for ulysses_degree=2 and leads to an invalid reshape operation.

at the same: when I use 3 GPU is also not work
i set

GPU_memory_mode     = "model_full_load_and_qfloat8"
# Multi GPUs config
# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used.
# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4.
# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1.
ulysses_degree      = 2 # when GPU=3 ,it is 1
ring_degree         = 2 # when GPU=3 ,it is 3
# Use FSDP to save more GPU memory in multi gpus.
fsdp_dit            = True
fsdp_text_encoder   = True
# Compile will give a speedup in fixed resolution and need a little GPU memory.
# The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload.
compile_dit         = False

# Support TeaCache.
enable_teacache     = True
# Recommended to be set between 0.05 and 0.30. A larger threshold can cache more steps, speeding up the inference process,
# but it may cause slight differences between the generated content and the original content.
# # --------------------------------------------------------------------------------------------------- #
# | Model Name          | threshold | Model Name          | threshold | Model Name          | threshold |
# | Wan2.1-T2V-1.3B     | 0.05~0.10 | Wan2.1-T2V-14B      | 0.10~0.15 | Wan2.1-I2V-14B-720P | 0.20~0.30 |
# | Wan2.1-I2V-14B-480P | 0.20~0.25 | Wan2.1-Fun-*-1.3B-* | 0.05~0.10 | Wan2.1-Fun-*-14B-*  | 0.20~0.30 |
# # --------------------------------------------------------------------------------------------------- #
teacache_threshold  = 0.10
# The number of steps to skip TeaCache at the beginning of the inference process, which can
# reduce the impact of TeaCache on generated video quality.
num_skip_start_steps = 5
# Whether to offload TeaCache tensors to cpu to save a little bit of GPU memory.
teacache_offload    = True

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions