Skip to content

Commit f5f212f

Browse files
authored
adds wan2.1 training readme guide. (#272)
* adds wan2.1 training readme guide. * update xpk command. * update xpk parallelism. * resolve sanbao's comments.
1 parent 9ffde26 commit f5f212f

File tree

2 files changed

+266
-3
lines changed

2 files changed

+266
-3
lines changed

README.md

Lines changed: 264 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,270 @@ After installation completes, run the training script.
100100

101101
## Wan 2.1 Training
102102

103-
Coming soon.
103+
in the first part, we'll run on a single host VM to get familiar with the workflow, then run on xpk for large scale training.
104+
105+
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).
106+
107+
This workflow was tested using v5p-8 with a 500GB disk attached.
108+
109+
### Dataset Preparation
110+
111+
For this example, we'll be using the [PusaV1 dataset](https://huggingface.co/datasets/RaphaelLiu/PusaV1_training).
112+
113+
First, download the dataset.
114+
115+
```bash
116+
export HF_DATASET_DIR=/mnt/disks/external_disk/PusaV1_training/
117+
export TFRECORDS_DATASET_DIR=/mnt/disks/external_disk/wan_tfr_dataset_pusa_v1
118+
huggingface-cli download RaphaelLiu/PusaV1_training --repo-type dataset --local-dir $HF_DATASET_DIR
119+
```
120+
121+
Next run the TFRecords conversion script. This step prepares training and eval datasets. Validation is done as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206). More details [here](https://github.com/mlcommons/training/tree/master/text_to_image#5-quality)
122+
123+
Training dataset.
124+
125+
```bash
126+
python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py src/maxdiffusion/configs/base_wan_14b.yml train_data_dir=$HF_DATASET_DIR tfrecords_dir=$TFRECORDS_DATASET_DIR/train no_records_per_shard=10 enable_eval_timesteps=False
127+
```
128+
129+
The script will not have an output, but you can check the progress using:
130+
131+
```bash
132+
ls -ll $TFRECORDS_DATASET_DIR/train
133+
```
134+
135+
Evaluation dataset.
136+
137+
```bash
138+
python src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py src/maxdiffusion/configs/base_wan_14b.yml train_data_dir=$HF_DATASET_DIR tfrecords_dir=$TFRECORDS_DATASET_DIR/eval no_records_per_shard=10 enable_eval_timesteps=True
139+
```
140+
141+
The evaluation dataset creation takes the first 420 samples of the dataset and adds a timestep field. We then need to manually delete the first 420 samples from the `train` folder so they are not used in training.
142+
143+
144+
```bash
145+
printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' rm
146+
```
147+
148+
And verify that they do not exist.
149+
150+
```bash
151+
printf "%s\n" $TFRECORDS_DATASET_DIR/train/file_*-*.tfrec | awk -F '[-.]' '$2+0 <= 420' | xargs -d '\n' echo
152+
```
153+
154+
After the script is done running, you should see the following directory structure inside `$TFRECORDS_DATASET_DIR`
155+
156+
```
157+
train
158+
eval_timesteps
159+
```
160+
161+
In some instances an empty file `file_42-430.tfrec` is created inside `eval_timesteps`, for sanity check, let's run a delete command.
162+
163+
```bash
164+
rm $TFRECORDS_DATASET_DIR/eval_timesteps/file_42-430.tfrec
165+
```
166+
167+
### Training on a Single VM
168+
169+
Loading the data is supported both locally from the disk created above, or from `gcs`. In this guide, we'll be using a gcs bucket to train. First copy the data to the GCS bucket.
170+
171+
```bash
172+
BUCKET_NAME=my-bucket
173+
gsutil -m cp -r $TFRECORDS_DATASET_DIR gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}
174+
```
175+
176+
Now run the training command:
177+
178+
```bash
179+
RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM}
180+
OUTPUT_DIR=gs://$BUCKET_NAME/wan/
181+
DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/
182+
EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/
183+
SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/
184+
```
185+
186+
```bash
187+
export LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
188+
--xla_tpu_megacore_fusion_allow_ags=false \
189+
--xla_enable_async_collective_permute=true \
190+
--xla_tpu_enable_ag_backward_pipelining=true \
191+
--xla_tpu_enable_data_parallel_all_reduce_opt=true \
192+
--xla_tpu_data_parallel_opt_different_sized_ops=true \
193+
--xla_tpu_enable_async_collective_fusion=true \
194+
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
195+
--xla_tpu_overlap_compute_collective_tc=true \
196+
--xla_enable_async_all_gather=true \
197+
--xla_tpu_scoped_vmem_limit_kib=65536 \
198+
--xla_tpu_enable_async_all_to_all=true \
199+
--xla_tpu_enable_all_experimental_scheduler_features=true \
200+
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
201+
--xla_tpu_host_transfer_overlap_limit=24 \
202+
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \
203+
--xla_lhs_prioritize_async_depth_over_stall=ENABLED \
204+
--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
205+
--xla_should_add_loop_invariant_op_in_chain=ENABLED \
206+
--xla_max_concurrent_host_send_recv=100 \
207+
--xla_tpu_scheduler_percent_shared_memory_limit=100 \
208+
--xla_latency_hiding_scheduler_rerun=2 \
209+
--xla_tpu_use_minor_sharding_for_major_trivial_input=true \
210+
--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \
211+
--xla_tpu_assign_all_reduce_scatter_layout=true'
212+
```
213+
214+
```bash
215+
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \
216+
src/maxdiffusion/configs/base_wan_14b.yml \
217+
attention='flash' \
218+
weights_dtype=bfloat16 \
219+
activations_dtype=bfloat16 \
220+
guidance_scale=5.0 \
221+
flow_shift=5.0 \
222+
fps=16 \
223+
skip_jax_distributed_system=False \
224+
run_name=${RUN_NAME} \
225+
output_dir=${OUTPUT_DIR} \
226+
train_data_dir=${DATASET_DIR} \
227+
load_tfrecord_cached=True \
228+
height=1280 \
229+
width=720 \
230+
num_frames=81 \
231+
num_inference_steps=50 \
232+
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
233+
max_train_steps=1000 \
234+
enable_profiler=True \
235+
dataset_save_location=${SAVE_DATASET_DIR} \
236+
remat_policy='FULL' \
237+
flash_min_seq_length=0 \
238+
seed=$RANDOM \
239+
skip_first_n_steps_for_profiler=3 \
240+
profiler_steps=3 \
241+
per_device_batch_size=0.25 \
242+
ici_data_parallelism=1 \
243+
ici_fsdp_parallelism=4 \
244+
ici_tensor_parallelism=1
245+
```
246+
247+
It is important to note a couple of things:
248+
- per_device_batch_size can be a fractional, but must be a whole number when multiplied by number of devices. In this example, 0.25 * 4 (devices) = effective global batch size = 1.
249+
- The step time in v5p-8 with global batch size = 1 is large due to using `FULL` remat. On larger number of chips we can run larger batch sizes greatly increasing MFU, as we will see in the next session of deploying with xpk.
250+
- To enable eval during training set `eval_every` to a value > 0.
251+
- In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism.
252+
- You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism.
253+
- For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now.
254+
255+
You should eventually see a training run as:
256+
257+
```bash
258+
***** Running training *****
259+
Instantaneous batch size per device = 0.25
260+
Total train batch size (w. parallel & distributed) = 1
261+
Total optimization steps = 1000
262+
Calculated TFLOPs per pass: 4893.2719
263+
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
264+
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
265+
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
266+
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
267+
completed step: 0, seconds: 142.395, TFLOP/s/device: 34.364, loss: 0.270
268+
To see full metrics 'tensorboard --logdir=gs://jfacevedo-maxdiffusion-v5p/wan/jfacevedo-wan-v5p-8-17263/tensorboard/'
269+
completed step: 1, seconds: 137.207, TFLOP/s/device: 35.664, loss: 0.144
270+
completed step: 2, seconds: 36.014, TFLOP/s/device: 135.871, loss: 0.210
271+
completed step: 3, seconds: 36.016, TFLOP/s/device: 135.864, loss: 0.120
272+
completed step: 4, seconds: 36.008, TFLOP/s/device: 135.894, loss: 0.107
273+
completed step: 5, seconds: 36.008, TFLOP/s/device: 135.895, loss: 0.346
274+
completed step: 6, seconds: 36.006, TFLOP/s/device: 135.900, loss: 0.169
275+
```
276+
277+
### Deploying with XPK
278+
279+
This assummes the user has already created an xpk cluster, installed all dependencies and the also created the dataset from the step above. For getting started with MaxDiffusion and xpk see [this guide](docs/getting_started/run_maxdiffusion_via_xpk.md).
280+
281+
Using v5p-256 Then the command to run on xpk is as follows:
282+
283+
```bash
284+
RUN_NAME=jfacevedo-wan-v5p-8-${RANDOM}
285+
OUTPUT_DIR=gs://$BUCKET_NAME/wan/
286+
DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/train/
287+
EVAL_DATA_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/eval_timesteps/
288+
SAVE_DATASET_DIR=gs://$BUCKET_NAME/${TFRECORDS_DATASET_DIR##*/}/save/
289+
```
290+
291+
```bash
292+
LIBTPU_INIT_ARGS='--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true \
293+
--xla_tpu_megacore_fusion_allow_ags=false \
294+
--xla_enable_async_collective_permute=true \
295+
--xla_tpu_enable_ag_backward_pipelining=true \
296+
--xla_tpu_enable_data_parallel_all_reduce_opt=true \
297+
--xla_tpu_data_parallel_opt_different_sized_ops=true \
298+
--xla_tpu_enable_async_collective_fusion=true \
299+
--xla_tpu_enable_async_collective_fusion_multiple_steps=true \
300+
--xla_tpu_overlap_compute_collective_tc=true \
301+
--xla_enable_async_all_gather=true \
302+
--xla_tpu_scoped_vmem_limit_kib=65536 \
303+
--xla_tpu_enable_async_all_to_all=true \
304+
--xla_tpu_enable_all_experimental_scheduler_features=true \
305+
--xla_tpu_enable_scheduler_memory_pressure_tracking=true \
306+
--xla_tpu_host_transfer_overlap_limit=24 \
307+
--xla_tpu_aggressive_opt_barrier_removal=ENABLED \
308+
--xla_lhs_prioritize_async_depth_over_stall=ENABLED \
309+
--xla_should_allow_loop_variant_parameter_in_chain=ENABLED \
310+
--xla_should_add_loop_invariant_op_in_chain=ENABLED \
311+
--xla_max_concurrent_host_send_recv=100 \
312+
--xla_tpu_scheduler_percent_shared_memory_limit=100 \
313+
--xla_latency_hiding_scheduler_rerun=2 \
314+
--xla_tpu_use_minor_sharding_for_major_trivial_input=true \
315+
--xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 \
316+
--xla_tpu_assign_all_reduce_scatter_layout=true'
317+
```
318+
319+
```bash
320+
python3 ~/xpk/xpk.py workload create \
321+
--cluster=$CLUSTER_NAME \
322+
--project=$PROJECT \
323+
--zone=$ZONE \
324+
--device-type=$DEVICE_TYPE \
325+
--num-slices=1 \
326+
--command=" \
327+
HF_HUB_CACHE=/mnt/disks/external_disk/maxdiffusion_hf_cache/ python src/maxdiffusion/train_wan.py \
328+
src/maxdiffusion/configs/base_wan_14b.yml \
329+
attention='flash' \
330+
weights_dtype=bfloat16 \
331+
activations_dtype=bfloat16 \
332+
guidance_scale=5.0 \
333+
flow_shift=5.0 \
334+
fps=16 \
335+
skip_jax_distributed_system=False \
336+
run_name=${RUN_NAME} \
337+
output_dir=${OUTPUT_DIR} \
338+
train_data_dir=${DATASET_DIR} \
339+
load_tfrecord_cached=True \
340+
height=1280 \
341+
width=720 \
342+
num_frames=81 \
343+
num_inference_steps=50 \
344+
jax_cache_dir=${OUTPUT_DIR}/jax_cache/ \
345+
enable_profiler=True \
346+
dataset_save_location=${SAVE_DATASET_DIR} \
347+
remat_policy='HIDDEN_STATE_WITH_OFFLOAD' \
348+
flash_min_seq_length=0 \
349+
seed=$RANDOM \
350+
skip_first_n_steps_for_profiler=3 \
351+
profiler_steps=3 \
352+
per_device_batch_size=0.25 \
353+
ici_data_parallelism=32 \
354+
ici_fsdp_parallelism=4 \
355+
ici_tensor_parallelism=1" \
356+
max_train_steps=5000 \
357+
eval_every=100 \
358+
eval_data_dir=${EVAL_DATA_DIR} \
359+
enable_generate_video_for_eval=True \
360+
warmup_steps_fraction=0.025"
361+
--base-docker-image=${IMAGE_DIR} \
362+
--enable-debug-logs \
363+
--workload=${RUN_NAME} \
364+
--priority=medium \
365+
--max-restarts=0
366+
```
104367
105368
## Flux Training
106369

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ global_batch_size: 0
237237
tfrecords_dir: ''
238238
no_records_per_shard: 0
239239
enable_eval_timesteps: False
240-
considered_timesteps_list: [125, 250, 375, 500, 625, 750, 875]
240+
timesteps_list: [125, 250, 375, 500, 625, 750, 875]
241241
num_eval_samples: 420
242242

243243
warmup_steps_fraction: 0.1
@@ -321,6 +321,6 @@ qwix_module_path: ".*"
321321
eval_every: -1
322322
eval_data_dir: ""
323323
enable_generate_video_for_eval: False # This will increase the used TPU memory.
324-
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(considered_timesteps_list).
324+
eval_max_number_of_samples_in_bucket: 60 # The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).
325325

326326
enable_ssim: False

0 commit comments

Comments
 (0)