You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+264-1Lines changed: 264 additions & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -100,7 +100,270 @@ After installation completes, run the training script.
100
100
101
101
## Wan 2.1 Training
102
102
103
-
Coming soon.
103
+
in the first part, we'll run on a single host VM to get familiar with the workflow, then run on xpk for large scale training.
104
+
105
+
Although not required, attaching an external disk is recommended as weights take up a lot of disk space. [Follow these instructions if you would like to attach an external disk](https://cloud.google.com/tpu/docs/attach-durable-block-storage).
106
+
107
+
This workflow was tested using v5p-8 with a 500GB disk attached.
108
+
109
+
### Dataset Preparation
110
+
111
+
For this example, we'll be using the [PusaV1 dataset](https://huggingface.co/datasets/RaphaelLiu/PusaV1_training).
Next run the TFRecords conversion script. This step prepares training and eval datasets. Validation is done as described in [Scaling Rectified Flow Transformers for High-Resolution Image Synthesis](https://arxiv.org/pdf/2403.03206). More details [here](https://github.com/mlcommons/training/tree/master/text_to_image#5-quality)
The evaluation dataset creation takes the first 420 samples of the dataset and adds a timestep field. We then need to manually delete the first 420 samples from the `train` folder so they are not used in training.
Loading the data is supported both locally from the disk created above, or from `gcs`. In this guide, we'll be using a gcs bucket to train. First copy the data to the GCS bucket.
- per_device_batch_size can be a fractional, but must be a whole number when multiplied by number of devices. In this example, 0.25 * 4 (devices) = effective global batch size = 1.
249
+
- The step time in v5p-8 with global batch size = 1 is large due to using `FULL` remat. On larger number of chips we can run larger batch sizes greatly increasing MFU, as we will see in the next session of deploying with xpk.
250
+
- To enable eval during training set `eval_every` to a value > 0.
251
+
- In Wan2.1, the ici_fsdp_parallelism axis is used for sequence parallelism, the ici_tensor_parallelism axis is used for head parallelism.
252
+
- You can enable both, keeping in mind that Wan2.1 has 40 heads and 40 must be evenly divisible by ici_tensor_parallelism.
253
+
- For Sequence parallelism, the code pads the sequence length to evenly divide the sequence. Try out different ici_fsdp_parallelism numbers, but we find 2 and 4 to be the best right now.
254
+
255
+
You should eventually see a training run as:
256
+
257
+
```bash
258
+
***** Running training *****
259
+
Instantaneous batch size per device = 0.25
260
+
Total train batch size (w. parallel & distributed) = 1
261
+
Total optimization steps = 1000
262
+
Calculated TFLOPs per pass: 4893.2719
263
+
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
264
+
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
265
+
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
266
+
Warning, batch dimension should be shardable among the devices in data and fsdp axis, batch dimension: 1, devices_in_data_fsdp: 4
This assummes the user has already created an xpk cluster, installed all dependencies and the also created the dataset from the step above. For getting started with MaxDiffusion and xpk see [this guide](docs/getting_started/run_maxdiffusion_via_xpk.md).
280
+
281
+
Using v5p-256 Then the command to run on xpk is as follows:
enable_generate_video_for_eval: False # This will increase the used TPU memory.
324
-
eval_max_number_of_samples_in_bucket: 60# The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(considered_timesteps_list).
324
+
eval_max_number_of_samples_in_bucket: 60# The number of samples per bucket for evaluation. This is calculated by num_eval_samples / len(timesteps_list).
0 commit comments