Skip to content

Bypassing the 24GB VRAM Wall for ActionVideo2WorldInference (Cosmos 2B + GR1) #6

@prateek4robotics

Description

@prateek4robotics

I am trying to run the ActionVideo2WorldInference pipeline locally using the dreamdojo_2b_480_640_gr1 checkpoint and LAM_400k.ckpt. I specifically bypassed the 14B Text Encoder to save memory, loading only the DiT, LAM, and VAE. The static weights fit perfectly into ~4.28 GB of VRAM.

However, when running the generation loop, I am hitting a hard 24GB VRAM bottleneck caused by the Wan2.1 VAE 3D Convolutions, and I am trapped in a geometric constraint loop trying to bypass it.

Environment:

GPU: 1x 24GB VRAM (e.g., RTX 3090/4090)

Pipeline: ActionVideo2WorldInference (Text encoding bypassed)

Precision: bfloat16 with sdp_kernel (FlashAttention) enabled.

Steps Taken & Architectural Trap:

Attempt 1: Native Dimensions (12 frames @ 480x640)
Running the standard script causes an immediate CUDA out of memory error exactly when the 12-frame tensor hits the 3D Convolution (F.conv3d) inside the Wan2.1 VAE encoder (wan2pt1.py). The activation scratchpad spikes VRAM usage past 25GB+.

Attempt 2: Reduce Temporal Length (<4 frames @ 480x640)
To save VRAM, I reduced chunk_size to 2 and 3 frames. This successfully bypassed the OOM, but crashed the VAE's hardcoded temporal downsampling math:

Plaintext
RuntimeError: Calculated padded input size per channel: (2 x 60 x 80). Kernel size: (3 x 1 x 1). Kernel size can't be greater than actual input size
Conclusion: The VAE mathematically requires at least 4-5 frames to satisfy the 3D convolution scanner.

Attempt 3: Reduce Spatial Canvas (12 frames @ 256x384)
To satisfy the VAE's 5+ frame requirement while keeping the tensor small enough for 24GB, I resized the seed_frame to 256x384 (a safe DiT multiple of 16). The generation loop completed all 35 steps successfully!
However, because the pre-trained DiT weights have absolute 2D positional embeddings mapped strictly for a 480x640 grid (1,200 patches), the math collapsed into NaN, outputting a pure black video (video_clamped.max() == 0).

Attempt 4: CPU Offloading (The Hack)
I used torch.inference_mode() and intercepted encode() and decode() in text2world_model_rectified_flow.py to forcefully move self.tokenizer.model.encoder to the CPU right before the VAE pass. While this prevented the GPU OOM, the 12-frame 3D convolution resulted in massive CPU/Swap thrashing, causing an infinite hang.

Questions for the Maintainers:

It appears that generating 5+ frames at 480x640 physically requires more than 24GB of VRAM just for the intermediate attention/convolution activations.

VRAM Slicing: Are there any recommended deep-level PyTorch tricks (like DeepSpeed ZeRO-Offload, VAE spatial-tiling, or forced sequential encoding) to slice the Wan2.1 VAE activations so the 480x640 tensor fits in 24GB?

Positional Interpolation: Is there a supported parameter to dynamically interpolate the DiT's positional embeddings at runtime so it can accept a 256x384 input without outputting pure black pixels?

Future Checkpoints: Are there any plans to release a native 256x256 Cosmos/DreamDojo checkpoint for local consumer research?

Thank you for open-sourcing this incredible framework! Any guidance on fitting this into a single 24GB node would be greatly appreciated.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions