Skip to content

Commit 9f963e7

Browse files
authored
[Community Pipelines] Accelerate inference of AnimateDiff by IPEX on CPU (#8643)
* add animatediff_ipex community pipeline * address the 1st round review comments
1 parent 973a62d commit 9f963e7

File tree

2 files changed

+1114
-0
lines changed

2 files changed

+1114
-0
lines changed

examples/community/README.md

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
7070
| Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
7171
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
7272
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
73+
| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
7374

7475
To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
7576

@@ -4099,6 +4100,117 @@ output_frames[0].save(output_video_path, save_all=True,
40994100
append_images=output_frames[1:], duration=100, loop=0)
41004101
```
41014102

4103+
### AnimateDiff on IPEX
4104+
4105+
This diffusion pipeline aims to accelerate the inference of AnimateDiff on Intel Xeon CPUs with BF16/FP32 precision using [IPEX](https://github.com/intel/intel-extension-for-pytorch).
4106+
4107+
To use this pipeline, you need to:
4108+
1. Install [IPEX](https://github.com/intel/intel-extension-for-pytorch)
4109+
4110+
**Note:** For each PyTorch release, there is a corresponding release of IPEX. Here is the mapping relationship. It is recommended to install Pytorch/IPEX2.3 to get the best performance.
4111+
4112+
|PyTorch Version|IPEX Version|
4113+
|--|--|
4114+
|[v2.3.\*](https://github.com/pytorch/pytorch/tree/v2.3.0 "v2.3.0")|[v2.3.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v2.3.0+cpu)|
4115+
|[v1.13.\*](https://github.com/pytorch/pytorch/tree/v1.13.0 "v1.13.0")|[v1.13.\*](https://github.com/intel/intel-extension-for-pytorch/tree/v1.13.100+cpu)|
4116+
4117+
You can simply use pip to install IPEX with the latest version.
4118+
```python
4119+
python -m pip install intel_extension_for_pytorch
4120+
```
4121+
**Note:** To install a specific version, run with the following command:
4122+
```
4123+
python -m pip install intel_extension_for_pytorch==<version_name> -f https://developer.intel.com/ipex-whl-stable-cpu
4124+
```
4125+
2. After pipeline initialization, `prepare_for_ipex()` should be called to enable IPEX accelaration. Supported inference datatypes are Float32 and BFloat16.
4126+
4127+
```python
4128+
pipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
4129+
# For Float32
4130+
pipe.prepare_for_ipex(torch.float32, prompt="A girl smiling")
4131+
# For BFloat16
4132+
pipe.prepare_for_ipex(torch.bfloat16, prompt="A girl smiling")
4133+
```
4134+
4135+
Then you can use the ipex pipeline in a similar way to the default animatediff pipeline.
4136+
```python
4137+
# For Float32
4138+
output = pipe(prompt="A girl smiling", guidance_scale=1.0, num_inference_steps=step)
4139+
# For BFloat16
4140+
with torch.cpu.amp.autocast(enabled = True, dtype = torch.bfloat16):
4141+
output = pipe(prompt="A girl smiling", guidance_scale=1.0, num_inference_steps=step)
4142+
```
4143+
4144+
The following code compares the performance of the original animatediff pipeline with the ipex-optimized pipeline.
4145+
By using this optimized pipeline, we can get about 1.5-2.2 times performance boost with BFloat16 on the fifth generation of Intel Xeon CPUs, code-named Emerald Rapids.
4146+
4147+
```python
4148+
import torch
4149+
from diffusers import MotionAdapter, AnimateDiffPipeline, EulerDiscreteScheduler
4150+
from safetensors.torch import load_file
4151+
from pipeline_animatediff_ipex import AnimateDiffPipelineIpex
4152+
import time
4153+
4154+
device = "cpu"
4155+
dtype = torch.float32
4156+
4157+
prompt = "A girl smiling"
4158+
step = 8 # Options: [1,2,4,8]
4159+
repo = "ByteDance/AnimateDiff-Lightning"
4160+
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
4161+
base = "emilianJR/epiCRealism" # Choose to your favorite base model.
4162+
4163+
adapter = MotionAdapter().to(device, dtype)
4164+
adapter.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
4165+
4166+
# Helper function for time evaluation
4167+
def elapsed_time(pipeline, nb_pass=3, num_inference_steps=1):
4168+
# warmup
4169+
for _ in range(2):
4170+
output = pipeline(prompt = prompt, guidance_scale=1.0, num_inference_steps = num_inference_steps)
4171+
#time evaluation
4172+
start = time.time()
4173+
for _ in range(nb_pass):
4174+
pipeline(prompt = prompt, guidance_scale=1.0, num_inference_steps = num_inference_steps)
4175+
end = time.time()
4176+
return (end - start) / nb_pass
4177+
4178+
############## bf16 inference performance ###############
4179+
4180+
# 1. IPEX Pipeline initialization
4181+
pipe = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
4182+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
4183+
pipe.prepare_for_ipex(torch.bfloat16, prompt = prompt)
4184+
4185+
# 2. Original Pipeline initialization
4186+
pipe2 = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
4187+
pipe2.scheduler = EulerDiscreteScheduler.from_config(pipe2.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
4188+
4189+
# 3. Compare performance between Original Pipeline and IPEX Pipeline
4190+
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
4191+
latency = elapsed_time(pipe, num_inference_steps=step)
4192+
print("Latency of AnimateDiffPipelineIpex--bf16", latency, "s for total", step, "steps")
4193+
latency = elapsed_time(pipe2, num_inference_steps=step)
4194+
print("Latency of AnimateDiffPipeline--bf16", latency, "s for total", step, "steps")
4195+
4196+
############## fp32 inference performance ###############
4197+
4198+
# 1. IPEX Pipeline initialization
4199+
pipe3 = AnimateDiffPipelineIpex.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
4200+
pipe3.scheduler = EulerDiscreteScheduler.from_config(pipe3.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
4201+
pipe3.prepare_for_ipex(torch.float32, prompt = prompt)
4202+
4203+
# 2. Original Pipeline initialization
4204+
pipe4 = AnimateDiffPipeline.from_pretrained(base, motion_adapter=adapter, torch_dtype=dtype).to(device)
4205+
pipe4.scheduler = EulerDiscreteScheduler.from_config(pipe4.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
4206+
4207+
# 3. Compare performance between Original Pipeline and IPEX Pipeline
4208+
latency = elapsed_time(pipe3, num_inference_steps=step)
4209+
print("Latency of AnimateDiffPipelineIpex--fp32", latency, "s for total", step, "steps")
4210+
latency = elapsed_time(pipe4, num_inference_steps=step)
4211+
print("Latency of AnimateDiffPipeline--fp32",latency, "s for total", step, "steps")
4212+
```
4213+
41024214
# Perturbed-Attention Guidance
41034215

41044216
[Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)

0 commit comments

Comments
 (0)