Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
996388a
Changes for WAN 2.2
prishajain1 Oct 25, 2025
c094a73
changes return type of checkpoint_loader to tuple
prishajain1 Oct 25, 2025
33bf49c
opt_state=None added
prishajain1 Oct 25, 2025
8a752e7
added model_name in config file
prishajain1 Oct 25, 2025
1be0361
double noise computation fixed
prishajain1 Oct 27, 2025
731b07b
support for wan2.1 in run_inference added
prishajain1 Oct 27, 2025
11d30fc
Support for WAN 2.2 added
prishajain1 Nov 11, 2025
c2d0fe0
Merge branch 'main' into wanpipeline_classes
prishajain1 Nov 11, 2025
ce17ed0
Removed extra files
prishajain1 Nov 11, 2025
b7aad0a
Updated README and generate_wan.py
prishajain1 Nov 12, 2025
16d657a
Added tensorboard logging for inference metrics
prishajain1 Nov 18, 2025
cc78cac
Fixed duplicate pipeline loading
prishajain1 Nov 20, 2025
c46fd87
Merge conflicts
prishajain1 Nov 20, 2025
f5e6b11
Merge remote-tracking branch 'upstream/main' into wanpipeline_classes
prishajain1 Nov 20, 2025
e55ccd2
ruff errors
prishajain1 Nov 20, 2025
ebc7eec
Changes to Wan trainer for compatibility with checkpointer
prishajain1 Nov 20, 2025
d6cdb1e
flash block size changed for testing
prishajain1 Nov 20, 2025
b3edab6
Revert "flash block size changed for testing"
prishajain1 Nov 20, 2025
2c494d1
Raise error for unsupported model training
prishajain1 Nov 20, 2025
5e642f8
Explicitly instantiate WanPipeline and WanCheckpointer subclasses
prishajain1 Nov 22, 2025
531e64d
ruff errors
prishajain1 Nov 22, 2025
598b0bc
Added commit_id to tensorboard logging
prishajain1 Nov 24, 2025
81dab27
Commit hash logging
prishajain1 Nov 24, 2025
6d3b597
Added enable_jax_named_scopes param for wan 2.2
prishajain1 Nov 25, 2025
101335d
Merge remote-tracking branch 'upstream/main' into wanpipeline_classes
prishajain1 Nov 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
[![Unit Tests](https://github.com/google/maxtext/actions/workflows/UnitTests.yml/badge.svg)](https://github.com/AI-Hypercomputer/maxdiffusion/actions/workflows/UnitTests.yml)

# What's new?
- **`2025/11/11`**: Wan2.2 txt2vid generation is now supported
- **`2025/10/10`**: Wan2.1 txt2vid training and generation is now supported.
- **`2025/10/14`**: NVIDIA DGX Spark Flux support.
- **`2025/8/14`**: LTX-Video img2vid generation is now supported.
Expand Down
143 changes: 127 additions & 16 deletions src/maxdiffusion/checkpointing/wan_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,37 @@
limitations under the License.
"""

from abc import ABC
from abc import ABC, abstractmethod
import json

import jax
import numpy as np
from typing import Optional, Tuple
from maxdiffusion.checkpointing.checkpointing_utils import (create_orbax_checkpoint_manager)
from ..pipelines.wan.wan_pipeline import WanPipeline
from ..pipelines.wan.wan_pipeline import WanPipeline2_1, WanPipeline2_2
from .. import max_logging, max_utils
import orbax.checkpoint as ocp
from etils import epath


WAN_CHECKPOINT = "WAN_CHECKPOINT"


class WanCheckpointer(ABC):

def __init__(self, config, checkpoint_type):
def __init__(self, config, checkpoint_type: str = WAN_CHECKPOINT):
self.config = config
self.checkpoint_type = checkpoint_type
self.opt_state = None

self.checkpoint_manager: ocp.CheckpointManager = create_orbax_checkpoint_manager(
self.config.checkpoint_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=checkpoint_type,
dataset_type=config.dataset_type,
self.checkpoint_manager: ocp.CheckpointManager = (
create_orbax_checkpoint_manager(
self.config.checkpoint_dir,
enable_checkpointing=True,
save_interval_steps=1,
checkpoint_type=checkpoint_type,
dataset_type=config.dataset_type,
)
)

def _create_optimizer(self, model, config, learning_rate):
Expand All @@ -51,6 +54,25 @@ def _create_optimizer(self, model, config, learning_rate):
tx = max_utils.create_optimizer(config, learning_rate_scheduler)
return tx, learning_rate_scheduler

@abstractmethod
def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
raise NotImplementedError

@abstractmethod
def load_diffusers_checkpoint(self):
raise NotImplementedError

@abstractmethod
def load_checkpoint(self, step=None) -> Tuple[Optional[WanPipeline2_1 | WanPipeline2_2], Optional[dict], Optional[int]]:
raise NotImplementedError

@abstractmethod
def save_checkpoint(self, train_step, pipeline, train_states: dict):
raise NotImplementedError


class WanCheckpointer2_1(WanCheckpointer):

def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
if step is None:
step = self.checkpoint_manager.latest_step()
Expand Down Expand Up @@ -85,24 +107,24 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
return restored_checkpoint, step

def load_diffusers_checkpoint(self):
pipeline = WanPipeline.from_pretrained(self.config)
pipeline = WanPipeline2_1.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None) -> Tuple[WanPipeline, Optional[dict], Optional[int]]:
def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_1, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = WanPipeline.from_checkpoint(self.config, restored_checkpoint)
if "opt_state" in restored_checkpoint["wan_state"].keys():
opt_state = restored_checkpoint["wan_state"]["opt_state"]
pipeline = WanPipeline2_1.from_checkpoint(self.config, restored_checkpoint)
if "opt_state" in restored_checkpoint.wan_state.keys():
opt_state = restored_checkpoint.wan_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline, opt_state, step

def save_checkpoint(self, train_step, pipeline: WanPipeline, train_states: dict):
def save_checkpoint(self, train_step, pipeline: WanPipeline2_1, train_states: dict):
"""Saves the training state and model configurations."""

def config_to_json(model_or_config):
Expand All @@ -120,7 +142,96 @@ def config_to_json(model_or_config):
max_logging.log(f"Checkpoint for step {train_step} saved.")


def save_checkpoint_orig(self, train_step, pipeline: WanPipeline, train_states: dict):
class WanCheckpointer2_2(WanCheckpointer):

def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dict], Optional[int]]:
if step is None:
step = self.checkpoint_manager.latest_step()
max_logging.log(f"Latest WAN checkpoint step: {step}")
if step is None:
max_logging.log("No WAN checkpoint found.")
return None, None
max_logging.log(f"Loading WAN checkpoint from step {step}")
metadatas = self.checkpoint_manager.item_metadata(step)

# Handle low_noise_transformer
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata)
low_params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
abstract_tree_structure_low_params,
)
)

# Handle high_noise_transformer
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata)
high_params_restore = ocp.args.PyTreeRestore(
restore_args=jax.tree.map(
lambda _: ocp.RestoreArgs(restore_type=np.ndarray),
abstract_tree_structure_high_params,
)
)

max_logging.log("Restoring WAN 2.2 checkpoint")
restored_checkpoint = self.checkpoint_manager.restore(
directory=epath.Path(self.config.checkpoint_dir),
step=step,
args=ocp.args.Composite(
low_noise_transformer_state=low_params_restore,
high_noise_transformer_state=high_params_restore,
wan_config=ocp.args.JsonRestore(),
),
)
max_logging.log(f"restored checkpoint {restored_checkpoint.keys()}")
max_logging.log(f"restored checkpoint low_noise_transformer_state {restored_checkpoint.low_noise_transformer_state.keys()}")
max_logging.log(f"restored checkpoint high_noise_transformer_state {restored_checkpoint.high_noise_transformer_state.keys()}")
max_logging.log(f"optimizer found in low_noise checkpoint {'opt_state' in restored_checkpoint.low_noise_transformer_state.keys()}")
max_logging.log(f"optimizer found in high_noise checkpoint {'opt_state' in restored_checkpoint.high_noise_transformer_state.keys()}")
max_logging.log(f"optimizer state saved in attribute self.opt_state {self.opt_state}")
return restored_checkpoint, step

def load_diffusers_checkpoint(self):
pipeline = WanPipeline2_2.from_pretrained(self.config)
return pipeline

def load_checkpoint(self, step=None) -> Tuple[WanPipeline2_2, Optional[dict], Optional[int]]:
restored_checkpoint, step = self.load_wan_configs_from_orbax(step)
opt_state = None
if restored_checkpoint:
max_logging.log("Loading WAN pipeline from checkpoint")
pipeline = WanPipeline2_2.from_checkpoint(self.config, restored_checkpoint)
# Check for optimizer state in either transformer
if "opt_state" in restored_checkpoint.low_noise_transformer_state.keys():
opt_state = restored_checkpoint.low_noise_transformer_state["opt_state"]
elif "opt_state" in restored_checkpoint.high_noise_transformer_state.keys():
opt_state = restored_checkpoint.high_noise_transformer_state["opt_state"]
else:
max_logging.log("No checkpoint found, loading default pipeline.")
pipeline = self.load_diffusers_checkpoint()

return pipeline, opt_state, step

def save_checkpoint(self, train_step, pipeline: WanPipeline2_2, train_states: dict):
"""Saves the training state and model configurations."""

def config_to_json(model_or_config):
return json.loads(model_or_config.to_json_string())

max_logging.log(f"Saving checkpoint for step {train_step}")
items = {
"wan_config": ocp.args.JsonSave(config_to_json(pipeline.low_noise_transformer)),
}

items["low_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["low_noise_transformer"])
items["high_noise_transformer_state"] = ocp.args.PyTreeSave(train_states["high_noise_transformer"])

# Save the checkpoint
self.checkpoint_manager.save(train_step, args=ocp.args.Composite(**items))
max_logging.log(f"Checkpoint for step {train_step} saved.")

def save_checkpoint_orig(self, train_step, pipeline, train_states: dict):
"""Saves the training state and model configurations."""

def config_to_json(model_or_config):
Expand Down
Loading