Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion examples/omnigen2/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,76 @@ case.

## Training

Coming soon
### 1. Preparation

Before launching the training, you need to prepare the following configuration files.

#### Step 1: Set Up the Training Configuration

This is a YAML file that specifies crucial parameters for your training job, including the model architecture,
optimizer, dataset paths, and validation settings.

We provide two templates to get you started:

* **Full-Parameter Fine-Tuning:** `configs/finetune/ft.yml`
* **LoRA Fine-Tuning:** `configs/finetune/ft_lora.yml`

Copy one of these templates and modify it according to your needs. Below are some of the most important parameters you
may want to adjust:

- `name`: The experiment name. This is used to create a directory for logs and saved model weights (e.g.,
`experiments/your_exp_name`).
- `data.config_path`: Path to the data configuration file that defines your training data sources and mixing ratios.
- `data.max_output_pixels`: The maximum number of pixels for an output image. Larger images will be downsampled while
maintaining their aspect ratio.
- `data.max_input_pixels`: A list specifying the maximum pixel count for input images, corresponding to one, two, three,
or more inputs.
- `data.max_side_length`: The maximum side length for any image (input or output). Images exceeding this will be
downsampled while maintaining their aspect ratio.
- `dataloader.batch_size`: The batch size per NPU.
- `train.steps`: The total number of training steps to run.
- `train.lr_scheduler.lr`: The learning rate for the optimizer. **Note:** This often requires tuning based on your
dataset size and whether you are using LoRA. We recommend using lower learning rate for full-parameter fine-tuning.

#### Step 2: Configure Your Dataset

The data configuration consists of a set of `yaml` and `jsonl` files.

* The `.yml` file defines the mixing ratios for different data sources.
* The `.jsonl` files contain the actual data entries, with each line representing a single data sample.

For a practical example, please refer to `configs/finetune/data/mix.yml`.
Each line in a `.jsonl` file describes a sample, generally following this format:

```json
{
"task_type": "edit",
"instruction": "add a hat to the person",
"input_images": [
"/path/to/your/data/edit/input1.png",
"/path/to/your/data/edit/input2.png"
],
"output_image": "/path/to/your/data/edit/output.png"
}
```

*Note: The `input_images` field can be omitted for text-to-image (T2I) tasks.*

### 2. 🚀 Launching the Training

Once your configuration is ready, you can launch the training script. All experiment artifacts, including logs and
checkpoints, will be saved in `experiments/${experiment_name}`.

We provide convenient shell scripts to handle the complexities of launching distributed training jobs. You can use them
directly or adapt them for your environment.

* **For Full-Parameter Fine-Tuning:** `scripts/run/ft.sh`
* **For LoRA Fine-Tuning:** `scripts/run/ft_lora.sh`

> **⚠️ Note on LoRA Checkpoints:**
> Currently, when training with LoRA, the script saves the entire model's parameters (including the frozen base model
> weights) in the checkpoint. This is due to a limitation in easily extracting only the LoRA-related parameters when
> using FSDP.

## Performance

Expand All @@ -161,6 +230,13 @@ Coming soon
| OmniGen2 | In-context Generation | 1 | BF16 | 1 | 768x1152 | Euler | 50 | 248 |
| OmniGen2 | In-context Generation | 1 | BF16 | 2 | 1024x1024 | Euler | 50 | 870 |

### Training

| Model | Fine-tuning | Cards | Batch size | Resolution | Precision | s/step | Recipe |
|:--------:|:-----------:|:-----:|:----------:|:----------:|:---------:|:------:|:-------------------------------------------:|
| OmniGen2 | Full | 8 | 1 | 720x720 | BF16 | 5.03 | [ft.yml](configs/finetune/ft.yml) |
| OmniGen2 | LoRA | 8 | 1 | 720x720 | BF16 | 3.78 | [ft_lora.yml](configs/finetune/ft_lora.yml) |

## Acknowledgement

If you find OmniGen2 useful, please cite the original work:
Expand Down
49 changes: 8 additions & 41 deletions examples/omnigen2/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@
from omnigen2.utils.img_util import create_collage
from PIL import Image
from tqdm import tqdm
from transformers import Qwen2_5_VLProcessor

from mindspore import dtype
from mindspore.nn import no_init_parameters

from mindone.diffusers import AutoencoderKL
from mindone.transformers import Qwen2_5_VLForConditionalGeneration
from mindone.transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor

