Skip to content

Commit 617c208

Browse files
authored
[Docs] Update Wan Docs with memory optimizations (#11089)
* update * update
1 parent 5d970a4 commit 617c208

File tree

1 file changed

+354
-11
lines changed
  • docs/source/en/api/pipelines

1 file changed

+354
-11
lines changed

docs/source/en/api/pipelines/wan.md

+354-11
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,357 @@
2222

2323
<!-- TODO(aryan): update abstract once paper is out -->
2424

25-
<Tip>
25+
## Generating Videos with Wan 2.1
26+
27+
We will first need to install some addtional dependencies.
28+
29+
```shell
30+
pip install -u ftfy imageio-ffmpeg imageio
31+
```
32+
33+
### Text to Video Generation
34+
35+
The following example requires 11GB VRAM to run and uses the smaller `Wan-AI/Wan2.1-T2V-1.3B-Diffusers` model. You can switch it out
36+
for the larger `Wan2.1-I2V-14B-720P-Diffusers` or `Wan-AI/Wan2.1-I2V-14B-480P-Diffusers` if you have at least 35GB VRAM available.
2637

27-
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
38+
```python
39+
from diffusers import WanPipeline
40+
from diffusers.utils import export_to_video
41+
42+
# Available models: Wan-AI/Wan2.1-I2V-14B-720P-Diffusers or Wan-AI/Wan2.1-I2V-14B-480P-Diffusers
43+
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
44+
45+
pipe = WanPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
46+
pipe.enable_model_cpu_offload()
47+
48+
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
49+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
50+
num_frames = 33
51+
52+
frames = pipe(prompt=prompt, negative_prompt=negative_prompt, num_frames=num_frames).frames[0]
53+
export_to_video(frames, "wan-t2v.mp4", fps=16)
54+
```
2855

56+
<Tip>
57+
You can improve the quality of the generated video by running the decoding step in full precision.
2958
</Tip>
3059

31-
Recommendations for inference:
32-
- VAE in `torch.float32` for better decoding quality.
33-
- `num_frames` should be of the form `4 * k + 1`, for example `49` or `81`.
34-
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
60+
```python
61+
from diffusers import WanPipeline, AutoencoderKLWan
62+
from diffusers.utils import export_to_video
63+
64+
model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
65+
66+
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
67+
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
68+
69+
# replace this with pipe.to("cuda") if you have sufficient VRAM
70+
pipe.enable_model_cpu_offload()
71+
72+
prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
73+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
74+
num_frames = 33
75+
76+
frames = pipe(prompt=prompt, num_frames=num_frames).frames[0]
77+
export_to_video(frames, "wan-t2v.mp4", fps=16)
78+
```
79+
80+
### Image to Video Generation
81+
82+
The Image to Video pipeline requires loading the `AutoencoderKLWan` and the `CLIPVisionModel` components in full precision. The following example will need at least
83+
35GB of VRAM to run.
84+
85+
```python
86+
import torch
87+
import numpy as np
88+
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
89+
from diffusers.utils import export_to_video, load_image
90+
from transformers import CLIPVisionModel
91+
92+
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
93+
model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
94+
image_encoder = CLIPVisionModel.from_pretrained(
95+
model_id, subfolder="image_encoder", torch_dtype=torch.float32
96+
)
97+
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
98+
pipe = WanImageToVideoPipeline.from_pretrained(
99+
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
100+
)
101+
102+
# replace this with pipe.to("cuda") if you have sufficient VRAM
103+
pipe.enable_model_cpu_offload()
104+
105+
image = load_image(
106+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
107+
)
108+
109+
max_area = 480 * 832
110+
aspect_ratio = image.height / image.width
111+
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
112+
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
113+
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
114+
image = image.resize((width, height))
115+
116+
prompt = (
117+
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
118+
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
119+
)
120+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
121+
122+
num_frames = 33
123+
124+
output = pipe(
125+
image=image,
126+
prompt=prompt,
127+
negative_prompt=negative_prompt,
128+
height=height,
129+
width=width,
130+
num_frames=num_frames,
131+
guidance_scale=5.0,
132+
).frames[0]
133+
export_to_video(output, "wan-i2v.mp4", fps=16)
134+
```
135+
136+
## Memory Optimizations for Wan 2.1
137+
138+
Base inference with the large 14B Wan 2.1 models can take up to 35GB of VRAM when generating videos at 720p resolution. We'll outline a few memory optimizations we can apply to reduce the VRAM required to run the model.
139+
140+
We'll use `Wan-AI/Wan2.1-I2V-14B-720P-Diffusers` model in these examples to demonstrate the memory savings, but the techniques are applicable to all model checkpoints.
141+
142+
### Group Offloading the Transformer and UMT5 Text Encoder
143+
144+
Find more information about group offloading [here](../optimization/memory.md)
145+
146+
#### Block Level Group Offloading
147+
148+
We can reduce our VRAM requirements by applying group offloading to the larger model components of the pipeline; the `WanTransformer3DModel` and `UMT5EncoderModel`. Group offloading will break up the individual modules of a model and offload/onload them onto your GPU as needed during inference. In this example, we'll apply `block_level` offloading, which will group the modules in a model into blocks of size `num_blocks_per_group` and offload/onload them to GPU. Moving to between CPU and GPU does add latency to the inference process. You can trade off between latency and memory savings by increasing or decreasing the `num_blocks_per_group`.
149+
150+
The following example will now only require 14GB of VRAM to run, but will take approximately 30 minutes to generate a video.
151+
152+
```python
153+
import torch
154+
import numpy as np
155+
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
156+
from diffusers.hooks.group_offloading import apply_group_offloading
157+
from diffusers.utils import export_to_video, load_image
158+
from transformers import UMT5EncoderModel, CLIPVisionModel
159+
160+
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
161+
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
162+
image_encoder = CLIPVisionModel.from_pretrained(
163+
model_id, subfolder="image_encoder", torch_dtype=torch.float32
164+
)
35165

36-
### Using a custom scheduler
166+
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
167+
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
168+
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
169+
170+
onload_device = torch.device("cuda")
171+
offload_device = torch.device("cpu")
172+
173+
apply_group_offloading(text_encoder,
174+
onload_device=onload_device,
175+
offload_device=offload_device,
176+
offload_type="block_level",
177+
num_blocks_per_group=4
178+
)
179+
180+
transformer.enable_group_offload(
181+
onload_device=onload_device,
182+
offload_device=offload_device,
183+
offload_type="block_level",
184+
num_blocks_per_group=4,
185+
)
186+
pipe = WanImageToVideoPipeline.from_pretrained(
187+
model_id,
188+
vae=vae,
189+
transformer=transformer,
190+
text_encoder=text_encoder,
191+
image_encoder=image_encoder,
192+
torch_dtype=torch.bfloat16
193+
)
194+
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
195+
pipe.to("cuda")
196+
197+
image = load_image(
198+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
199+
)
200+
201+
max_area = 720 * 832
202+
aspect_ratio = image.height / image.width
203+
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
204+
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
205+
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
206+
image = image.resize((width, height))
207+
208+
prompt = (
209+
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
210+
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
211+
)
212+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
213+
214+
num_frames = 33
215+
216+
output = pipe(
217+
image=image,
218+
prompt=prompt,
219+
negative_prompt=negative_prompt,
220+
height=height,
221+
width=width,
222+
num_frames=num_frames,
223+
guidance_scale=5.0,
224+
).frames[0]
225+
226+
export_to_video(output, "wan-i2v.mp4", fps=16)
227+
```
228+
229+
#### Block Level Group Offloading with CUDA Streams
230+
231+
We can speed up group offloading inference, by enabling the use of [CUDA streams](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html). However, using CUDA streams requires moving the model parameters into pinned memory. This allocation is handled by Pytorch under the hood, and can result in a significant spike in CPU RAM usage. Please consider this option if your CPU RAM is atleast 2X the size of the model you are group offloading.
232+
233+
In the following example we will use CUDA streams when group offloading the `WanTransformer3DModel`. When testing on an A100, this example will require 14GB of VRAM, 52GB of CPU RAM, but will generate a video in approximately 9 minutes.
234+
235+
```python
236+
import torch
237+
import numpy as np
238+
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
239+
from diffusers.hooks.group_offloading import apply_group_offloading
240+
from diffusers.utils import export_to_video, load_image
241+
from transformers import UMT5EncoderModel, CLIPVisionModel
242+
243+
# Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
244+
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
245+
image_encoder = CLIPVisionModel.from_pretrained(
246+
model_id, subfolder="image_encoder", torch_dtype=torch.float32
247+
)
248+
249+
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
250+
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
251+
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
252+
253+
onload_device = torch.device("cuda")
254+
offload_device = torch.device("cpu")
255+
256+
apply_group_offloading(text_encoder,
257+
onload_device=onload_device,
258+
offload_device=offload_device,
259+
offload_type="block_level",
260+
num_blocks_per_group=4
261+
)
262+
263+
transformer.enable_group_offload(
264+
onload_device=onload_device,
265+
offload_device=offload_device,
266+
offload_type="leaf_level",
267+
use_stream=True
268+
)
269+
pipe = WanImageToVideoPipeline.from_pretrained(
270+
model_id,
271+
vae=vae,
272+
transformer=transformer,
273+
text_encoder=text_encoder,
274+
image_encoder=image_encoder,
275+
torch_dtype=torch.bfloat16
276+
)
277+
# Since we've offloaded the larger models alrady, we can move the rest of the model components to GPU
278+
pipe.to("cuda")
279+
280+
image = load_image(
281+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
282+
)
283+
284+
max_area = 720 * 832
285+
aspect_ratio = image.height / image.width
286+
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
287+
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
288+
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
289+
image = image.resize((width, height))
290+
291+
prompt = (
292+
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
293+
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
294+
)
295+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
296+
297+
num_frames = 33
298+
299+
output = pipe(
300+
image=image,
301+
prompt=prompt,
302+
negative_prompt=negative_prompt,
303+
height=height,
304+
width=width,
305+
num_frames=num_frames,
306+
guidance_scale=5.0,
307+
).frames[0]
308+
309+
export_to_video(output, "wan-i2v.mp4", fps=16)
310+
```
311+
312+
### Applying Layerwise Casting to the Transformer
313+
314+
Find more information about layerwise casting [here](../optimization/memory.md)
315+
316+
In this example, we will model offloading with layerwise casting. Layerwise casting will downcast each layer's weights to `torch.float8_e4m3fn`, temporarily upcast to `torch.bfloat16` during the forward pass of the layer, then revert to `torch.float8_e4m3fn` afterward. This approach reduces memory requirements by approximately 50% while introducing a minor quality reduction in the generated video due to the precision trade-off.
317+
318+
This example will require 20GB of VRAM.
319+
320+
```python
321+
import torch
322+
import numpy as np
323+
from diffusers import AutoencoderKLWan, WanTransformer3DModel, WanImageToVideoPipeline
324+
from diffusers.hooks.group_offloading import apply_group_offloading
325+
from diffusers.utils import export_to_video, load_image
326+
from transformers import UMT5EncoderModel, CLIPVisionMode
327+
328+
model_id = "Wan-AI/Wan2.1-I2V-14B-720P-Diffusers"
329+
image_encoder = CLIPVisionModel.from_pretrained(
330+
model_id, subfolder="image_encoder", torch_dtype=torch.float32
331+
)
332+
text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16)
333+
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
334+
335+
transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
336+
transformer.enable_layerwise_casting(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
337+
338+
pipe = WanImageToVideoPipeline.from_pretrained(
339+
model_id,
340+
vae=vae,
341+
transformer=transformer,
342+
text_encoder=text_encoder,
343+
image_encoder=image_encoder,
344+
torch_dtype=torch.bfloat16
345+
)
346+
pipe.enable_model_cpu_offload()
347+
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg")
348+
349+
max_area = 720 * 832
350+
aspect_ratio = image.height / image.width
351+
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
352+
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
353+
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
354+
image = image.resize((width, height))
355+
prompt = (
356+
"An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
357+
"the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
358+
)
359+
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
360+
num_frames = 33
361+
362+
output = pipe(
363+
image=image,
364+
prompt=prompt,
365+
negative_prompt=negative_prompt,
366+
height=height,
367+
width=width,
368+
num_frames=num_frames,
369+
num_inference_steps=50,
370+
guidance_scale=5.0,
371+
).frames[0]
372+
export_to_video(output, "wan-i2v.mp4", fps=16)
373+
```
374+
375+
### Using a Custom Scheduler
37376

38377
Wan can be used with many different schedulers, each with their own benefits regarding speed and generation quality. By default, Wan uses the `UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)` scheduler. You can use a different scheduler as follows:
39378

@@ -49,11 +388,10 @@ pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", scheduler
49388
pipe.scheduler = <CUSTOM_SCHEDULER_HERE>
50389
```
51390

52-
### Using single file loading with Wan
53-
54-
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
55-
method.
391+
## Using Single File Loading with Wan 2.1
56392

393+
The `WanTransformer3DModel` and `AutoencoderKLWan` models support loading checkpoints in their original format via the `from_single_file` loading
394+
method.
57395

58396
```python
59397
import torch
@@ -65,6 +403,11 @@ transformer = WanTransformer3DModel.from_single_file(ckpt_path, torch_dtype=torc
65403
pipe = WanPipeline.from_pretrained("Wan-AI/Wan2.1-T2V-1.3B-Diffusers", transformer=transformer)
66404
```
67405

406+
## Recommendations for Inference:
407+
- Keep `AutencoderKLWan` in `torch.float32` for better decoding quality.
408+
- `num_frames` should satisfy the following constraint: `(num_frames - 1) % 4 == 0`
409+
- For smaller resolution videos, try lower values of `shift` (between `2.0` to `5.0`) in the [Scheduler](https://huggingface.co/docs/diffusers/main/en/api/schedulers/flow_match_euler_discrete#diffusers.FlowMatchEulerDiscreteScheduler.shift). For larger resolution videos, try higher values (between `7.0` and `12.0`). The default value is `3.0` for Wan.
410+
68411
## WanPipeline
69412

70413
[[autodoc]] WanPipeline

0 commit comments

Comments
 (0)