diff --git a/README.md b/README.md index 20414533..2deb8ba9 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ [![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml) # What's new? +- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported - **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported. - **`2025/10/14`**: NVIDIA DGX Spark Flux support. - **`2025/8/14`**: LTX-Video img2vid generation is now supported. diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer.py b/src/maxdiffusion/checkpointing/wan_checkpointer.py index 0dd493a3..12151bff 100644 --- a/src/maxdiffusion/checkpointing/wan_checkpointer.py +++ b/src/maxdiffusion/checkpointing/wan_checkpointer.py @@ -14,34 +14,37 @@ limitations under the License. """ -from abc import ABC +from abc import ABC, abstractmethod import json import jax import numpy as np from typing import Optional, Tuple from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) -from ..pipelines.wan.wan_pipeline import WanPipeline +from ..pipelines.wan.wan_pipeline import WanPipeline2_1, WanPipeline2_2 from .. import max_logging, max_utils import orbax.checkpoint as ocp from etils import epath + WAN_CHECKPOINT = "WAN_CHECKPOINT" class WanCheckpointer(ABC): - def __init__(self, config, checkpoint_type): + def __init__(self, config, checkpoint_type: str = WAN_CHECKPOINT): self.config = config self.checkpoint_type = checkpoint_type self.opt_state = None - self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( - self.config.checkpoint_dir, - enable_checkpointing=True, - save_interval_steps=1, - checkpoint_type=checkpoint_type, - dataset_type=config.dataset_type, + self.checkpoint_manager: ocp.CheckpointManager = ( + create_orbax_checkpoint_manager( + self.config.checkpoint_dir, + enable_checkpointing=True, + save_interval_steps=1, + checkpoint_type=checkpoint_type, + dataset_type=config.dataset_type, + ) ) def _create_optimizer(self, model, config, learning_rate): @@ -51,6 +54,25 @@ def _create_optimizer(self, model, config, learning_rate): tx = max_utils.create_optimizer(config, learning_rate_scheduler) return tx, learning_rate_scheduler + @abstractmethod + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: + raise NotImplementedError + + @abstractmethod + def load_diffusers_checkpoint(self): + raise NotImplementedError + + @abstractmethod + def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2], Optional[dict], Optional[int]]: + raise NotImplementedError + + @abstractmethod + def save_checkpoint(self, train_step, pipeline, train_states: dict): + raise NotImplementedError + + +class WanCheckpointer2_1(WanCheckpointer): + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: if step is None: step = self.checkpoint_manager.latest_step() @@ -85,24 +107,24 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic return restored_checkpoint, step def load_diffusers_checkpoint(self): - pipeline = WanPipeline.from_pretrained(self.config) + pipeline = WanPipeline2_1.from_pretrained(self.config) return pipeline - def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]: + def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_1, Optional[dict], Optional[int]]: restored_checkpoint, step = self.load_wan_configs_from_orbax(step) opt_state = None if restored_checkpoint: max_logging.log("Loading WAN pipeline from checkpoint") - pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint) - if "opt_state" in restored_checkpoint["wan_state"].keys(): - opt_state = restored_checkpoint["wan_state"]["opt_state"] + pipeline = WanPipeline2_1.from_checkpoint(self.config, restored_checkpoint) + if "opt_state" in restored_checkpoint.wan_state.keys(): + opt_state = restored_checkpoint.wan_state["opt_state"] else: max_logging.log("No checkpoint found, loading default pipeline.") pipeline = self.load_diffusers_checkpoint() return pipeline, opt_state, step - def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict): + def save_checkpoint(self, train_step, pipeline: WanPipeline2_1, train_states: dict): """Saves the training state and model configurations.""" def config_to_json(model_or_config): @@ -120,7 +142,96 @@ def config_to_json(model_or_config): max_logging.log(f"Checkpoint for step {train_step} saved.") -def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict): +class WanCheckpointer2_2(WanCheckpointer): + + def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: + if step is None: + step = self.checkpoint_manager.latest_step() + max_logging.log(f"Latest WAN checkpoint step: {step}") + if step is None: + max_logging.log("No WAN checkpoint found.") + return None, None + max_logging.log(f"Loading WAN checkpoint from step {step}") + metadatas = self.checkpoint_manager.item_metadata(step) + + # Handle low_noise_transformer + low_noise_transformer_metadata = metadatas.low_noise_transformer_state + abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) + low_params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_low_params, + ) + ) + + # Handle high_noise_transformer + high_noise_transformer_metadata = metadatas.high_noise_transformer_state + abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) + high_params_restore = ocp.args.PyTreeRestore( + restore_args=jax.tree.map( + lambda _: ocp.RestoreArgs(restore_type=np.ndarray), + abstract_tree_structure_high_params, + ) + ) + + max_logging.log("Restoring WAN 2.2 checkpoint") + restored_checkpoint = self.checkpoint_manager.restore( + directory=epath.Path(self.config.checkpoint_dir), + step=step, + args=ocp.args.Composite( + low_noise_transformer_state=low_params_restore, + high_noise_transformer_state=high_params_restore, + wan_config=ocp.args.JsonRestore(), + ), + ) + max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") + max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") + max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") + max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") + max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") + return restored_checkpoint, step + + def load_diffusers_checkpoint(self): + pipeline = WanPipeline2_2.from_pretrained(self.config) + return pipeline + + def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_2, Optional[dict], Optional[int]]: + restored_checkpoint, step = self.load_wan_configs_from_orbax(step) + opt_state = None + if restored_checkpoint: + max_logging.log("Loading WAN pipeline from checkpoint") + pipeline = WanPipeline2_2.from_checkpoint(self.config, restored_checkpoint) + # Check for optimizer state in either transformer + if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): + opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] + elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): + opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] + else: + max_logging.log("No checkpoint found, loading default pipeline.") + pipeline = self.load_diffusers_checkpoint() + + return pipeline, opt_state, step + + def save_checkpoint(self, train_step, pipeline: WanPipeline2_2, train_states: dict): + """Saves the training state and model configurations.""" + + def config_to_json(model_or_config): + return json.loads(model_or_config.to_json_string()) + + max_logging.log(f"Saving checkpoint for step {train_step}") + items = { + "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), + } + + items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"]) + items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"]) + + # Save the checkpoint + self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) + max_logging.log(f"Checkpoint for step {train_step} saved.") + +def save_checkpoint_orig(self, train_step, pipeline, train_states: dict): """Saves the training state and model configurations.""" def config_to_json(model_or_config): diff --git a/src/maxdiffusion/checkpointing/wan_checkpointer2_2.py b/src/maxdiffusion/checkpointing/wan_checkpointer2_2.py deleted file mode 100644 index de8bb35d..00000000 --- a/src/maxdiffusion/checkpointing/wan_checkpointer2_2.py +++ /dev/null @@ -1,207 +0,0 @@ -""" - Copyright 2025 Google LLC - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - https://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -""" - -from abc import ABC -import json - -import jax -import numpy as np -from typing import Optional, Tuple -from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager) -from ..pipelines.wan.wan_pipeline2_2 import WanPipeline -from .. import max_logging, max_utils -import orbax.checkpoint as ocp -from etils import epath - -WAN_CHECKPOINT = "WAN_CHECKPOINT" - - -class WanCheckpointer(ABC): - - def __init__(self, config, checkpoint_type): - self.config = config - self.checkpoint_type = checkpoint_type - self.opt_state = None - - self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager( - self.config.checkpoint_dir, - enable_checkpointing=True, - save_interval_steps=1, - checkpoint_type=checkpoint_type, - dataset_type=config.dataset_type, - ) - - def _create_optimizer(self, model, config, learning_rate): - learning_rate_scheduler = max_utils.create_learning_rate_schedule( - learning_rate, config.learning_rate_schedule_steps, config.warmup_steps_fraction, config.max_train_steps - ) - tx = max_utils.create_optimizer(config, learning_rate_scheduler) - return tx, learning_rate_scheduler - - def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]: - if step is None: - step = self.checkpoint_manager.latest_step() - max_logging.log(f"Latest WAN checkpoint step: {step}") - if step is None: - max_logging.log("No WAN checkpoint found.") - return None, None - max_logging.log(f"Loading WAN checkpoint from step {step}") - metadatas = self.checkpoint_manager.item_metadata(step) - - low_noise_transformer_metadata = metadatas.low_noise_transformer_state - abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata) - low_params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_low_params, - ) - ) - - high_noise_transformer_metadata = metadatas.high_noise_transformer_state - abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata) - high_params_restore = ocp.args.PyTreeRestore( - restore_args=jax.tree.map( - lambda _: ocp.RestoreArgs(restore_type=np.ndarray), - abstract_tree_structure_high_params, - ) - ) - - max_logging.log("Restoring WAN checkpoint") - restored_checkpoint = self.checkpoint_manager.restore( - directory=epath.Path(self.config.checkpoint_dir), - step=step, - args=ocp.args.Composite( - low_noise_transformer_state=low_params_restore, - high_noise_transformer_state=high_params_restore, - wan_config=ocp.args.JsonRestore(), - ), - ) - max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}") - max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}") - max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}") - max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}") - return restored_checkpoint, step - - def load_diffusers_checkpoint(self): - pipeline = WanPipeline.from_pretrained(self.config) - return pipeline - - def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]: - restored_checkpoint, step = self.load_wan_configs_from_orbax(step) - opt_state = None - if restored_checkpoint: - max_logging.log("Loading WAN pipeline from checkpoint") - pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint) - # Check for optimizer state in either transformer - if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys(): - opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"] - elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys(): - opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"] - else: - max_logging.log("No checkpoint found, loading default pipeline.") - pipeline = self.load_diffusers_checkpoint() - - return pipeline, opt_state, step - - def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict): - """Saves the training state and model configurations.""" - - def config_to_json(model_or_config): - return json.loads(model_or_config.to_json_string()) - - max_logging.log(f"Saving checkpoint for step {train_step}") - items = { - "wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)), - } - - items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"]) - items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"]) - - # Save the checkpoint - self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items)) - max_logging.log(f"Checkpoint for step {train_step} saved.") - - -def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict): - """Saves the training state and model configurations.""" - - def config_to_json(model_or_config): - """ - only save the config that is needed and can be serialized to JSON. - """ - if not hasattr(model_or_config, "config"): - return None - source_config = dict(model_or_config.config) - - # 1. configs that can be serialized to JSON - SAFE_KEYS = [ - "_class_name", - "_diffusers_version", - "model_type", - "patch_size", - "num_attention_heads", - "attention_head_dim", - "in_channels", - "out_channels", - "text_dim", - "freq_dim", - "ffn_dim", - "num_layers", - "cross_attn_norm", - "qk_norm", - "eps", - "image_dim", - "added_kv_proj_dim", - "rope_max_seq_len", - "pos_embed_seq_len", - "flash_min_seq_length", - "flash_block_sizes", - "attention", - "_use_default_values", - ] - - # 2. save the config that are in the SAFE_KEYS list - clean_config = {} - for key in SAFE_KEYS: - if key in source_config: - clean_config[key] = source_config[key] - - # 3. deal with special data type and precision - if "dtype" in source_config and hasattr(source_config["dtype"], "name"): - clean_config["dtype"] = source_config["dtype"].name # e.g 'bfloat16' - - if "weights_dtype" in source_config and hasattr(source_config["weights_dtype"], "name"): - clean_config["weights_dtype"] = source_config["weights_dtype"].name - - if "precision" in source_config and isinstance(source_config["precision"]): - clean_config["precision"] = source_config["precision"].name # e.g. 'HIGHEST' - - return clean_config - - items_to_save = { - "transformer_config": ocp.args.JsonSave(config_to_json(pipeline.transformer)), - } - - items_to_save["transformer_states"] = ocp.args.PyTreeSave(train_states) - - # Create CompositeArgs for Orbax - save_args = ocp.args.Composite(**items_to_save) - - # Save the checkpoint - self.checkpoint_manager.save(train_step, args=save_args) - max_logging.log(f"Checkpoint for step {train_step} saved.") diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index eb4895e9..2a998d0c 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -281,9 +281,9 @@ width: 832 num_frames: 81 flow_shift: 3.0 -guidance_scale_low: 5.0 -guidance_scale_high: 8.0 -boundary_timestep: 15 +guidance_scale_low: 3.0 +guidance_scale_high: 4.0 +boundary_timestep: 875 # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf guidance_rescale: 0.0 diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index 442d7887..dabbf4b9 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -16,9 +16,9 @@ import jax import time import os +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer2_1, WanCheckpointer2_2 from maxdiffusion import pyconfig, max_logging, max_utils from absl import app -import importlib from maxdiffusion.utils import export_to_video from google.cloud import storage import flax @@ -72,22 +72,6 @@ def delete_file(file_path: str): os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" ) -def get_pipeline(model_name: str): - if model_name == "wan2.1": - return importlib.import_module("maxdiffusion.pipelines.wan.wan_pipeline") - elif model_name == "wan2.2": - return importlib.import_module("maxdiffusion.pipelines.wan.wan_pipeline2_2") - else: - raise ValueError(f"Unsupported model_name in config: {model_name}") - -def get_checkpointer(model_name: str): - if model_name == "wan2.1": - return importlib.import_module("maxdiffusion.checkpointing.wan_checkpointer") - elif model_name == "wan2.2": - return importlib.import_module("maxdiffusion.checkpointing.wan_checkpointer2_2") - else: - raise ValueError(f"Unsupported model_name in config: {model_name}") - def call_pipeline(config, pipeline, prompt, negative_prompt): model_key = config.model_name if model_key == "wan2.1": @@ -140,21 +124,18 @@ def inference_generate_video(config, pipeline, filename_prefix=""): def run(config, pipeline=None, filename_prefix=""): model_key = config.model_name - # Initialize TensorBoard writer writer = max_utils.initialize_summary_writer(config) if jax.process_index() == 0 and writer: max_logging.log(f"TensorBoard logs will be written to: {config.tensorboard_dir}") - checkpointer_lib = get_checkpointer(model_key) - WanCheckpointer = checkpointer_lib.WanCheckpointer - - checkpoint_loader = WanCheckpointer(config, "WAN_CHECKPOINT") - pipeline, _, _ = checkpoint_loader.load_checkpoint() - if pipeline is None: - pipeline_lib = get_pipeline(model_key) - WanPipeline = pipeline_lib.WanPipeline - pipeline = WanPipeline.from_pretrained(config) + if model_key == "wan2.1": + checkpoint_loader = WanCheckpointer2_1(config=config) + elif model_key == "wan2.2": + checkpoint_loader = WanCheckpointer2_2(config=config) + else: + raise ValueError(f"Unsupported model_name for checkpointer: {model_key}") + pipeline, _, _ = checkpoint_loader.load_checkpoint() s0 = time.perf_counter() # Using global_batch_size_to_train_on so not to create more config variables @@ -164,8 +145,8 @@ def run(config, pipeline=None, filename_prefix=""): max_logging.log( f"Num steps: {config.num_inference_steps}, height: {config.height}, width: {config.width}, frames: {config.num_frames}" ) - videos = call_pipeline(config, pipeline, prompt, negative_prompt) + max_logging.log("===================== Model details =======================") max_logging.log(f"model name: {config.model_name}") max_logging.log(f"model path: {config.pretrained_model_name_or_path}") @@ -201,8 +182,6 @@ def run(config, pipeline=None, filename_prefix=""): max_logging.log(f"generation time per video: {generation_time_per_video}") else: max_logging.log("Warning: Number of videos is zero, cannot calculate generation_time_per_video.") - - s0 = time.perf_counter() if config.enable_profiler: max_utils.activate_profiler(config) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 7ed8007b..23ff46b8 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import abstractmethod from typing import List, Union, Optional from functools import partial import numpy as np @@ -100,7 +101,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): if restored_checkpoint: wan_config = restored_checkpoint["wan_config"] else: - wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder="transformer") + wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder) wan_config["mesh"] = mesh wan_config["dtype"] = config.activations_dtype wan_config["weights_dtype"] = config.weights_dtype @@ -188,12 +189,10 @@ class WanPipeline: vae ([`AutoencoderKLWan`]): Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. """ - def __init__( self, tokenizer: AutoTokenizer, text_encoder: UMT5EncoderModel, - transformer: WanModel, vae: AutoencoderKLWan, vae_cache: AutoencoderKLWanCache, scheduler: FlaxUniPCMultistepScheduler, @@ -204,7 +203,6 @@ def __init__( ): self.tokenizer = tokenizer self.text_encoder = text_encoder - self.transformer = transformer self.vae = vae self.vae_cache = vae_cache self.scheduler = scheduler @@ -212,6 +210,7 @@ def __init__( self.devices_array = devices_array self.mesh = mesh self.config = config + self.model_name = config.model_name self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 @@ -371,84 +370,6 @@ def load_scheduler(cls, config): ) return scheduler, scheduler_state - @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - transformer = None - tokenizer = None - scheduler = None - scheduler_state = None - text_encoder = None - if not vae_only: - if load_transformer: - with mesh: - transformer = cls.load_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer" - ) - - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - return WanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - transformer=transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, - config=config, - ) - - @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - transformer = None - tokenizer = None - scheduler = None - scheduler_state = None - text_encoder = None - if not vae_only: - if load_transformer: - with mesh: - transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") - - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - pipeline = WanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - transformer=transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, - config=config, - ) - - pipeline.transformer = cls.quantize_transformer(config, pipeline.transformer, pipeline, mesh) - return pipeline def _get_t5_prompt_embeds( self, @@ -538,22 +459,66 @@ def prepare_latents( return latents - def __call__( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale: float = 5.0, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: jax.Array = None, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, - ): + def _denormalize_latents(self, latents: jax.Array) -> jax.Array: + """Denormalizes latents using VAE statistics.""" + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) + latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) + latents = latents / latents_std + latents_mean + latents = latents.astype(jnp.float32) + return latents + + def _decode_latents_to_video(self, latents: jax.Array) -> np.ndarray: + """Decodes latents to video frames and postprocesses.""" + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + video = self.vae.decode(latents, self.vae_cache)[0] + + video = jnp.transpose(video, (0, 4, 1, 2, 3)) + video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) + video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) + return self.video_processor.postprocess_video(video, output_type="np") + + @classmethod + def _create_common_components(cls, config, vae_only=False): + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + rng = jax.random.key(config.seed) + rngs = nnx.Rngs(rng) + + with mesh: + wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) + + components = { + "vae": wan_vae, "vae_cache": vae_cache, + "devices_array": devices_array, "rngs": rngs, "mesh": mesh, + "tokenizer": None, "text_encoder": None, "scheduler": None, "scheduler_state": None + } + + if not vae_only: + components["tokenizer"] = cls.load_tokenizer(config=config) + components["text_encoder"] = cls.load_text_encoder(config=config) + components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) + return components + + @abstractmethod + def _get_num_channel_latents(self) -> int: + """Returns the number of input channels for the transformer.""" + pass + + def _prepare_call_inputs( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, + ): if not vae_only: if num_frames % self.vae_scale_factor_temporal != 1: max_logging.log( @@ -577,7 +542,7 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - num_channel_latents = self.transformer.config.in_channels + num_channel_latents = self._get_num_channel_latents() if latents is None: latents = self.prepare_latents( batch_size=batch_size, @@ -602,40 +567,235 @@ def __call__( self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape ) - graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) + return latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames - p_run_inference = partial( - run_inference, - guidance_scale=guidance_scale, - num_inference_steps=num_inference_steps, - scheduler=self.scheduler, - scheduler_state=scheduler_state, - num_transformer_layers=self.transformer.config.num_layers, - ) + @abstractmethod + def __call__(self, **kwargs): + """Runs the inference pipeline.""" + pass - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - latents = p_run_inference( - graphdef=graphdef, - sharded_state=state, - rest_of_state=rest_of_state, - latents=latents, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, +class WanPipeline2_1(WanPipeline): + """Pipeline for WAN 2.1 with a single transformer.""" + def __init__(self, config: HyperParameters, transformer: Optional[WanModel], **kwargs): + super().__init__(config=config, **kwargs) + self.transformer = transformer + + @classmethod + def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): + common_components = cls._create_common_components(config, vae_only) + transformer = None + if not vae_only: + if load_transformer: + transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer" + ) + + pipeline = cls( + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + transformer=transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, ) - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) - latents = latents / latents_std + latents_mean - latents = latents.astype(jnp.float32) + + return pipeline, transformer + + @classmethod + def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): + pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer) + transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh) + return pipeline + + @classmethod + def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) + return pipeline + + def _get_num_channel_latents(self) -> int: + return self.transformer.config.in_channels + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: Optional[jax.Array] = None, + prompt_embeds: Optional[jax.Array] = None, + negative_prompt_embeds: Optional[jax.Array] = None, + vae_only: bool = False, + ): + latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_call_inputs( + prompt, + negative_prompt, + height, + width, + num_frames, + num_inference_steps, + num_videos_per_prompt, + max_sequence_length, + latents, + prompt_embeds, + negative_prompt_embeds, + vae_only, + ) + + graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) + + p_run_inference = partial( + run_inference_2_1, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + scheduler_state=scheduler_state, + ) with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - video = self.vae.decode(latents, self.vae_cache)[0] + latents = p_run_inference( + graphdef=graphdef, + sharded_state=state, + rest_of_state=rest_of_state, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + latents = self._denormalize_latents(latents) + return self._decode_latents_to_video(latents) + +class WanPipeline2_2(WanPipeline): + """Pipeline for WAN 2.2 with dual transformers.""" + def __init__(self, config: HyperParameters, low_noise_transformer: Optional[WanModel], high_noise_transformer: Optional[WanModel], **kwargs): + super().__init__(config=config, **kwargs) + self.low_noise_transformer = low_noise_transformer + self.high_noise_transformer = high_noise_transformer - video = jnp.transpose(video, (0, 4, 1, 2, 3)) - video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) - video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) - video = self.video_processor.postprocess_video(video, output_type="np") - return video + @classmethod + def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_transformer=True): + common_components = cls._create_common_components(config, vae_only) + low_noise_transformer, high_noise_transformer = None, None + if not vae_only and load_transformer: + low_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer" + ) + high_noise_transformer = super().load_transformer( + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + rngs=common_components["rngs"], + config=config, + restored_checkpoint=restored_checkpoint, + subfolder="transformer_2" + ) + pipeline = cls( + tokenizer=common_components["tokenizer"], + text_encoder=common_components["text_encoder"], + low_noise_transformer=low_noise_transformer, + high_noise_transformer=high_noise_transformer, + vae=common_components["vae"], + vae_cache=common_components["vae_cache"], + scheduler=common_components["scheduler"], + scheduler_state=common_components["scheduler_state"], + devices_array=common_components["devices_array"], + mesh=common_components["mesh"], + config=config, + ) + return pipeline, low_noise_transformer, high_noise_transformer + + @classmethod + def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): + pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, None, vae_only, load_transformer) + low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh) + high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh) + return pipeline + + @classmethod + def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) + return pipeline + + def _get_num_channel_latents(self) -> int: + return self.low_noise_transformer.config.in_channels + + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: Union[str, List[str]] = None, + height: int = 480, + width: int = 832, + num_frames: int = 81, + num_inference_steps: int = 50, + guidance_scale_low: float = 3.0, + guidance_scale_high: float = 4.0, + boundary: int = 875, + num_videos_per_prompt: Optional[int] = 1, + max_sequence_length: int = 512, + latents: jax.Array = None, + prompt_embeds: jax.Array = None, + negative_prompt_embeds: jax.Array = None, + vae_only: bool = False, + ): + latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_call_inputs( + prompt, + negative_prompt, + height, + width, + num_frames, + num_inference_steps, + num_videos_per_prompt, + max_sequence_length, + latents, + prompt_embeds, + negative_prompt_embeds, + vae_only, + ) + + low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) + high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) + + p_run_inference = partial( + run_inference_2_2, + guidance_scale_low=guidance_scale_low, + guidance_scale_high=guidance_scale_high, + boundary=boundary, + num_inference_steps=num_inference_steps, + scheduler=self.scheduler, + scheduler_state=scheduler_state, + ) + + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + latents = p_run_inference( + low_noise_graphdef=low_noise_graphdef, + low_noise_state=low_noise_state, + low_noise_rest=low_noise_rest, + high_noise_graphdef=high_noise_graphdef, + high_noise_state=high_noise_state, + high_noise_rest=high_noise_rest, + latents=latents, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + latents = self._denormalize_latents(latents) + return self._decode_latents_to_video(latents) @partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) def transformer_forward_pass( @@ -659,8 +819,7 @@ def transformer_forward_pass( return noise_pred, latents - -def run_inference( +def run_inference_2_1( graphdef, sharded_state, rest_of_state, @@ -670,7 +829,6 @@ def run_inference( guidance_scale: float, num_inference_steps: int, scheduler: FlaxUniPCMultistepScheduler, - num_transformer_layers: int, scheduler_state, ): do_classifier_free_guidance = guidance_scale > 1.0 @@ -695,3 +853,58 @@ def run_inference( latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents + +def run_inference_2_2( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + high_noise_graphdef, + high_noise_state, + high_noise_rest, + latents: jnp.array, + prompt_embeds: jnp.array, + negative_prompt_embeds: jnp.array, + guidance_scale_low: float, + guidance_scale_high: float, + boundary: int, + num_inference_steps: int, + scheduler: FlaxUniPCMultistepScheduler, + scheduler_state, +): + do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 + if do_classifier_free_guidance: + prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) + + def low_noise_branch(operands): + latents, timestep, prompt_embeds = operands + return transformer_forward_pass( + low_noise_graphdef, low_noise_state, low_noise_rest, + latents, timestep, prompt_embeds, + do_classifier_free_guidance, guidance_scale_low + ) + + def high_noise_branch(operands): + latents, timestep, prompt_embeds = operands + return transformer_forward_pass( + high_noise_graphdef, high_noise_state, high_noise_rest, + latents, timestep, prompt_embeds, + do_classifier_free_guidance, guidance_scale_high + ) + + for step in range(num_inference_steps): + t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + if do_classifier_free_guidance: + latents = jnp.concatenate([latents] * 2) + timestep = jnp.broadcast_to(t, latents.shape[0]) + + use_high_noise = jnp.greater_equal(t, boundary) + + noise_pred, latents = jax.lax.cond( + use_high_noise, + high_noise_branch, + low_noise_branch, + (latents, timestep, prompt_embeds) + ) + + latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() + return latents diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py deleted file mode 100644 index 0645aeeb..00000000 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline2_2.py +++ /dev/null @@ -1,725 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Union, Optional -from functools import partial -import numpy as np -import jax -import jax.numpy as jnp -from jax.sharding import Mesh, NamedSharding, PartitionSpec as P -import flax -import flax.linen as nn -from flax import nnx -from flax.linen import partitioning as nn_partitioning -from ...pyconfig import HyperParameters -from ... import max_logging -from ... import max_utils -from ...max_utils import get_flash_block_sizes, get_precision, device_put_replicated -from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae -from ...models.wan.transformers.transformer_wan import WanModel -from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache -from maxdiffusion.video_processor import VideoProcessor -from ...schedulers.scheduling_unipc_multistep_flax import FlaxUniPCMultistepScheduler, UniPCMultistepSchedulerState -from transformers import AutoTokenizer, UMT5EncoderModel -from maxdiffusion.utils.import_utils import is_ftfy_available -from maxdiffusion.maxdiffusion_utils import get_dummy_wan_inputs -import html -import re -import torch -import qwix - - -def cast_with_exclusion(path, x, dtype_to_cast): - """ - Casts arrays to dtype_to_cast, but keeps params from any 'norm' layer in float32. - """ - - exclusion_keywords = [ - "norm", # For all LayerNorm/GroupNorm layers - "condition_embedder", # The entire time/text conditioning module - "scale_shift_table", # Catches both the final and the AdaLN tables - ] - - path_str = ".".join(str(k.key) if isinstance(k, jax.tree_util.DictKey) else str(k) for k in path) - - if any(keyword in path_str.lower() for keyword in exclusion_keywords): - print("is_norm_path: ", path) - # Keep LayerNorm/GroupNorm weights and biases in full precision - return x.astype(jnp.float32) - else: - # Cast everything else to dtype_to_cast - return x.astype(dtype_to_cast) - - -def basic_clean(text): - if is_ftfy_available(): - import ftfy - - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r"\s+", " ", text) - text = text.strip() - return text - - -def prompt_clean(text): - text = whitespace_clean(basic_clean(text)) - return text - - -def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.VariableState: - vs.sharding_rules = logical_axis_rules - return vs - - -# For some reason, jitting this function increases the memory significantly, so instead manually move weights to device. -def create_sharded_logical_transformer( - devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder: str = "" -): - - def create_model(rngs: nnx.Rngs, wan_config: dict): - wan_transformer = WanModel(**wan_config, rngs=rngs) - return wan_transformer - - # 1. Load config. - if restored_checkpoint: - wan_config = restored_checkpoint["wan_config"] - else: - wan_config = WanModel.load_config(config.pretrained_model_name_or_path, subfolder=subfolder) - wan_config["mesh"] = mesh - wan_config["dtype"] = config.activations_dtype - wan_config["weights_dtype"] = config.weights_dtype - wan_config["attention"] = config.attention - wan_config["precision"] = get_precision(config) - wan_config["flash_block_sizes"] = get_flash_block_sizes(config) - wan_config["remat_policy"] = config.remat_policy - wan_config["names_which_can_be_saved"] = config.names_which_can_be_saved - wan_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded - wan_config["flash_min_seq_length"] = config.flash_min_seq_length - wan_config["dropout"] = config.dropout - wan_config["scan_layers"] = config.scan_layers - - # 2. eval_shape - will not use flops or create weights on device - # thus not using HBM memory. - p_model_factory = partial(create_model, wan_config=wan_config) - wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs) - graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) - - # 3. retrieve the state shardings, mapping logical names to mesh axis names. - logical_state_spec = nnx.get_partition_spec(state) - logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) - logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) - params = state.to_pure_dict() - state = dict(nnx.to_flat_state(state)) - - # 4. Load pretrained weights and move them to device using the state shardings from (3) above. - # This helps with loading sharded weights directly into the accelerators without fist copying them - # all to one device and then distributing them, thus using low HBM memory. - if restored_checkpoint: - if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer - params = restored_checkpoint["wan_state"]["params"] - else: # if not checkpointed with optimizer - params = restored_checkpoint["wan_state"] - else: - params = load_wan_transformer( - config.wan_transformer_pretrained_model_name_or_path, - params, - "cpu", - num_layers=wan_config["num_layers"], - scan_layers=config.scan_layers, - subfolder=subfolder, - ) - - params = jax.tree_util.tree_map_with_path( - lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), params - ) - for path, val in flax.traverse_util.flatten_dict(params).items(): - if restored_checkpoint: - path = path[:-1] - sharding = logical_state_sharding[path].value - state[path].value = device_put_replicated(val, sharding) - state = nnx.from_flat_state(state) - - wan_transformer = nnx.merge(graphdef, state, rest_of_state) - return wan_transformer - - -@nnx.jit(static_argnums=(1,), donate_argnums=(0,)) -def create_sharded_logical_model(model, logical_axis_rules): - graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...) - p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=logical_axis_rules) - state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) - pspecs = nnx.get_partition_spec(state) - sharded_state = jax.lax.with_sharding_constraint(state, pspecs) - model = nnx.merge(graphdef, sharded_state, rest_of_state) - return model - - -class WanPipeline: - r""" - Pipeline for text-to-video generation using Wan. - - tokenizer ([`T5Tokenizer`]): - Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer), - specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - text_encoder ([`T5EncoderModel`]): - [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically - the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant. - transformer ([`WanModel`]): - Conditional Transformer to denoise the input latents. - scheduler ([`FlaxUniPCMultistepScheduler`]): - A scheduler to be used in combination with `transformer` to denoise the encoded image latents. - vae ([`AutoencoderKLWan`]): - Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. - """ - - def __init__( - self, - tokenizer: AutoTokenizer, - text_encoder: UMT5EncoderModel, - low_noise_transformer: WanModel, - high_noise_transformer: WanModel, - vae: AutoencoderKLWan, - vae_cache: AutoencoderKLWanCache, - scheduler: FlaxUniPCMultistepScheduler, - scheduler_state: UniPCMultistepSchedulerState, - devices_array: np.array, - mesh: Mesh, - config: HyperParameters, - ): - self.tokenizer = tokenizer - self.text_encoder = text_encoder - self.low_noise_transformer = low_noise_transformer - self.high_noise_transformer = high_noise_transformer - self.vae = vae - self.vae_cache = vae_cache - self.scheduler = scheduler - self.scheduler_state = scheduler_state - self.devices_array = devices_array - self.mesh = mesh - self.config = config - - self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 - self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) - - self.p_run_inference = None - - @classmethod - def load_text_encoder(cls, config: HyperParameters): - text_encoder = UMT5EncoderModel.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="text_encoder", - ) - return text_encoder - - @classmethod - def load_tokenizer(cls, config: HyperParameters): - tokenizer = AutoTokenizer.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="tokenizer", - ) - return tokenizer - - @classmethod - def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): - - def create_model(rngs: nnx.Rngs, config: HyperParameters): - wan_vae = AutoencoderKLWan.from_config( - config.pretrained_model_name_or_path, - subfolder="vae", - rngs=rngs, - mesh=mesh, - dtype=jnp.float32, - weights_dtype=jnp.float32, - ) - return wan_vae - - # 1. eval shape - p_model_factory = partial(create_model, config=config) - wan_vae = nnx.eval_shape(p_model_factory, rngs=rngs) - graphdef, state = nnx.split(wan_vae, nnx.Param) - - # 2. retrieve the state shardings, mapping logical names to mesh axis names. - logical_state_spec = nnx.get_partition_spec(state) - logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) - logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) - params = state.to_pure_dict() - state = dict(nnx.to_flat_state(state)) - - # 4. Load pretrained weights and move them to device using the state shardings from (3) above. - # This helps with loading sharded weights directly into the accelerators without fist copying them - # all to one device and then distributing them, thus using low HBM memory. - params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") - params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) - for path, val in flax.traverse_util.flatten_dict(params).items(): - sharding = logical_state_sharding[path].value - if config.replicate_vae: - sharding = NamedSharding(mesh, P()) - state[path].value = device_put_replicated(val, sharding) - state = nnx.from_flat_state(state) - - wan_vae = nnx.merge(graphdef, state) - vae_cache = AutoencoderKLWanCache(wan_vae) - return wan_vae, vae_cache - - @classmethod - def get_basic_config(cls, dtype, config: HyperParameters): - rules = [ - qwix.QtRule( - module_path=config.qwix_module_path, - weight_qtype=dtype, - act_qtype=dtype, - op_names=("dot_general", "einsum", "conv_general_dilated"), - ) - ] - return rules - - @classmethod - def get_fp8_config(cls, config: HyperParameters): - """ - fp8 config rules with per-tensor calibration. - FLAX API (https://flax-linen.readthedocs.io/en/v0.10.6/guides/quantization/fp8_basics.html#flax-low-level-api): - The autodiff does not automatically use E5M2 for gradients and E4M3 for activations/weights during training, which is the recommended practice. - """ - rules = [ - qwix.QtRule( - module_path=config.qwix_module_path, - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e5m2, - disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, - bwd_calibration_method=config.quantization_calibration_method, - op_names=("dot_general", "einsum"), - ), - qwix.QtRule( - module_path=config.qwix_module_path, - weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e4m3fn, - disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config.quantization_calibration_method, - act_calibration_method=config.quantization_calibration_method, - bwd_calibration_method=config.quantization_calibration_method, - op_names=("conv_general_dilated"), - ), - ] - return rules - - @classmethod - def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]: - """Get quantization rules based on the config.""" - if not getattr(config, "use_qwix_quantization", False): - return None - - match config.quantization: - case "int8": - return qwix.QtProvider(cls.get_basic_config(jnp.int8, config)) - case "fp8": - return qwix.QtProvider(cls.get_basic_config(jnp.float8_e4m3fn, config)) - case "fp8_full": - return qwix.QtProvider(cls.get_fp8_config(config)) - return None - - @classmethod - def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline: "WanPipeline", mesh: Mesh): - """Quantizes the transformer model.""" - q_rules = cls.get_qt_provider(config) - if not q_rules: - return model - max_logging.log("Quantizing transformer with Qwix.") - - batch_size = jnp.ceil(config.per_device_batch_size * jax.local_device_count()).astype(jnp.int32) - latents, prompt_embeds, timesteps = get_dummy_wan_inputs(config, pipeline, batch_size) - model_inputs = (latents, timesteps, prompt_embeds) - with mesh: - quantized_model = qwix.quantize_model(model, q_rules, *model_inputs) - max_logging.log("Qwix Quantization complete.") - return quantized_model - - @classmethod - def load_transformer( - cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, restored_checkpoint=None, subfolder="transformer"): - with mesh: - wan_transformer = create_sharded_logical_transformer( - devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder=subfolder - ) - return wan_transformer - - @classmethod - def load_scheduler(cls, config): - scheduler, scheduler_state = FlaxUniPCMultistepScheduler.from_pretrained( - config.pretrained_model_name_or_path, - subfolder="scheduler", - flow_shift=config.flow_shift, # 5.0 for 720p, 3.0 for 480p - ) - return scheduler, scheduler_state - - @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - low_noise_transformer = None - high_noise_transformer = None - tokenizer = None - scheduler = None - scheduler_state = None - text_encoder = None - if not vae_only: - if load_transformer: - with mesh: - low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer") - high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, restored_checkpoint=restored_checkpoint, subfolder="transformer_2") - - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - return WanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - low_noise_transformer=low_noise_transformer, - high_noise_transformer=high_noise_transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, - config=config, - ) - - @classmethod - def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True): - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - rng = jax.random.key(config.seed) - rngs = nnx.Rngs(rng) - low_noise_transformer = None - high_noise_transformer = None - tokenizer = None - scheduler = None - scheduler_state = None - text_encoder = None - if not vae_only: - if load_transformer: - with mesh: - low_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer") - high_noise_transformer = cls.load_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config, subfolder="transformer_2") - text_encoder = cls.load_text_encoder(config=config) - tokenizer = cls.load_tokenizer(config=config) - - scheduler, scheduler_state = cls.load_scheduler(config=config) - - with mesh: - wan_vae, vae_cache = cls.load_vae(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) - - pipeline = WanPipeline( - tokenizer=tokenizer, - text_encoder=text_encoder, - low_noise_transformer=low_noise_transformer, - high_noise_transformer=high_noise_transformer, - vae=wan_vae, - vae_cache=vae_cache, - scheduler=scheduler, - scheduler_state=scheduler_state, - devices_array=devices_array, - mesh=mesh, - config=config, - ) - - pipeline.low_noise_transformer = cls.quantize_transformer(config, pipeline.low_noise_transformer, pipeline, mesh) - pipeline.high_noise_transformer = cls.quantize_transformer(config, pipeline.high_noise_transformer, pipeline, mesh) - return pipeline - - def _get_t5_prompt_embeds( - self, - prompt: Union[str, List[str]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - ): - prompt = [prompt] if isinstance(prompt, str) else prompt - prompt = [prompt_clean(u) for u in prompt] - batch_size = len(prompt) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=max_sequence_length, - truncation=True, - add_special_tokens=True, - return_attention_mask=True, - return_tensors="pt", - ) - text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask - seq_lens = mask.gt(0).sum(dim=1).long() - prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state - prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] - prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 - ) - - # duplicate text embeddings for each generation per prompt, using mps friendly method - _, seq_len, _ = prompt_embeds.shape - prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) - - return prompt_embeds - - def encode_prompt( - self, - prompt: Union[str, List[str]], - negative_prompt: Optional[Union[str, List[str]]] = None, - num_videos_per_prompt: int = 1, - max_sequence_length: int = 226, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - ): - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) - if prompt_embeds is None: - prompt_embeds = self._get_t5_prompt_embeds( - prompt=prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - ) - prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32) - - if negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - negative_prompt_embeds = self._get_t5_prompt_embeds( - prompt=negative_prompt, - num_videos_per_prompt=num_videos_per_prompt, - max_sequence_length=max_sequence_length, - ) - negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=jnp.float32) - - return prompt_embeds, negative_prompt_embeds - - def prepare_latents( - self, - batch_size: int, - vae_scale_factor_temporal: int, - vae_scale_factor_spatial: int, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_channels_latents: int = 16, - ): - rng = jax.random.key(self.config.seed) - num_latent_frames = (num_frames - 1) // vae_scale_factor_temporal + 1 - shape = ( - batch_size, - num_channels_latents, - num_latent_frames, - int(height) // vae_scale_factor_spatial, - int(width) // vae_scale_factor_spatial, - ) - latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) - - return latents - - def __call__( - self, - prompt: Union[str, List[str]] = None, - negative_prompt: Union[str, List[str]] = None, - height: int = 480, - width: int = 832, - num_frames: int = 81, - num_inference_steps: int = 50, - guidance_scale_low: float = 3.0, - guidance_scale_high: float = 4.0, - boundary: int = 875, - num_videos_per_prompt: Optional[int] = 1, - max_sequence_length: int = 512, - latents: jax.Array = None, - prompt_embeds: jax.Array = None, - negative_prompt_embeds: jax.Array = None, - vae_only: bool = False, - ): - if not vae_only: - if num_frames % self.vae_scale_factor_temporal != 1: - max_logging.log( - f"`num_frames -1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number." - ) - num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 - num_frames = max(num_frames, 1) - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - prompt = [prompt] - - batch_size = len(prompt) - - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt=prompt, - negative_prompt=negative_prompt, - max_sequence_length=max_sequence_length, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - - num_channel_latents = self.low_noise_transformer.config.in_channels - if latents is None: - latents = self.prepare_latents( - batch_size=batch_size, - vae_scale_factor_temporal=self.vae_scale_factor_temporal, - vae_scale_factor_spatial=self.vae_scale_factor_spatial, - height=height, - width=width, - num_frames=num_frames, - num_channels_latents=num_channel_latents, - ) - - data_sharding = NamedSharding(self.mesh, P()) - # Using global_batch_size_to_train_on so not to create more config variables - if self.config.global_batch_size_to_train_on // self.config.per_device_batch_size == 0: - data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding)) - - latents = jax.device_put(latents, data_sharding) - prompt_embeds = jax.device_put(prompt_embeds, data_sharding) - negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) - - scheduler_state = self.scheduler.set_timesteps( - self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape - ) - - low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) - high_noise_graphdef, high_noise_state, high_noise_rest = nnx.split(self.high_noise_transformer, nnx.Param, ...) - - p_run_inference = partial( - run_inference, - guidance_scale_low=guidance_scale_low, - guidance_scale_high=guidance_scale_high, - boundary=boundary, - num_inference_steps=num_inference_steps, - scheduler=self.scheduler, - scheduler_state=scheduler_state, - ) - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - latents = p_run_inference( - low_noise_graphdef=low_noise_graphdef, - low_noise_state=low_noise_state, - low_noise_rest=low_noise_rest, - high_noise_graphdef=high_noise_graphdef, - high_noise_state=high_noise_state, - high_noise_rest=high_noise_rest, - latents=latents, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - ) - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) - latents = latents / latents_std + latents_mean - latents = latents.astype(jnp.float32) - - with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): - video = self.vae.decode(latents, self.vae_cache)[0] - - video = jnp.transpose(video, (0, 4, 1, 2, 3)) - video = jax.experimental.multihost_utils.process_allgather(video, tiled=True) - video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) - video = self.video_processor.postprocess_video(video, output_type="np") - return video - - -@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) -def transformer_forward_pass( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, - prompt_embeds, - do_classifier_free_guidance, - guidance_scale, -): - wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - noise_pred = wan_transformer(hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds) - if do_classifier_free_guidance: - bsz = latents.shape[0] // 2 - noise_uncond = noise_pred[bsz:] - noise_pred = noise_pred[:bsz] - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) - latents = latents[:bsz] - - return noise_pred, latents - -def run_inference( - low_noise_graphdef, - low_noise_state, - low_noise_rest, - high_noise_graphdef, - high_noise_state, - high_noise_rest, - latents: jnp.array, - prompt_embeds: jnp.array, - negative_prompt_embeds: jnp.array, - guidance_scale_low: float, - guidance_scale_high: float, - boundary: int, - num_inference_steps: int, - scheduler: FlaxUniPCMultistepScheduler, - scheduler_state, -): - do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 - if do_classifier_free_guidance: - prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - - def low_noise_branch(operands): - latents, timestep, prompt_embeds = operands - return transformer_forward_pass( - low_noise_graphdef, low_noise_state, low_noise_rest, - latents, timestep, prompt_embeds, - do_classifier_free_guidance, guidance_scale_low - ) - - def high_noise_branch(operands): - latents, timestep, prompt_embeds = operands - return transformer_forward_pass( - high_noise_graphdef, high_noise_state, high_noise_rest, - latents, timestep, prompt_embeds, - do_classifier_free_guidance, guidance_scale_high - ) - - for step in range(num_inference_steps): - t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] - if do_classifier_free_guidance: - latents = jnp.concatenate([latents] * 2) - timestep = jnp.broadcast_to(t, latents.shape[0]) - use_high_noise = jnp.greater_equal(t, boundary) - - noise_pred, latents = jax.lax.cond( - use_high_noise, - high_noise_branch, - low_noise_branch, - (latents, timestep, prompt_embeds) - ) - - latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() - return latents diff --git a/src/maxdiffusion/tests/wan_checkpointer2_2_test.py b/src/maxdiffusion/tests/wan_checkpointer2_2_test.py deleted file mode 100644 index 8e1fa0be..00000000 --- a/src/maxdiffusion/tests/wan_checkpointer2_2_test.py +++ /dev/null @@ -1,113 +0,0 @@ -""" - Copyright 2025 Google LLC - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - """ - -import unittest -from unittest.mock import patch, MagicMock - -from maxdiffusion.checkpointing.wan_checkpointer2_2 import WanCheckpointer, WAN_CHECKPOINT - - -class WanCheckpointerTest(unittest.TestCase): - - def setUp(self): - self.config = MagicMock() - self.config.checkpoint_dir = "/tmp/wan_checkpoint_test" - self.config.dataset_type = "test_dataset" - - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline") - def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): - mock_manager = MagicMock() - mock_manager.latest_step.return_value = None - mock_create_manager.return_value = mock_manager - - mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance - - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) - pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) - - mock_manager.latest_step.assert_called_once() - mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config) - self.assertEqual(pipeline, mock_pipeline_instance) - self.assertIsNone(opt_state) - self.assertIsNone(step) - - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline") - def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): - mock_manager = MagicMock() - mock_manager.latest_step.return_value = 1 - metadata_mock = MagicMock() - metadata_mock.low_noise_transformer_state = {} - metadata_mock.high_noise_transformer_state = {} - mock_manager.item_metadata.return_value = metadata_mock - - restored_mock = MagicMock() - restored_mock.low_noise_transformer_state = {"params": {}} - restored_mock.high_noise_transformer_state = {"params": {}} - restored_mock.wan_config = {} - restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] - - mock_manager.restore.return_value = restored_mock - - mock_create_manager.return_value = mock_manager - - mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) - pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) - - mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) - mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) - self.assertEqual(pipeline, mock_pipeline_instance) - self.assertIsNone(opt_state) - self.assertEqual(step, 1) - - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer2_2.WanPipeline") - def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager): - mock_manager = MagicMock() - mock_manager.latest_step.return_value = 1 - metadata_mock = MagicMock() - metadata_mock.low_noise_transformer_state = {} - metadata_mock.high_noise_transformer_state = {} - mock_manager.item_metadata.return_value = metadata_mock - - restored_mock = MagicMock() - restored_mock.low_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.001}} - restored_mock.high_noise_transformer_state = {"params": {}} - restored_mock.wan_config = {} - restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] - - mock_manager.restore.return_value = restored_mock - - mock_create_manager.return_value = mock_manager - - mock_pipeline_instance = MagicMock() - mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) - pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) - - mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) - mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) - self.assertEqual(pipeline, mock_pipeline_instance) - self.assertIsNotNone(opt_state) - self.assertEqual(opt_state["learning_rate"], 0.001) - self.assertEqual(step, 1) - - -if __name__ == "__main__": - unittest.main() diff --git a/src/maxdiffusion/tests/wan_checkpointer_test.py b/src/maxdiffusion/tests/wan_checkpointer_test.py index ab5b5ca3..79f050c0 100644 --- a/src/maxdiffusion/tests/wan_checkpointer_test.py +++ b/src/maxdiffusion/tests/wan_checkpointer_test.py @@ -14,10 +14,10 @@ import unittest from unittest.mock import patch, MagicMock -from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer, WAN_CHECKPOINT +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer2_1, WanCheckpointer2_2 - -class WanCheckpointerTest(unittest.TestCase): +class WanCheckpointer2_1Test(unittest.TestCase): + """Tests for WAN 2.1 checkpointer.""" def setUp(self): self.config = MagicMock() @@ -25,7 +25,7 @@ def setUp(self): self.config.dataset_type = "test_dataset" @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = None @@ -34,7 +34,7 @@ def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + checkpointer = WanCheckpointer2_1(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) mock_manager.latest_step.assert_called_once() @@ -44,7 +44,7 @@ def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): self.assertIsNone(step) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 @@ -57,12 +57,6 @@ def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manag restored_mock.wan_config = {} restored_mock.keys.return_value = ["wan_state", "wan_config"] - def getitem_side_effect(key): - if key == "wan_state": - return restored_mock.wan_state - raise KeyError(key) - - restored_mock.__getitem__.side_effect = getitem_side_effect mock_manager.restore.return_value = restored_mock mock_create_manager.return_value = mock_manager @@ -70,7 +64,7 @@ def getitem_side_effect(key): mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + checkpointer = WanCheckpointer2_1(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) @@ -80,7 +74,7 @@ def getitem_side_effect(key): self.assertEqual(step, 1) @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") - @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager): mock_manager = MagicMock() mock_manager.latest_step.return_value = 1 @@ -93,12 +87,102 @@ def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_man restored_mock.wan_config = {} restored_mock.keys.return_value = ["wan_state", "wan_config"] - def getitem_side_effect(key): - if key == "wan_state": - return restored_mock.wan_state - raise KeyError(key) + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_1(config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.001) + self.assertEqual(step, 1) + + +class WanCheckpointer2_2Test(unittest.TestCase): + """Tests for WAN 2.2 checkpointer.""" + + def setUp(self): + self.config = MagicMock() + self.config.checkpoint_dir = "/tmp/wan_checkpoint_2_2_test" + self.config.dataset_type = "test_dataset" + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager): + """Test loading from pretrained when no checkpoint exists.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = None + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_2(config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) + + mock_manager.latest_step.assert_called_once() + mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertIsNone(step) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint without optimizer state.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}} + restored_mock.high_noise_transformer_state = {"params": {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] + + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_2(config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNone(opt_state) + self.assertEqual(step, 1) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_checkpoint_with_optimizer_in_low_noise(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint with optimizer state in low_noise_transformer.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.001}} + restored_mock.high_noise_transformer_state = {"params": {}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] - restored_mock.__getitem__.side_effect = getitem_side_effect mock_manager.restore.return_value = restored_mock mock_create_manager.return_value = mock_manager @@ -106,7 +190,7 @@ def getitem_side_effect(key): mock_pipeline_instance = MagicMock() mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance - checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT) + checkpointer = WanCheckpointer2_2(config=self.config) pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) @@ -116,6 +200,104 @@ def getitem_side_effect(key): self.assertEqual(opt_state["learning_rate"], 0.001) self.assertEqual(step, 1) + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_checkpoint_with_optimizer_in_high_noise(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint with optimizer state in high_noise_transformer.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}} + restored_mock.high_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.002}} + restored_mock.wan_config = {} + restored_mock.keys.return_value = ["low_noise_transformer_state", "high_noise_transformer_state", "wan_config"] + + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_2(config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY) + mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value) + self.assertEqual(pipeline, mock_pipeline_instance) + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.002) + self.assertEqual(step, 1) + + +class WanCheckpointerEdgeCasesTest(unittest.TestCase): + """Tests for edge cases and error handling.""" + + def setUp(self): + self.config = MagicMock() + self.config.checkpoint_dir = "/tmp/wan_checkpoint_edge_test" + self.config.dataset_type = "test_dataset" + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_1") + def test_load_checkpoint_with_explicit_none_step(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint with explicit None step falls back to latest.""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 5 + metadata_mock = MagicMock() + metadata_mock.wan_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.wan_state = {"params": {}} + restored_mock.wan_config = {} + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_1(config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=None) + + mock_manager.latest_step.assert_called_once() + self.assertEqual(step, 5) + + @patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager") + @patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline2_2") + def test_load_checkpoint_both_optimizers_present(self, mock_wan_pipeline, mock_create_manager): + """Test loading checkpoint when both transformers have optimizer state (prioritize low_noise).""" + mock_manager = MagicMock() + mock_manager.latest_step.return_value = 1 + metadata_mock = MagicMock() + metadata_mock.low_noise_transformer_state = {} + metadata_mock.high_noise_transformer_state = {} + mock_manager.item_metadata.return_value = metadata_mock + + restored_mock = MagicMock() + restored_mock.low_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.001}} + restored_mock.high_noise_transformer_state = {"params": {}, "opt_state": {"learning_rate": 0.002}} + restored_mock.wan_config = {} + mock_manager.restore.return_value = restored_mock + + mock_create_manager.return_value = mock_manager + + mock_pipeline_instance = MagicMock() + mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance + + checkpointer = WanCheckpointer2_2(config=self.config) + pipeline, opt_state, step = checkpointer.load_checkpoint(step=1) + + # Should prioritize low_noise_transformer's optimizer state + self.assertIsNotNone(opt_state) + self.assertEqual(opt_state["learning_rate"], 0.001) + if __name__ == "__main__": - unittest.main() + unittest.main(verbosity=2) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index fb01a4f4..deecfd43 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -29,7 +29,7 @@ from maxdiffusion.schedulers import FlaxFlowMatchScheduler from flax.linen import partitioning as nn_partitioning from maxdiffusion import max_utils, max_logging, train_utils -from maxdiffusion.checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT) +from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer2_1 from maxdiffusion.input_pipeline.input_pipeline_interface import (make_data_iterator) from maxdiffusion.generate_wan import run as generate_wan from maxdiffusion.generate_wan import inference_generate_video @@ -85,12 +85,16 @@ def print_ssim(pretrained_video_path, posttrained_video_path): max_logging.log(f"SSIM score after training is {ssim_compare}") -class WanTrainer(WanCheckpointer): +class WanTrainer: def __init__(self, config): - WanCheckpointer.__init__(self, config, WAN_CHECKPOINT) if config.train_text_encoder: raise ValueError("this script currently doesn't support training text_encoders") + self.config = config + model_key = config.model_name + if model_key != 'wan2.1': + raise ValueError(f"Unsupported model_name: '{model_key}'. This trainer only supports 'wan2.1'.") + self.checkpointer = WanCheckpointer2_1(config=config) def post_training_steps(self, pipeline, params, train_states, msg=""): pass @@ -210,7 +214,7 @@ def prepare_sample_eval(features): def start_training(self): - pipeline, opt_state, step = self.load_checkpoint() + pipeline, opt_state, step = self.checkpointer.load_checkpoint() restore_args = {} if opt_state and step: restore_args = {"opt_state": opt_state, "step": step} @@ -231,7 +235,7 @@ def start_training(self): scheduler, scheduler_state = self.create_scheduler() pipeline.scheduler = scheduler pipeline.scheduler_state = scheduler_state - optimizer, learning_rate_scheduler = self._create_optimizer(pipeline.transformer, self.config, 1e-5) + optimizer, learning_rate_scheduler = self.checkpointer._create_optimizer(pipeline.transformer, self.config, 1e-5) # Returns pipeline with trained transformer state pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args) @@ -392,9 +396,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data if step != 0 and self.config.checkpoint_every != -1 and step % self.config.checkpoint_every == 0: max_logging.log(f"Saving checkpoint for step {step}") if self.config.save_optimizer: - self.save_checkpoint(step, pipeline, state) + self.checkpointer.save_checkpoint(step, pipeline, state) else: - self.save_checkpoint(step, pipeline, state.params) + self.checkpointer.save_checkpoint(step, pipeline, state.params) _metrics_queue.put(None) writer_thread.join() @@ -402,8 +406,8 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data writer.flush() if self.config.save_final_checkpoint: max_logging.log(f"Saving final checkpoint for step {step}") - self.save_checkpoint(self.config.max_train_steps - 1, pipeline, state.params) - self.checkpoint_manager.wait_until_finished() + self.checkpointer.save_checkpoint(self.config.max_train_steps - 1, pipeline, state.params) + self.checkpointer.checkpoint_manager.wait_until_finished() # load new state for trained tranformer pipeline.transformer = nnx.merge(state.graphdef, state.params, state.rest_of_state) return pipeline