NEGATIVE_PROMPT = (
"(((deformed))), blurry, over saturation, bad anatomy, disfigured, poorly drawn face, mutation, mutated,"
Expand Down Expand Up @@ -314,36 +313,16 @@ def main(args):
width_input = gr.Slider(label="Width", minimum=256, maximum=2048, value=1024, step=128)
with gr.Row(equal_height=True):
text_guidance_scale_input = gr.Slider(
label="Text Guidance Scale",
minimum=1.0,
maximum=8.0,
value=5.0,
step=0.1,
label="Text Guidance Scale", minimum=1.0, maximum=8.0, value=5.0, step=0.1
)

image_guidance_scale_input = gr.Slider(
label="Image Guidance Scale",
minimum=1.0,
maximum=3.0,
value=2.0,
step=0.1,
label="Image Guidance Scale", minimum=1.0, maximum=3.0, value=2.0, step=0.1
)
with gr.Row(equal_height=True):
cfg_range_start = gr.Slider(
label="CFG Range Start",
minimum=0.0,
maximum=1.0,
value=0.0,
step=0.1,
)
cfg_range_start = gr.Slider(label="CFG Range Start", minimum=0.0, maximum=1.0, value=0.0, step=0.1)

cfg_range_end = gr.Slider(
label="CFG Range End",
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.1,
)
cfg_range_end = gr.Slider(label="CFG Range End", minimum=0.0, maximum=1.0, value=1.0, step=0.1)

def adjust_end_slider(start_val, end_val):
return max(start_val, end_val)
Expand All @@ -370,28 +349,16 @@ def adjust_start_slider(end_val, start_val):
num_inference_steps = gr.Slider(label="Inference Steps", minimum=20, maximum=100, value=50, step=1)
with gr.Row(equal_height=True):
num_images_per_prompt = gr.Slider(
label="Number of images per prompt",
minimum=1,
maximum=4,
value=1,
step=1,
label="Number of images per prompt", minimum=1, maximum=4, value=1, step=1
)

seed_input = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=0, step=1)
with gr.Row(equal_height=True):
max_input_image_side_length = gr.Slider(
label="max_input_image_side_length",
minimum=256,
maximum=2048,
value=2048,
step=256,
label="max_input_image_side_length", minimum=256, maximum=2048, value=2048, step=256
)
max_pixels = gr.Slider(
label="max_pixels",
minimum=256 * 256,
maximum=1536 * 1536,
value=1024 * 1024,
step=256 * 256,
label="max_pixels", minimum=256 * 256, maximum=1536 * 1536, value=1024 * 1024, step=256 * 256
)

with gr.Column():
Expand Down
4 changes: 4 additions & 0 deletions examples/omnigen2/configs/finetune/data/edit/edit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
data:
- path: "configs/finetune/data/edit/jsonls/0.jsonl"
type: "edit"
ratio: 1.0
2 changes: 2 additions & 0 deletions examples/omnigen2/configs/finetune/data/edit/jsonls/0.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"task_type": "edit", "instruction": "add a hat to the person", "input_images": ["/path/to/your/data/edit/0.png"], "output_image": "/path/to/your/data/edit/0.png"}
{"task_type": "edit", "instruction": "add a dog behind the person", "input_images": ["/path/to/your/data/edit/1.png"], "output_image": "/path/to/your/data/edit/1.png"}
4 changes: 4 additions & 0 deletions examples/omnigen2/configs/finetune/data/ic/ic.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
data:
- path: "configs/finetune/data/ic/jsonls/0.jsonl"
type: "ic"
ratio: 1.0
2 changes: 2 additions & 0 deletions examples/omnigen2/configs/finetune/data/ic/jsonls/0.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"task_type": "ic", "instruction": "A big tree is in the forest", "input_images": ["/path/to/your/data/ic/0.png", "/path/to/your/data/ic/1.png"], "output_image": "/path/to/your/data/ic/0.png"}
{"task_type": "ic", "instruction": "a dog is running on grass", "input_images": ["/path/to/your/data/ic/2.png", "/path/to/your/data/ic/3.png"], "output_image": "/path/to/your/data/ic/1.png"}
10 changes: 10 additions & 0 deletions examples/omnigen2/configs/finetune/data/mix.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
data:
- path: 'configs/finetune/data/t2i/t2i.yml'
type: 't2i'
ratio: 0.33
- path: 'configs/finetune/data/edit/edit.yml'
type: 'edit'
ratio: 0.33
- path: 'configs/finetune/data/ic/ic.yml'
type: 'ic'
ratio: 0.33
2 changes: 2 additions & 0 deletions examples/omnigen2/configs/finetune/data/t2i/jsonls/0.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
{"task_type": "t2i", "instruction": "A big tree is in the forest", "output_image": "/path/to/your/data/t2i/0.png"}
{"task_type": "t2i", "instruction": "a dog is running on grass", "output_image": "/path/to/your/data/t2i/1.png"}
4 changes: 4 additions & 0 deletions examples/omnigen2/configs/finetune/data/t2i/t2i.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
data:
- path: "configs/finetune/data/t2i/jsonls/0.jsonl"
type: "t2i"
ratio: 1.0
70 changes: 70 additions & 0 deletions examples/omnigen2/configs/finetune/ft.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
name: ft

