diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 084c2c4ff..0fc73cc1a 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -347,6 +347,56 @@ def __init__( def should_log(self, step: int) -> bool: return step == 1 or step % self.job_config.metrics.log_freq == 0 + def val_log( + self, + step: int, + loss: float, + extra_metrics: dict[str, Any] | None = None, + ): + + time_delta = time.perf_counter() - self.time_last_log + + # tokens per second per device, abbreviated as tps + tps = self.ntokens_since_last_log / ( + time_delta * self.parallel_dims.non_data_parallel_size + ) + + time_end_to_end = time_delta / self.job_config.metrics.log_freq + + device_mem_stats = self.device_memory_monitor.get_peak_stats() + + metrics = { + "val_loss_metrics/avg_loss": loss, + "val_throughput(tps)": tps, + "val_time_metrics/end_to_end(s)": time_end_to_end, + "val_memory/max_active(GiB)": device_mem_stats.max_active_gib, + "val_memory/max_active(%)": device_mem_stats.max_active_pct, + "val_memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, + "val_memory/max_reserved(%)": device_mem_stats.max_reserved_pct, + "val_memory/num_alloc_retries": device_mem_stats.num_alloc_retries, + "val_memory/num_ooms": device_mem_stats.num_ooms, + } + + + if extra_metrics: + metrics.update(extra_metrics) + + self.logger.log(metrics, step) + + color = self.color + logger.info( + f"{color.magenta}val: {step:2} " + f"{color.green}val loss: {loss:7.4f} " + f"{color.yellow}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%) " + f"{color.blue}tps: {round(tps):,} {color.reset}" + ) + + self.ntokens_since_last_log = 0 + self.data_loading_times.clear() + self.time_last_log = time.perf_counter() + self.device_memory_monitor.reset_peak_stats() + def log( self, step: int, diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 21f43c8f1..e9cdecc98 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -21,6 +21,36 @@ from torchtitan.tools.utils import device_module, device_type +def dist_collect( + x: torch.Tensor, + mesh: DeviceMesh, + extra_pg: dist.ProcessGroup | None = None, +) -> torch.Tensor: + """Collect and sum tensors across devices. + + Unlike _dist_reduce, this function returns the full tensor after reduction + rather than extracting a scalar item. + + Args: + x (torch.Tensor): Input tensor. + mesh (DeviceMesh): Device mesh to use for reduction. + extra_pg (dist.ProcessGroup, optional): Extra process group to use for reduction. + Defaults to None. If provided, this all_reduce will be called for the extra + process group, and then the result will be all_reduced for the mesh. + + Returns: + torch.Tensor: The summed tensor collected from all devices. + """ + if isinstance(x, DTensor): + # functional collectives do not support DTensor inputs + x = x.full_tensor() + + if extra_pg is not None: + x = funcol.all_reduce(x, reduceOp="sum", group=extra_pg) + + return funcol.all_reduce(x, reduceOp="sum", group=mesh) + + def _dist_reduce( x: torch.Tensor, reduceOp: str, diff --git a/torchtitan/experiments/flux/__init__.py b/torchtitan/experiments/flux/__init__.py index f11353f0a..6e35c0983 100644 --- a/torchtitan/experiments/flux/__init__.py +++ b/torchtitan/experiments/flux/__init__.py @@ -9,7 +9,7 @@ from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers -from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_dataloader +from torchtitan.experiments.flux.dataset.flux_dataset import build_flux_train_dataloader, build_flux_val_dataloader from torchtitan.experiments.flux.loss import build_mse_loss from torchtitan.experiments.flux.model.autoencoder import AutoEncoderParams from torchtitan.experiments.flux.parallelize_flux import parallelize_flux @@ -104,8 +104,7 @@ } -register_train_spec( - TrainSpec( +train_spec = TrainSpec( name="flux", cls=FluxModel, config=flux_configs, @@ -113,8 +112,12 @@ pipelining_fn=None, build_optimizers_fn=build_optimizers, build_lr_schedulers_fn=build_lr_schedulers, - build_dataloader_fn=build_flux_dataloader, + build_dataloader_fn=build_flux_train_dataloader, build_tokenizer_fn=None, build_loss_fn=build_mse_loss, ) +# monkey patch a build_val_dataloader_fn to the train_spec +train_spec.build_val_dataloader_fn = build_flux_val_dataloader +register_train_spec( + train_spec ) diff --git a/torchtitan/experiments/flux/dataset/flux_dataset.py b/torchtitan/experiments/flux/dataset/flux_dataset.py index 1e19f300b..7c86d7773 100644 --- a/torchtitan/experiments/flux/dataset/flux_dataset.py +++ b/torchtitan/experiments/flux/dataset/flux_dataset.py @@ -6,6 +6,8 @@ import json import math +import os +import random from dataclasses import dataclass from typing import Any, Callable, Optional @@ -33,12 +35,13 @@ def _process_cc12m_image( img: PIL.Image.Image, output_size: int = 256, + skip_low_resolution: bool = True, ) -> Optional[torch.Tensor]: """Process CC12M image to the desired size.""" width, height = img.size # Skip low resolution images - if width < output_size or height < output_size: + if skip_low_resolution and (width < output_size or height < output_size): return None if width >= height: @@ -106,7 +109,8 @@ def _cc12m_wds_data_processor( result = { "image": img, "clip_tokens": clip_tokens, # type: List[int] - "t5_tokens": t5_tokens, # type: List[int] + "t5_tokens": t5_tokens, # type: List[int], + "txt": sample["txt"], } if include_sample_id: result["id"] = sample["__key__"] @@ -174,6 +178,13 @@ class TextToImageDatasetConfig: loader=lambda path: load_dataset(path, split="train", streaming=True), data_processor=_cc12m_wds_data_processor, ), + "cc12m-wds-30k": TextToImageDatasetConfig( + path="pixparse/cc12m-wds", + loader=lambda path: load_dataset(path, split="train", streaming=True).take( + 30_000 + ), + data_processor=_cc12m_wds_data_processor, + ), "cc12m-preprocessed": TextToImageDatasetConfig( path="outputs/preprocessed", loader=lambda path: load_dataset( @@ -285,35 +296,32 @@ def __init__( # Variables for checkpointing self._sample_idx = 0 - self._all_samples: list[dict[str, Any]] = [] + self._epoch = 0 + self._restored_checkpoint = False + + def reset(self): + self._sample_idx = 0 def _get_data_iter(self): if isinstance(self._data, Dataset) and self._sample_idx == len(self._data): return iter([]) - it = iter(self._data) - for _ in range(self._sample_idx): - next(it) - return it + return iter(self._data) def __iter__(self): - dataset_iterator = self._get_data_iter() + # Initialize the dataset iterator + iterator = self._get_data_iter() + + # Skip samples if we're resuming from a checkpoint + if self._restored_checkpoint: + logger.info(f"Restoring dataset state: skipping {self._sample_idx} samples") + for _ in range(self._sample_idx): + next(iterator) + self._restored_checkpoint = False + while True: try: - sample = next(dataset_iterator) - except StopIteration: - if not self.infinite: - logger.warning( - f"Dataset {self.dataset_name} has run out of data. \ - This might cause NCCL timeout if data parallelism is enabled." - ) - break - else: - # Reset offset for the next iteration if infinite - self._sample_idx = 0 - logger.info(f"Dataset {self.dataset_name} is being re-looped.") - dataset_iterator = self._get_data_iter() - continue + sample = next(iterator) except (UnicodeDecodeError, SyntaxError, OSError) as e: # Handle other exception, eg, dataset corruption logger.warning( @@ -321,6 +329,15 @@ def __iter__(self): Error {type(e).__name__}: {e}. The error could be the result of a streaming glitch." ) continue + except StopIteration: + # Handle the end of the iterator + self.reset() + if not self.infinite: + logger.warning(f"Dataset {self.dataset_name} has run out of data") + break + # Reset for next epoch if infinite + logger.warning(f"Dataset {self.dataset_name} is being re-looped") + iterator = self._get_data_iter() # Use the dataset-specific preprocessor sample_dict = self._data_processor( @@ -333,17 +350,16 @@ def __iter__(self): # skip low quality image or image with color channel = 1 if sample_dict["image"] is None: logger.warning( - f"Low quality image {sample['__key__']} is skipped in Flux Dataloader." + f"Low quality image {sample['__key__']} is skipped in Flux Dataloader" ) continue # Classifier-free guidance: Replace some of the strings with empty strings. # Distinct random seed is initialized at the beginning of training for each FSDP rank. dropout_prob = self.job_config.training.classifer_free_guidance_prob - if dropout_prob > 0.0: - if torch.rand(1).item() < dropout_prob: - sample_dict["t5_tokens"] = self._t5_empty_token - sample_dict["clip_tokens"] = self._clip_empty_token + if dropout_prob > 0.0 and random.random() < dropout_prob: + sample_dict["t5_tokens"] = self._t5_empty_token + sample_dict["clip_tokens"] = self._clip_empty_token self._sample_idx += 1 @@ -351,7 +367,8 @@ def __iter__(self): yield sample_dict, labels def load_state_dict(self, state_dict): - self._sample_idx = state_dict["sample_idx"] + self._sample_idx = state_dict.get("sample_idx", 0) + self._restored_checkpoint = True # Mark that we've loaded from a checkpoint def state_dict(self): return { @@ -359,7 +376,49 @@ def state_dict(self): } -def build_flux_dataloader( +def build_flux_train_dataloader( + dp_world_size: int, + dp_rank: int, + job_config: JobConfig, + tokenizer: FluxTokenizer | None, + infinite: bool = True, +) -> ParallelAwareDataloader: + return _build_flux_dataloader( + dataset_name=job_config.training.dataset, + dataset_path=job_config.training.dataset_path, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + job_config=job_config, + tokenizer=tokenizer, + infinite=infinite, + batch_size=job_config.training.batch_size, + ) + + +def build_flux_val_dataloader( + dp_world_size: int, + dp_rank: int, + job_config: JobConfig, + tokenizer: FluxTokenizer | None, + infinite: bool = False, +) -> ParallelAwareDataloader: + print(job_config.eval.dataset_path) + + return _build_flux_dataloader( + dataset_name=job_config.eval.dataset, + dataset_path=job_config.eval.dataset_path, + dp_world_size=dp_world_size, + dp_rank=dp_rank, + job_config=job_config, + tokenizer=tokenizer, + infinite=infinite, + batch_size=job_config.eval.batch_size, + ) + + +def _build_flux_dataloader( + dataset_name: str, + dataset_path: str, dp_world_size: int, dp_rank: int, job_config: JobConfig, @@ -367,12 +426,9 @@ def build_flux_dataloader( tokenizer: FluxTokenizer | None, infinite: bool = True, include_sample_id: bool = False, + batch_size: int = 4, ) -> ParallelAwareDataloader: """Build a data loader for HuggingFace datasets.""" - dataset_name = job_config.training.dataset - dataset_path = job_config.training.dataset_path - batch_size = job_config.training.batch_size - t5_encoder_name = job_config.encoder.t5_encoder clip_encoder_name = job_config.encoder.clip_encoder max_t5_encoding_len = job_config.encoder.max_t5_encoding_len diff --git a/torchtitan/experiments/flux/flux_argparser.py b/torchtitan/experiments/flux/flux_argparser.py index aeef69fbf..059500b5f 100644 --- a/torchtitan/experiments/flux/flux_argparser.py +++ b/torchtitan/experiments/flux/flux_argparser.py @@ -43,6 +43,24 @@ class Eval: """Frequency of evaluation/sampling during training""" save_img_folder: str = "img" """Directory to save image generated/sampled from the model""" + dataset: str | None = None + """Dataset to use for validation.""" + dataset_path: str | None = None + """ + Path to the dataset in the file system. + """ + batch_size: int = 16 + """Batch size for validation.""" + +@dataclass +class Inference: + """Inference configuration""" + save_path: str = "inference_results" + """Path to save the inference results""" + prompts_path: str = "prompts.txt" + """Path to file with newline separated prompts to generate images for""" + batch_size: int = 16 + """Batch size for inference""" @dataclass @@ -54,3 +72,4 @@ class JobConfig: training: Training = field(default_factory=Training) encoder: Encoder = field(default_factory=Encoder) eval: Eval = field(default_factory=Eval) + inference: Inference = field(default_factory=Inference) diff --git a/torchtitan/experiments/flux/infer.py b/torchtitan/experiments/flux/infer.py new file mode 100644 index 000000000..e791d788a --- /dev/null +++ b/torchtitan/experiments/flux/infer.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +import os +from pathlib import Path + +import torch +from einops import rearrange +from PIL import ExifTags, Image +from torch.distributed.elastic.multiprocessing.errors import record + +from torchtitan.config_manager import ConfigManager +from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer +from torchtitan.experiments.flux.sampling import generate_image +from torchtitan.experiments.flux.train import FluxTrainer +from torchtitan.tools.logging import init_logger, logger + + +def torch_to_pil(x: torch.Tensor) -> Image.Image: + x = x.clamp(-1, 1) + x = rearrange(x, "c h w -> h w c") + return Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) + + +@record +def inference( + prompts: list[str], + trainer: FluxTrainer, + t5_tokenizer: FluxTokenizer, + clip_tokenizer: FluxTokenizer, + bs: int = 1, +): + """ + Run inference on the Flux model. + """ + results = [] + with torch.no_grad(): + for i in range(0, len(prompts), bs): + images = generate_image( + device=trainer.device, + dtype=trainer._dtype, + job_config=trainer.job_config, + model=trainer.model_parts[0], + prompt=prompts[i : i + bs], + autoencoder=trainer.autoencoder, + t5_tokenizer=t5_tokenizer, + clip_tokenizer=clip_tokenizer, + t5_encoder=trainer.t5_encoder, + clip_encoder=trainer.clip_encoder, + ) + results.append(images.detach()) + results = torch.cat(results, dim=0) + return results + + +if __name__ == "__main__": + init_logger() + config_manager = ConfigManager() + config = config_manager.parse_args() + trainer = FluxTrainer(config) + world_size = int(os.environ["WORLD_SIZE"]) + global_id = int(os.environ["RANK"]) + original_prompts = open(config.inference.prompts_path).readlines() + total_prompts = len(original_prompts) + + # Each process processes its shard + prompts = original_prompts[global_id::world_size] + + trainer.checkpointer.load(step=config.checkpoint.load_step) + t5_tokenizer = FluxTokenizer( + config.encoder.t5_encoder, + max_length=config.encoder.max_t5_encoding_len, + ) + clip_tokenizer = FluxTokenizer(config.encoder.clip_encoder, max_length=77) + + if global_id == 0: + logger.info("Starting inference...") + + if prompts: + images = inference( + prompts, trainer, t5_tokenizer, clip_tokenizer, bs=config.inference.batch_size + ) + # pad the outputs to make sure all ranks have the same number of images for the gather step + images = torch.cat([images, torch.zeros(math.ceil(total_prompts / world_size) - images.shape[0], 3, 256, 256, device=trainer.device)]) + else: + # if there are not enough prompts for all ranks, pad with empty tensors + images = torch.zeros(math.ceil(total_prompts / world_size), 3, 256, 256, device=trainer.device) + + # Create a list of tensors to gather results + gathered_images = [ + torch.zeros_like(images, device=trainer.device) for _ in range(world_size) + ] + + # Gather images from all processes + torch.distributed.all_gather(gathered_images, images) + + # re-order the images to match the original ordering of prompts + if global_id == 0: + all_images = torch.zeros( + size=[total_prompts, 3, 256, 256], + dtype=torch.float32, + device=trainer.device, + ) + for in_rank_index in range(math.ceil(total_prompts / world_size)): + for rank_index in range(world_size): + global_idx = rank_index + in_rank_index * world_size + if global_idx >= total_prompts: + break + all_images[global_idx] = gathered_images[rank_index][in_rank_index] + logger.info("Inference done") + + # Computing FID activations + pil_images = [torch_to_pil(img) for img in all_images] + if config.inference.save_path: + path = Path(config.job.dump_folder, config.inference.save_path) + path.mkdir(parents=True, exist_ok=True) + images_to_save = pil_images[:2] + for i, img in enumerate(images_to_save): + exif_data = Image.Exif() + exif_data[ExifTags.Base.Software] = "AI generated;txt2img;flux" + exif_data[ExifTags.Base.Make] = "Black Forest Labs" + exif_data[ExifTags.Base.Model] = "Schnell" + img.save( + path / f"img_{i}.png", exif=exif_data, quality=95, subsampling=0 + ) + torch.distributed.destroy_process_group() diff --git a/torchtitan/experiments/flux/loss.py b/torchtitan/experiments/flux/loss.py index e3d2f000b..b1e38ac5a 100644 --- a/torchtitan/experiments/flux/loss.py +++ b/torchtitan/experiments/flux/loss.py @@ -14,9 +14,9 @@ LossFunction: TypeAlias = Callable[..., torch.Tensor] -def mse_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: +def mse_loss(pred: torch.Tensor, labels: torch.Tensor, reduction: str = "mean") -> torch.Tensor: """Common MSE loss function for Transformer models training.""" - return torch.nn.functional.mse_loss(pred.float(), labels.float().detach()) + return torch.nn.functional.mse_loss(pred.float(), labels.float().detach(), reduction=reduction) def build_mse_loss(job_config: JobConfig): diff --git a/torchtitan/experiments/flux/prompts.txt b/torchtitan/experiments/flux/prompts.txt new file mode 100644 index 000000000..8e7fe9cf4 --- /dev/null +++ b/torchtitan/experiments/flux/prompts.txt @@ -0,0 +1,32 @@ +A serene mountain landscape at sunset with a crystal clear lake reflecting the golden sky +A futuristic cityscape with flying cars and neon lights illuminating the night sky +A cozy cafe interior with steam rising from coffee cups and warm lighting +A magical forest with glowing mushrooms and fireflies dancing between ancient trees +A peaceful beach scene with turquoise waves and palm trees swaying in the breeze +A steampunk-inspired mechanical dragon soaring through clouds +A mystical library with floating books and magical artifacts +A Japanese garden in spring with cherry blossoms falling gently +A space station orbiting a colorful nebula +A medieval castle on a hilltop during a dramatic thunderstorm +A underwater scene with bioluminescent creatures and coral reefs +A desert oasis with a majestic palace and palm trees +A cyberpunk street market with holographic signs and diverse crowds +A cozy winter cabin surrounded by snow-covered pine trees +A fantasy tavern filled with unique characters and magical atmosphere +A tropical rainforest with exotic birds and waterfalls +A steampunk airship navigating through storm clouds +A peaceful zen garden with a traditional Japanese tea house +A magical potion shop with bubbling cauldrons and mysterious ingredients +A futuristic space colony on Mars with domed habitats +A mystical temple hidden in the clouds +A vintage train station with steam locomotives and period architecture +A magical bakery with floating pastries and enchanted ingredients +A peaceful countryside scene with rolling hills and a rustic farmhouse +A underwater city with advanced technology and marine life +A fantasy marketplace with magical creatures and exotic goods +A peaceful meditation garden with lotus flowers and koi ponds +A steampunk laboratory with intricate machinery and glowing elements +A magical treehouse village connected by rope bridges +A peaceful mountain monastery with prayer flags in the wind +A futuristic greenhouse with exotic plants and advanced technology +A mystical crystal cave with glowing formations and underground streams diff --git a/torchtitan/experiments/flux/run_inference.sh b/torchtitan/experiments/flux/run_inference.sh new file mode 100755 index 000000000..a9b0399a8 --- /dev/null +++ b/torchtitan/experiments/flux/run_inference.sh @@ -0,0 +1,28 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./torchtitan/experiments/flux/run_inference.sh +NGPU=${NGPU:-"8"} +export LOG_RANK=${LOG_RANK:-0} +OUTPUT_DIR=${OUTPUT_DIR:-"./torchtitan/experiments/flux/inference_results"} +CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/experiments/flux/train_configs/debug_model.toml"} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + + +PYTORCH_CUDA_ALLOC_CONF="expandable_segments:True" \ +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +-m torchtitan.experiments.flux.infer --job.config_file ${CONFIG_FILE} --inference.save_path ${OUTPUT_DIR} $overrides diff --git a/torchtitan/experiments/flux/sampling.py b/torchtitan/experiments/flux/sampling.py index ab364eee9..5cdc4bcdf 100644 --- a/torchtitan/experiments/flux/sampling.py +++ b/torchtitan/experiments/flux/sampling.py @@ -70,6 +70,137 @@ def get_schedule( # ---------------------------------------- +def generate_and_save_images( + inputs, + clip_tokenizer, + t5_tokenizer, + clip_encoder, + t5_encoder, + model, + autoencoder, + img_size, + step, + dtype=torch.bfloat16, + device="cuda", + denoising_steps=50, + enable_classifer_free_guidance=False, + classifier_free_guidance_scale=None, + save_img_folder="img", +) -> torch.Tensor: + with torch.no_grad(): + if enable_classifer_free_guidance: + empty_batch = generate_empty_batch( + num_images=len(inputs["txt"]), + device=device, + dtype=dtype, + clip_tokenizer=clip_tokenizer, + t5_tokenizer=t5_tokenizer, + clip_encoder=clip_encoder, + t5_encoder=t5_encoder, + ) + else: + empty_batch = {"t5_encodings": None, "clip_encodings": None} + + img_height = 16 * (img_size // 16) + img_width = 16 * (img_size // 16) + images = generate_image_from_latent( + device=device, + dtype=dtype, + model=model, + autoencoder=autoencoder, + img_width=img_width, + img_height=img_height, + denoising_steps=denoising_steps, + clip_encodings=inputs["clip_encodings"], + t5_encodings=inputs["t5_encodings"], + enable_classifer_free_guidance=enable_classifer_free_guidance, + empty_t5_encodings=empty_batch["t5_encodings"], + empty_clip_encodings=empty_batch["clip_encodings"], + classifier_free_guidance_scale=classifier_free_guidance_scale, + ) + + for i, image in enumerate(images): + name = f"image_rank_{str(torch.distributed.get_rank())}_step{step}_{i}.png" + save_image( + name=name, + output_dir=save_img_folder, + x=image, + prompt=inputs["txt"][i], + ) + return images + + +def generate_empty_batch( + num_images: int, + device: torch.device, + dtype: torch.dtype, + clip_tokenizer: Tokenizer, + t5_tokenizer: Tokenizer, + clip_encoder: FluxEmbedder, + t5_encoder: FluxEmbedder, +): + empty_clip_tokens = clip_tokenizer.encode("") + empty_t5_tokens = t5_tokenizer.encode("") + empty_clip_tokens = empty_clip_tokens.repeat(num_images, 1) + empty_t5_tokens = empty_t5_tokens.repeat(num_images, 1) + return preprocess_data( + device=device, + dtype=dtype, + autoencoder=None, + clip_encoder=clip_encoder, + t5_encoder=t5_encoder, + batch={ + "clip_tokens": empty_clip_tokens, + "t5_tokens": empty_t5_tokens, + }, + ) + + +def generate_image_from_latent( + device: torch.device, + dtype: torch.dtype, + model: FluxModel, + autoencoder: AutoEncoder, + img_width: int, + img_height: int, + denoising_steps: int, + clip_encodings: torch.Tensor, + t5_encodings: torch.Tensor, + enable_classifer_free_guidance: bool = False, + empty_t5_encodings: torch.Tensor | None = None, + empty_clip_encodings: torch.Tensor | None = None, + classifier_free_guidance_scale: float | None = None, +) -> torch.Tensor: + if enable_classifer_free_guidance and ( + empty_t5_encodings is None or empty_clip_encodings is None + ): + raise ValueError( + "empty_t5_encodings and empty_clip_encodings must be provided if enable_classifer_free_guidance is True" + ) + + img = denoise( + device=device, + dtype=dtype, + model=model, + img_width=img_width, + img_height=img_height, + denoising_steps=denoising_steps, + clip_encodings=clip_encodings, + t5_encodings=t5_encodings, + enable_classifer_free_guidance=enable_classifer_free_guidance, + empty_t5_encodings=( + empty_t5_encodings if enable_classifer_free_guidance else None + ), + empty_clip_encodings=( + empty_clip_encodings if enable_classifer_free_guidance else None + ), + classifier_free_guidance_scale=classifier_free_guidance_scale, + ) + + img = autoencoder.decode(img.to(dtype)) + return img + + def generate_image( device: torch.device, dtype: torch.dtype, @@ -95,12 +226,15 @@ def generate_image( enable_classifer_free_guidance = job_config.eval.enable_classifer_free_guidance # Tokenize the prompt. Unsqueeze to add a batch dimension. - clip_tokens = clip_tokenizer.encode(prompt).unsqueeze(0) - t5_tokens = t5_tokenizer.encode(prompt).unsqueeze(0) + clip_tokens = clip_tokenizer.encode(prompt) + t5_tokens = t5_tokenizer.encode(prompt) + if len(prompt) == 1: + clip_tokens = clip_tokens.unsqueeze(0) + t5_tokens = t5_tokens.unsqueeze(0) batch = preprocess_data( device=device, - dtype=torch.bfloat16, + dtype=dtype, autoencoder=None, clip_encoder=clip_encoder, t5_encoder=t5_encoder, @@ -111,24 +245,21 @@ def generate_image( ) if enable_classifer_free_guidance: - empty_clip_tokens = clip_tokenizer.encode("").unsqueeze(0) - empty_t5_tokens = t5_tokenizer.encode("").unsqueeze(0) - empty_batch = preprocess_data( + empty_batch = generate_empty_batch( + num_images=len(prompt), device=device, - dtype=torch.bfloat16, - autoencoder=None, + dtype=dtype, + clip_tokenizer=clip_tokenizer, + t5_tokenizer=t5_tokenizer, clip_encoder=clip_encoder, t5_encoder=t5_encoder, - batch={ - "clip_tokens": empty_clip_tokens, - "t5_tokens": empty_t5_tokens, - }, ) - img = denoise( + return generate_image_from_latent( device=device, dtype=dtype, model=model, + autoencoder=autoencoder, img_width=img_width, img_height=img_height, denoising_steps=job_config.eval.denoising_steps, @@ -144,9 +275,6 @@ def generate_image( classifier_free_guidance_scale=job_config.eval.classifier_free_guidance_scale, ) - img = autoencoder.decode(img) - return img - def denoise( device: torch.device, @@ -176,9 +304,9 @@ def denoise( # create positional encodings POSITION_DIM = 3 latent_pos_enc = create_position_encoding_for_latents( - bsz, latent_height, latent_width, POSITION_DIM + 1, latent_height, latent_width, POSITION_DIM ).to(latents) - text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM).to(latents) + text_pos_enc = torch.zeros(1, t5_encodings.shape[1], POSITION_DIM).to(latents) if enable_classifer_free_guidance: latents = torch.cat([latents, latents], dim=0) @@ -190,7 +318,7 @@ def denoise( # this is ignored for schnell for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]): - t_vec = torch.full((bsz,), t_curr, dtype=dtype, device=device) + t_vec = torch.full((1,), t_curr, dtype=dtype, device=device) pred = model( img=latents, img_ids=latent_pos_enc, @@ -202,9 +330,12 @@ def denoise( if enable_classifer_free_guidance: pred_u, pred_c = pred.chunk(2) pred = pred_u + classifier_free_guidance_scale * (pred_c - pred_u) - + pred = pred.repeat(2, 1, 1) latents = latents + (t_prev - t_curr) * pred + if enable_classifer_free_guidance: + latents = latents.chunk(2)[1] + # convert sequences of patches into img-like latents latents = unpack_latents(latents, latent_height, latent_width) @@ -215,16 +346,17 @@ def save_image( name: str, output_dir: str, x: torch.Tensor, - add_sampling_metadata: bool, - prompt: str, + add_sampling_metadata: bool = False, + prompt: str | None = None, ): print(f"Saving {output_dir}/{name}") - if not os.path.exists(output_dir): - os.makedirs(output_dir) + os.makedirs(output_dir, exist_ok=True) output_name = os.path.join(output_dir, name) # bring into PIL format and save x = x.clamp(-1, 1) - x = rearrange(x[0], "c h w -> h w c") + if len(x.shape) == 4: + x = x[0] + x = rearrange(x, "c h w -> h w c") img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index 0bbc3a043..7c98658df 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -1,21 +1,26 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import math import os +from datetime import time, timedelta from typing import Optional import torch +from torch.distributed.elastic.multiprocessing.errors import record +import torchtitan.components.ft as ft from torchtitan.config_manager import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import utils as dist_utils from torchtitan.experiments.flux.dataset.tokenizer import FluxTokenizer from torchtitan.experiments.flux.model.autoencoder import load_ae from torchtitan.experiments.flux.model.hf_embedder import FluxEmbedder from torchtitan.experiments.flux.parallelize_flux import parallelize_encoders -from torchtitan.experiments.flux.sampling import generate_image, save_image +from torchtitan.experiments.flux.sampling import generate_and_save_images from torchtitan.experiments.flux.utils import ( create_position_encoding_for_latents, pack_latents, @@ -23,6 +28,10 @@ unpack_latents, ) from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.profiling import ( + maybe_enable_memory_snapshot, + maybe_enable_profiling, +) from torchtitan.train import Trainer @@ -54,9 +63,21 @@ def __init__(self, job_config: JobConfig): model_config = self.train_spec.config[job_config.model.flavor] + # load components for pre-processing is the dataset is not preprocessed self.is_dataset_preprocessed = "preprocess" in job_config.training.dataset - # load components for pre-processing is the dataset is not preprocessed + self.val_dataloader = ( + self.train_spec.build_val_dataloader_fn( + dp_world_size=self.dataloader.dp_world_size, + dp_rank=self.dataloader.dp_rank, + tokenizer=None, + job_config=job_config, + infinite=False, + ) + if job_config.eval.dataset + else None + ) + self.autoencoder = load_ae( job_config.encoder.autoencoder_path, model_config.autoencoder_params, @@ -187,48 +208,276 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): model.eval() # We need to set reshard_after_forward before last forward pass. # So the model wieghts are sharded the same way for checkpoint saving. - model.final_layer.set_reshard_after_forward(True) + if self.parallel_dims.dp_shard_enabled: + model.final_layer.set_reshard_after_forward(True) self.eval_step() - model.final_layer.set_reshard_after_forward(False) + if self.parallel_dims.dp_shard_enabled: + model.final_layer.set_reshard_after_forward(False) model.train() - def eval_step(self, prompt: str = "A photo of a cat"): - """ - Evaluate the Flux model. - 1) generate and save images every few steps. Currently, we run the eval and on the same - prompts across all DP ranks. We will change this behavior to run on validation set prompts. - Due to random noise generation, results could be different across DP ranks cause we assign - different random seeds to each DP rank. - 2) [TODO] Calculate loss with fixed t value on validation set. + def eval_step( + self, + input_dict: dict[str, torch.Tensor], + labels: torch.Tensor, + timesteps: torch.Tensor, + save_imgs: bool = False, + ): # prompt: str = "A photo of a cat"): """ + Calculate the validation loss for the Flux model. - image = generate_image( + This follows the original paper's evaluation protocol. For each sample, calculate the loss at 7 equally spaced + values for t in [0, 1] (excluding 1) and average it. This will make each batch size 7x larger, which may require + a different batch size. + + Returns: Average loss per timestep across all samples in the batch. + """ + input_dict["image"] = labels + input_dict = self.preprocess_fn( device=self.device, dtype=self._dtype, - job_config=self.job_config, - model=self.model_parts[0], - prompt=prompt, # TODO(jianiw): change this to a prompt from validation set autoencoder=self.autoencoder, - t5_tokenizer=FluxTokenizer( - self.job_config.encoder.t5_encoder, - max_length=self.job_config.encoder.max_t5_encoding_len, - ), - clip_tokenizer=FluxTokenizer( - self.job_config.encoder.clip_encoder, max_length=77 - ), - t5_encoder=self.t5_encoder, clip_encoder=self.clip_encoder, + t5_encoder=self.t5_encoder, + batch=input_dict, ) + labels = input_dict["img_encodings"] + # Keep these variables local to shorten the code as these are + # the major variables that are used in the training loop. + model_parts = self.model_parts + assert len(self.model_parts) == 1 + model = model_parts[0] + + world_mesh = self.world_mesh + parallel_dims = self.parallel_dims + + # image in latent space transformed by self.auto_encoder + clip_encodings = input_dict["clip_encodings"] + t5_encodings = input_dict["t5_encodings"] + + bsz = labels.shape[0] + + with torch.no_grad(): + noise = torch.randn_like(labels) + timestep_values = (timesteps / 8.0).to(labels) + sigmas = timestep_values.view(-1, 1, 1, 1) + latents = (1 - sigmas) * labels + sigmas * noise + + bsz, _, latent_height, latent_width = latents.shape + + POSITION_DIM = 3 # constant for Flux flow model + # Create positional encodings + latent_pos_enc = create_position_encoding_for_latents( + bsz, latent_height, latent_width, POSITION_DIM + ) + text_pos_enc = torch.zeros(bsz, t5_encodings.shape[1], POSITION_DIM) - save_image( - name=f"image_rank{str(torch.distributed.get_rank())}_{self.step}.png", - output_dir=os.path.join( - self.job_config.job.dump_folder, self.job_config.eval.save_img_folder + # Patchify: Convert latent into a sequence of patches + latents = pack_latents(latents) + + latent_noise_pred = model( + img=latents, + img_ids=latent_pos_enc.to(latents), + txt=t5_encodings.to(latents), + txt_ids=text_pos_enc.to(latents), + y=clip_encodings.to(latents), + timesteps=timestep_values.to(latents), + ) + + # Convert sequence of patches to latent shape + pred = unpack_latents(latent_noise_pred, latent_height, latent_width) + target = noise - labels + loss = self.loss_fn(pred, target, reduction="none") + + # average the loss across timesteps + # might be useful to report this in the future, but currently not mechanism in torchtitan + # for distributed averaging with numel > 1 + # loss_per_timestep = loss.view(7, -1).mean(dim=1) + + # Initialize a tensor to accumulate losses for each timestep (0-7) + loss_per_timestep = torch.zeros(8, device=loss.device) + # Reshape loss to have one value per sample + loss_per_sample = loss.mean(dim=(1, 2, 3)) + + # Get integer timestep values from the timestep_values + timestep_indices = timesteps.long() + + # Use scatter_add_ for vectorized accumulation of losses by timestep + loss_per_timestep.scatter_add_(0, timestep_indices, loss_per_sample) + + # Count samples per timestep for averaging (using bincount) + timestep_counts = torch.bincount(timestep_indices, minlength=8) + + # Avoid division by zero + timestep_counts = torch.maximum( + timestep_counts, torch.ones_like(timestep_counts) + ) + + if ( + parallel_dims.dp_replicate_enabled + or parallel_dims.dp_shard_enabled + or parallel_dims.cp_enabled + ): + # Collect loss sums and counts from all devices + ft_pg = ( + self.ft_manager.replicate_pg if self.ft_manager.enabled else None + ) + # Use the new dist_collect function to gather tensors across devices + global_loss_per_timestep = dist_utils.dist_collect( + loss_per_timestep, world_mesh["dp_cp"], ft_pg + ) + global_timestep_counts = dist_utils.dist_collect( + timestep_counts, world_mesh["dp_cp"], ft_pg + ) + + else: + # For single device, just calculate locally + global_loss_per_timestep = loss_per_timestep + global_timestep_counts = timestep_counts + + if save_imgs: + t5_tokenizer = FluxTokenizer( + self.job_config.encoder.t5_encoder, + max_length=self.job_config.encoder.max_t5_encoding_len, + ) + clip_tokenizer = FluxTokenizer( + self.job_config.encoder.clip_encoder, + max_length=77, + ) + generate_and_save_images( + input_dict, + clip_tokenizer, + t5_tokenizer, + self.clip_encoder, + self.t5_encoder, + self.model_parts[0], + self.autoencoder, + self.job_config.training.img_size, + self.step, + ) + # In the future, we could return avg_loss_per_timestep for more detailed reporting + return global_loss_per_timestep, global_timestep_counts + + @record + def train(self): + job_config = self.job_config + + self.checkpointer.load(step=job_config.checkpoint.load_step) + logger.info(f"Training starts at step {self.step + 1}.") + + with ( + maybe_enable_profiling(job_config, global_step=self.step) as torch_profiler, + maybe_enable_memory_snapshot( + job_config, global_step=self.step + ) as memory_profiler, + ft.maybe_semi_sync_training( + job_config, + ft_manager=self.ft_manager, + model=self.model_parts[0], + optimizer=self.optimizers, + sync_every=job_config.fault_tolerance.sync_steps, ), - x=image, - add_sampling_metadata=True, - prompt=prompt, + ): + data_iterator = self.batch_generator(self.dataloader) + for inputs, labels in data_iterator: + if self.step >= job_config.training.steps: + break + self.step += 1 + self.gc_handler.run(self.step) + self.train_step(inputs, labels) + self.checkpointer.save( + self.step, force=(self.step == job_config.training.steps) + ) + + if ( + self.step % job_config.eval.eval_freq == 0 + and job_config.eval.dataset + ): + self.eval() + + # signal the profiler that the next profiling step has started + if torch_profiler: + torch_profiler.step() + if memory_profiler: + memory_profiler.step() + + # reduce timeout after first train step for faster signal + # (assuming lazy init and compilation are finished) + if self.step == 1: + dist_utils.set_pg_timeouts( + timeout=timedelta( + seconds=job_config.comm.train_timeout_seconds + ), + world_mesh=self.world_mesh, + ) + + if torch.distributed.get_rank() == 0: + logger.info("Sleeping 2 seconds for other ranks to complete") + time.sleep(2) + + self.metrics_processor.close() + logger.info("Training completed") + + def eval(self): + def generate_val_timesteps(cur_val_timestep, samples): + """ + Generate timesteps for validation set + + This is a helper function to generate timesteps 0 through 7, repeating as necessary. + """ + first_offset = torch.arange(cur_val_timestep, 8, device=self.device)[ + :samples + ] + samples_left = samples - first_offset.numel() + val_timesteps = torch.arange( + 0, 8, dtype=torch.int8, device=self.device + ).repeat_interleave(math.ceil(samples_left / 8))[:samples_left] + val_timesteps = torch.cat([first_offset, val_timesteps]) + cur_val_timestep = (val_timesteps[-1].item() + 1) % 8 + return val_timesteps, cur_val_timestep + + logger.info("Starting validation...") + t5_tokenizer = FluxTokenizer( + config.encoder.t5_encoder, + max_length=config.encoder.max_t5_encoding_len, ) + clip_tokenizer = FluxTokenizer( + self.job_config.encoder.clip_encoder, max_length=77 + ) + # Follow procedure set out in Flux paper of stratified timestep sampling + val_data_iterator = iter(self.val_dataloader) + cur_val_timestep = 0 + eval_step = 0 + eval_samples = 0 + sum_loss_per_timestep = torch.zeros(8, device=self.device) + sum_timestep_counts = torch.zeros(8, device=self.device) + # Iterate through all validation batches + # TODO: not sure how to handle profiling with validation + while True: + try: + val_inputs, val_labels = self.next_batch(val_data_iterator) + except StopIteration: + break + eval_step += 1 + samples = len(val_labels) + val_timesteps, cur_val_timestep = generate_val_timesteps( + cur_val_timestep, samples + ) + loss, counts = self.eval_step( + val_inputs, + val_labels, + val_timesteps, + save_imgs=eval_step == 1 and self.job_config.eval.save_img_folder, + ) + eval_samples += samples + sum_loss_per_timestep += loss + sum_timestep_counts += counts + + # Different batches and timestepsmay have different number of samples, so we need to average the loss like this + # rather than taking the mean of the mean batch losses. + timestep_counts_proportions = sum_timestep_counts / sum_timestep_counts.sum() + avg_loss_per_timestep = sum_loss_per_timestep / sum_timestep_counts + avg_loss = (avg_loss_per_timestep * timestep_counts_proportions).sum() + self.metrics_processor.val_log(self.step, avg_loss) if __name__ == "__main__": diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index 6cc6d347f..b20d74b5a 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -136,7 +136,7 @@ class NoColor: def check_if_feature_in_pytorch( feature_name: str, - pull_request: str, + pull_request_link: str, min_nightly_version: Optional[str] = None, ) -> None: if "git" in torch.__version__: # pytorch is built from source diff --git a/torchtitan/train.py b/torchtitan/train.py index 8966e689f..6a09ee389 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -302,11 +302,12 @@ def next_batch( ) device_type = utils.device_type - for k, _ in input_dict.items(): - if k == "id": - continue - input_dict[k] = input_dict[k].to(device_type) + # Move tensors to the appropriate device + for k, v in input_dict.items(): + if isinstance(v, torch.Tensor): + input_dict[k] = v.to(device_type) labels = labels.to(device_type) + return input_dict, labels def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): @@ -403,16 +404,18 @@ def train(self): self.checkpointer.load(step=job_config.checkpoint.load_step) logger.info(f"Training starts at step {self.step + 1}.") - with maybe_enable_profiling( - job_config, global_step=self.step - ) as torch_profiler, maybe_enable_memory_snapshot( - job_config, global_step=self.step - ) as memory_profiler, ft.maybe_semi_sync_training( + with ( + maybe_enable_profiling(job_config, global_step=self.step) as torch_profiler, + maybe_enable_memory_snapshot( + job_config, global_step=self.step + ) as memory_profiler, + ft.maybe_semi_sync_training( job_config, ft_manager=self.ft_manager, model=self.model_parts[0], optimizer=self.optimizers, sync_every=job_config.fault_tolerance.sync_steps, + ) ): data_iterator = iter(self.dataloader) while self.step < job_config.training.steps: @@ -468,12 +471,12 @@ def close(self) -> None: trainer = Trainer(config) if config.checkpoint.create_seed_checkpoint: - assert ( - int(os.environ["WORLD_SIZE"]) == 1 - ), "Must create seed checkpoint using a single device, to disable sharding." - assert ( - config.checkpoint.enable_checkpoint - ), "Must enable checkpointing when creating a seed checkpoint." + assert int(os.environ["WORLD_SIZE"]) == 1, ( + "Must create seed checkpoint using a single device, to disable sharding." + ) + assert config.checkpoint.enable_checkpoint, ( + "Must enable checkpointing when creating a seed checkpoint." + ) trainer.checkpointer.save(curr_step=0, force=True) logger.info("Created seed checkpoint") else: