-
Notifications
You must be signed in to change notification settings - Fork 42
Description
I am trying to run the ActionVideo2WorldInference pipeline locally using the dreamdojo_2b_480_640_gr1 checkpoint and LAM_400k.ckpt. I specifically bypassed the 14B Text Encoder to save memory, loading only the DiT, LAM, and VAE. The static weights fit perfectly into ~4.28 GB of VRAM.
However, when running the generation loop, I am hitting a hard 24GB VRAM bottleneck caused by the Wan2.1 VAE 3D Convolutions, and I am trapped in a geometric constraint loop trying to bypass it.
Environment:
GPU: 1x 24GB VRAM (e.g., RTX 3090/4090)
Pipeline: ActionVideo2WorldInference (Text encoding bypassed)
Precision: bfloat16 with sdp_kernel (FlashAttention) enabled.
Steps Taken & Architectural Trap:
Attempt 1: Native Dimensions (12 frames @ 480x640)
Running the standard script causes an immediate CUDA out of memory error exactly when the 12-frame tensor hits the 3D Convolution (F.conv3d) inside the Wan2.1 VAE encoder (wan2pt1.py). The activation scratchpad spikes VRAM usage past 25GB+.
Attempt 2: Reduce Temporal Length (<4 frames @ 480x640)
To save VRAM, I reduced chunk_size to 2 and 3 frames. This successfully bypassed the OOM, but crashed the VAE's hardcoded temporal downsampling math:
Plaintext
RuntimeError: Calculated padded input size per channel: (2 x 60 x 80). Kernel size: (3 x 1 x 1). Kernel size can't be greater than actual input size
Conclusion: The VAE mathematically requires at least 4-5 frames to satisfy the 3D convolution scanner.
Attempt 3: Reduce Spatial Canvas (12 frames @ 256x384)
To satisfy the VAE's 5+ frame requirement while keeping the tensor small enough for 24GB, I resized the seed_frame to 256x384 (a safe DiT multiple of 16). The generation loop completed all 35 steps successfully!
However, because the pre-trained DiT weights have absolute 2D positional embeddings mapped strictly for a 480x640 grid (1,200 patches), the math collapsed into NaN, outputting a pure black video (video_clamped.max() == 0).
Attempt 4: CPU Offloading (The Hack)
I used torch.inference_mode() and intercepted encode() and decode() in text2world_model_rectified_flow.py to forcefully move self.tokenizer.model.encoder to the CPU right before the VAE pass. While this prevented the GPU OOM, the 12-frame 3D convolution resulted in massive CPU/Swap thrashing, causing an infinite hang.
Questions for the Maintainers:
It appears that generating 5+ frames at 480x640 physically requires more than 24GB of VRAM just for the intermediate attention/convolution activations.
VRAM Slicing: Are there any recommended deep-level PyTorch tricks (like DeepSpeed ZeRO-Offload, VAE spatial-tiling, or forced sequential encoding) to slice the Wan2.1 VAE activations so the 480x640 tensor fits in 24GB?
Positional Interpolation: Is there a supported parameter to dynamically interpolate the DiT's positional embeddings at runtime so it can accept a 256x384 input without outputting pure black pixels?
Future Checkpoints: Are there any plans to release a native 256x256 Cosmos/DreamDojo checkpoint for local consumer research?
Thank you for open-sourcing this incredible framework! Any guidance on fitting this into a single 24GB node would be greatly appreciated.