env:
debug: False
seed: 2233
device_specific_seed: True

models:
transformer:
pretrained_model_name_or_path: OmniGen2/OmniGen2
mindspore_dtype: bfloat16
vae:
pretrained_model_name_or_path: black-forest-labs/FLUX.1-dev
mindspore_dtype: bfloat16
text_encoder:
pretrained_model_name_or_path: Qwen/Qwen2.5-VL-3B-Instruct
mindspore_dtype: bfloat16

data:
config_path: configs/finetune/data/mix.yml
use_chat_template: True
max_input_pixels: [ 1048576, 1048576, 589824, 262144 ] # [1024 * 1024, 1024 * 1024, 768 * 768, 512 * 512]
max_output_pixels: 1048576 # 1024 * 1024
max_side_length: 2048
prompt_dropout_prob: 0.0001
ref_img_dropout_prob: 0.5

dataloader:
batch_size: 1
shuffle: True
num_workers: 6
project_columns: [ "input_images", "output_image", "text_ids", "text_mask" ]

collator:
maximum_text_tokens: 888

transport:
path_type: Linear
prediction: velocity
snr_type: lognorm
do_shift: True
dynamic_time_shift: True
time_shift_version: v1

train:
steps: 4000
gradient_checkpointing: True
resume_from_checkpoint: latest

settings:
clip_grad: True
clip_norm: 1.0
gradient_accumulation_steps: 1
zero_stage: 2

lr_scheduler:
name: constant
lr: 8.0e-7
warmup_steps: 500

optimizer:
name: adamw_bf16
betas: [ 0.9, 0.95 ]
weight_decay: 0.01
eps: 1e-08

save:
checkpointing_steps: 1000
checkpoints_total_limit: null
train_visualization_steps: 100
77 changes: 77 additions & 0 deletions examples/omnigen2/configs/finetune/ft_lora.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
name: ft_lora

env:
debug: False
seed: 2233
device_specific_seed: True

models:
transformer:
pretrained_model_name_or_path: OmniGen2/OmniGen2
mindspore_dtype: bfloat16
vae:
pretrained_model_name_or_path: black-forest-labs/FLUX.1-dev
mindspore_dtype: bfloat16
text_encoder:
pretrained_model_name_or_path: Qwen/Qwen2.5-VL-3B-Instruct
mindspore_dtype: bfloat16

data:
config_path: configs/finetune/data/mix.yml
use_chat_template: True
max_input_pixels: [ 1048576, 1048576, 589824, 262144 ] # [1024 * 1024, 1024 * 1024, 768 * 768, 512 * 512]
max_output_pixels: 1048576 # 1024 * 1024
max_side_length: 2048
prompt_dropout_prob: 0.0001
ref_img_dropout_prob: 0.5

dataloader:
batch_size: 1
shuffle: True
num_workers: 6
project_columns: [ "input_images", "output_image", "text_ids", "text_mask" ]

collator:
maximum_text_tokens: 888

transport:
path_type: Linear
prediction: velocity
snr_type: lognorm
do_shift: True
dynamic_time_shift: True
time_shift_version: v1

train:
steps: 4000
gradient_checkpointing: True
resume_from_checkpoint: latest

settings:
clip_grad: True
clip_norm: 1.0
gradient_accumulation_steps: 1
zero_stage: 0

lr_scheduler:
name: constant
lr: 8.0e-7
warmup_steps: 500

optimizer:
name: adamw_bf16
betas: [ 0.9, 0.95 ]
weight_decay: 0.01
eps: 1e-08

lora:
target_modules: [ "to_k", "to_q", "to_v", "to_out.0" ]
r: 8
lora_alpha: 8
lora_dropout: 0
init_lora_weights: gaussian

save:
checkpointing_steps: 1000
checkpoints_total_limit: null
train_visualization_steps: 100
3 changes: 3 additions & 0 deletions examples/omnigen2/omnigen2/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .collator import OmniGen2Collator
from .omnigen2_test_dataset import OmniGen2TestDataset
from .omnigen2_train_dataset import OmniGen2TrainDataset
20 changes: 20 additions & 0 deletions examples/omnigen2/omnigen2/dataset/collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from numpy import ndarray


class OmniGen2Collator:
def __init__(self, tokenizer, max_token_len: int):
self.tokenizer = tokenizer
self.max_token_len = max_token_len

def __call__(self, instructions: "ndarray") -> tuple["ndarray", "ndarray"]:
text_inputs = self.tokenizer(
instructions.tolist(),
padding="longest",
max_length=self.max_token_len,
truncation=True,
return_tensors="np",
)
return text_inputs.input_ids, text_inputs.attention_mask
Loading